From 8db7699e0f5fbcd552fd6db671c54006c77ea8ce Mon Sep 17 00:00:00 2001 From: Maksim Nabokikh Date: Sun, 22 Feb 2026 23:06:03 +0100 Subject: [PATCH] feat: implement device code flow in example-app (#4570) This is a KubeCon 2026 preparation: 1. Add device flow to the example-app 2. Add userinfo checker 3. Refactor the structure Signed-off-by: maksim.nabokikh --- examples/config-dev.yaml | 1 + examples/example-app/handlers.go | 141 +++++++++++ examples/example-app/handlers_device.go | 273 +++++++++++++++++++++ examples/example-app/handlers_userinfo.go | 68 +++++ examples/example-app/main.go | 253 ++----------------- examples/example-app/static/app.js | 48 ++++ examples/example-app/static/device.js | 110 +++++++++ examples/example-app/static/style.css | 232 +++++++++++++++++ examples/example-app/static/token.js | 70 ++++++ examples/example-app/templates.go | 19 ++ examples/example-app/templates/device.html | 61 +++++ examples/example-app/templates/index.html | 9 +- examples/example-app/templates/token.html | 22 +- examples/example-app/utils.go | 154 ++++++++++++ 14 files changed, 1222 insertions(+), 239 deletions(-) create mode 100644 examples/example-app/handlers.go create mode 100644 examples/example-app/handlers_device.go create mode 100644 examples/example-app/handlers_userinfo.go create mode 100644 examples/example-app/static/device.js create mode 100644 examples/example-app/templates/device.html create mode 100644 examples/example-app/utils.go diff --git a/examples/config-dev.yaml b/examples/config-dev.yaml index 7bb0f2eb..94a40bff 100644 --- a/examples/config-dev.yaml +++ b/examples/config-dev.yaml @@ -129,6 +129,7 @@ staticClients: - id: example-app redirectURIs: - 'http://127.0.0.1:5555/callback' + - '/dex/device/callback' name: 'Example App' secret: ZXhhbXBsZS1hcHAtc2VjcmV0 diff --git a/examples/example-app/handlers.go b/examples/example-app/handlers.go new file mode 100644 index 00000000..fce65d77 --- /dev/null +++ b/examples/example-app/handlers.go @@ -0,0 +1,141 @@ +package main + +import ( + "fmt" + "net/http" + "net/url" + "time" + + "github.com/coreos/go-oidc/v3/oidc" + "golang.org/x/oauth2" +) + +func (a *app) handleIndex(w http.ResponseWriter, r *http.Request) { + renderIndex(w, indexPageData{ + ScopesSupported: a.scopesSupported, + LogoURI: dexLogoDataURI, + }) +} + +func (a *app) handleLogin(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + http.Error(w, fmt.Sprintf("failed to parse form: %v", err), http.StatusBadRequest) + return + } + + // Only use scopes that are checked in the form + scopes := r.Form["extra_scopes"] + crossClients := r.Form["cross_client"] + + // Build complete scope list with audience scopes + scopes = buildScopes(scopes, crossClients) + + connectorID := "" + if id := r.FormValue("connector_id"); id != "" { + connectorID = id + } + + authCodeURL := "" + + var authCodeOptions []oauth2.AuthCodeOption + + if a.pkce { + authCodeOptions = append(authCodeOptions, oauth2.SetAuthURLParam("code_challenge", codeChallenge)) + authCodeOptions = append(authCodeOptions, oauth2.SetAuthURLParam("code_challenge_method", "S256")) + } + + // Check if offline_access scope is present to determine offline access mode + hasOfflineAccess := false + for _, scope := range scopes { + if scope == "offline_access" { + hasOfflineAccess = true + break + } + } + + if hasOfflineAccess && !a.offlineAsScope { + // Provider uses access_type=offline instead of offline_access scope + authCodeOptions = append(authCodeOptions, oauth2.AccessTypeOffline) + // Remove offline_access from scopes as it's not supported + filteredScopes := make([]string, 0, len(scopes)) + for _, scope := range scopes { + if scope != "offline_access" { + filteredScopes = append(filteredScopes, scope) + } + } + scopes = filteredScopes + } + + authCodeURL = a.oauth2Config(scopes).AuthCodeURL(exampleAppState, authCodeOptions...) + + // Parse the auth code URL and safely add connector_id parameter if provided + u, err := url.Parse(authCodeURL) + if err != nil { + http.Error(w, "Failed to parse auth URL", http.StatusInternalServerError) + return + } + + if connectorID != "" { + query := u.Query() + query.Set("connector_id", connectorID) + u.RawQuery = query.Encode() + } + + http.Redirect(w, r, u.String(), http.StatusSeeOther) +} + +func (a *app) handleCallback(w http.ResponseWriter, r *http.Request) { + var ( + err error + token *oauth2.Token + ) + + ctx := oidc.ClientContext(r.Context(), a.client) + oauth2Config := a.oauth2Config(nil) + switch r.Method { + case http.MethodGet: + // Authorization redirect callback from OAuth2 auth flow. + if errMsg := r.FormValue("error"); errMsg != "" { + http.Error(w, errMsg+": "+r.FormValue("error_description"), http.StatusBadRequest) + return + } + code := r.FormValue("code") + if code == "" { + http.Error(w, fmt.Sprintf("no code in request: %q", r.Form), http.StatusBadRequest) + return + } + if state := r.FormValue("state"); state != exampleAppState { + http.Error(w, fmt.Sprintf("expected state %q got %q", exampleAppState, state), http.StatusBadRequest) + return + } + + var authCodeOptions []oauth2.AuthCodeOption + if a.pkce { + authCodeOptions = append(authCodeOptions, oauth2.SetAuthURLParam("code_verifier", codeVerifier)) + } + + token, err = oauth2Config.Exchange(ctx, code, authCodeOptions...) + case http.MethodPost: + // Form request from frontend to refresh a token. + refresh := r.FormValue("refresh_token") + if refresh == "" { + http.Error(w, fmt.Sprintf("no refresh_token in request: %q", r.Form), http.StatusBadRequest) + return + } + t := &oauth2.Token{ + RefreshToken: refresh, + Expiry: time.Now().Add(-time.Hour), + } + token, err = oauth2Config.TokenSource(ctx, t).Token() + default: + http.Error(w, fmt.Sprintf("method not implemented: %s", r.Method), http.StatusBadRequest) + return + } + + if err != nil { + http.Error(w, fmt.Sprintf("failed to get token: %v", err), http.StatusInternalServerError) + return + } + + parseAndRenderToken(w, r, a, token) +} diff --git a/examples/example-app/handlers_device.go b/examples/example-app/handlers_device.go new file mode 100644 index 00000000..40209ca2 --- /dev/null +++ b/examples/example-app/handlers_device.go @@ -0,0 +1,273 @@ +package main + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + + "golang.org/x/oauth2" +) + +func (a *app) handleDeviceLogin(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Parse request body to get options + var reqBody struct { + Scopes []string `json:"scopes"` + CrossClients []string `json:"cross_clients"` + ConnectorID string `json:"connector_id"` + } + + if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { + http.Error(w, fmt.Sprintf("failed to parse request body: %v", err), http.StatusBadRequest) + return + } + + // Build complete scope list with audience scopes (same as handleLogin) + scopes := buildScopes(reqBody.Scopes, reqBody.CrossClients) + + // Build scope string + scopeStr := strings.Join(scopes, " ") + + // Get device authorization endpoint + // Properly construct the device code endpoint URL + authURL := a.provider.Endpoint().AuthURL + deviceAuthURL := strings.TrimSuffix(authURL, "/auth") + "/device/code" + + // Request device code + data := url.Values{} + data.Set("client_id", a.clientID) + data.Set("client_secret", a.clientSecret) + data.Set("scope", scopeStr) + + // Add connector_id if specified + if reqBody.ConnectorID != "" { + data.Set("connector_id", reqBody.ConnectorID) + } + + resp, err := a.client.PostForm(deviceAuthURL, data) + if err != nil { + http.Error(w, fmt.Sprintf("Failed to request device code: %v", err), http.StatusInternalServerError) + return + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body := new(bytes.Buffer) + body.ReadFrom(resp.Body) + http.Error(w, fmt.Sprintf("Device code request failed: %s", body.String()), resp.StatusCode) + return + } + + var deviceResp struct { + DeviceCode string `json:"device_code"` + UserCode string `json:"user_code"` + VerificationURI string `json:"verification_uri"` + VerificationURIComplete string `json:"verification_uri_complete"` + ExpiresIn int `json:"expires_in"` + Interval int `json:"interval"` + } + + if err := json.NewDecoder(resp.Body).Decode(&deviceResp); err != nil { + http.Error(w, fmt.Sprintf("Failed to decode device response: %v", err), http.StatusInternalServerError) + return + } + + // Store device flow data with new session + sessionID := generateSessionID() + + a.deviceFlowMutex.Lock() + a.deviceFlowData.sessionID = sessionID + a.deviceFlowData.deviceCode = deviceResp.DeviceCode + a.deviceFlowData.userCode = deviceResp.UserCode + a.deviceFlowData.verificationURI = deviceResp.VerificationURI + a.deviceFlowData.pollInterval = deviceResp.Interval + if a.deviceFlowData.pollInterval == 0 { + a.deviceFlowData.pollInterval = 5 + } + a.deviceFlowData.token = nil + a.deviceFlowMutex.Unlock() + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "status": "ok", + "session_id": sessionID, + }) +} + +func (a *app) handleDevicePage(w http.ResponseWriter, r *http.Request) { + a.deviceFlowMutex.Lock() + data := devicePageData{ + SessionID: a.deviceFlowData.sessionID, + DeviceCode: a.deviceFlowData.deviceCode, + UserCode: a.deviceFlowData.userCode, + VerificationURI: a.deviceFlowData.verificationURI, + PollInterval: a.deviceFlowData.pollInterval, + LogoURI: dexLogoDataURI, + } + a.deviceFlowMutex.Unlock() + + if data.DeviceCode == "" { + http.Error(w, "No device flow in progress", http.StatusBadRequest) + return + } + + renderDevice(w, data) +} + +func (a *app) handleDevicePoll(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + var req struct { + DeviceCode string `json:"device_code"` + SessionID string `json:"session_id"` + } + + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "Invalid request", http.StatusBadRequest) + return + } + + a.deviceFlowMutex.Lock() + storedSessionID := a.deviceFlowData.sessionID + storedDeviceCode := a.deviceFlowData.deviceCode + existingToken := a.deviceFlowData.token + a.deviceFlowMutex.Unlock() + + // Check if this session has been superseded by a new one + if req.SessionID != storedSessionID { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusGone) + json.NewEncoder(w).Encode(map[string]interface{}{ + "error": "session_expired", + "error_description": "This device flow session has been superseded by a new one", + }) + return + } + + if req.DeviceCode != storedDeviceCode { + http.Error(w, "Invalid device code", http.StatusBadRequest) + return + } + + // If we already have a token, return success + if existingToken != nil { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]string{ + "status": "complete", + }) + return + } + + // Poll the token endpoint + tokenURL := a.provider.Endpoint().TokenURL + + data := url.Values{} + data.Set("grant_type", "urn:ietf:params:oauth:grant-type:device_code") + data.Set("device_code", req.DeviceCode) + data.Set("client_id", a.clientID) + data.Set("client_secret", a.clientSecret) + + tokenResp, err := a.client.PostForm(tokenURL, data) + if err != nil { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]string{ + "status": "pending", + }) + return + } + defer tokenResp.Body.Close() + + if tokenResp.StatusCode == http.StatusOK { + // Success! We got the token + // Parse the full response including id_token + var tokenData struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int `json:"expires_in"` + IDToken string `json:"id_token"` + } + + if err := json.NewDecoder(tokenResp.Body).Decode(&tokenData); err != nil { + http.Error(w, "Failed to decode token", http.StatusInternalServerError) + return + } + + // Create oauth2.Token with all fields + token := &oauth2.Token{ + AccessToken: tokenData.AccessToken, + TokenType: tokenData.TokenType, + RefreshToken: tokenData.RefreshToken, + } + + // Add id_token to Extra + token = token.WithExtra(map[string]interface{}{ + "id_token": tokenData.IDToken, + }) + + // Store the token + a.deviceFlowMutex.Lock() + a.deviceFlowData.token = token + a.deviceFlowMutex.Unlock() + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]string{ + "status": "complete", + }) + return + } + + // Check for errors + var errorResp struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description"` + } + + if err := json.NewDecoder(tokenResp.Body).Decode(&errorResp); err == nil { + if errorResp.Error == "authorization_pending" { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]string{ + "status": "pending", + }) + return + } + + // Other errors + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(tokenResp.StatusCode) + json.NewEncoder(w).Encode(map[string]interface{}{ + "error": errorResp.Error, + "error_description": errorResp.ErrorDescription, + }) + return + } + + // Unknown response + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]string{ + "status": "pending", + }) +} + +func (a *app) handleDeviceResult(w http.ResponseWriter, r *http.Request) { + a.deviceFlowMutex.Lock() + token := a.deviceFlowData.token + a.deviceFlowMutex.Unlock() + + if token == nil { + http.Error(w, "No token available", http.StatusBadRequest) + return + } + + parseAndRenderToken(w, r, a, token) +} diff --git a/examples/example-app/handlers_userinfo.go b/examples/example-app/handlers_userinfo.go new file mode 100644 index 00000000..36bab851 --- /dev/null +++ b/examples/example-app/handlers_userinfo.go @@ -0,0 +1,68 @@ +package main + +import ( + "encoding/json" + "fmt" + "io" + "net/http" +) + +func (a *app) handleUserInfo(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Parse form to get access token + if err := r.ParseForm(); err != nil { + http.Error(w, fmt.Sprintf("Failed to parse form: %v", err), http.StatusBadRequest) + return + } + + accessToken := r.FormValue("access_token") + if accessToken == "" { + http.Error(w, "access_token is required", http.StatusBadRequest) + return + } + + // Get UserInfo endpoint from provider + userInfoEndpoint := a.provider.Endpoint().AuthURL + if len(userInfoEndpoint) > 5 { + // Replace /auth with /userinfo + userInfoEndpoint = userInfoEndpoint[:len(userInfoEndpoint)-5] + "/userinfo" + } + + // Create request to UserInfo endpoint + req, err := http.NewRequestWithContext(r.Context(), "GET", userInfoEndpoint, nil) + if err != nil { + http.Error(w, fmt.Sprintf("Failed to create request: %v", err), http.StatusInternalServerError) + return + } + + // Add Authorization header with access token + req.Header.Set("Authorization", "Bearer "+accessToken) + + // Make the request + resp, err := a.client.Do(req) + if err != nil { + http.Error(w, fmt.Sprintf("Failed to fetch userinfo: %v", err), http.StatusInternalServerError) + return + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + http.Error(w, fmt.Sprintf("UserInfo request failed: %s", string(body)), resp.StatusCode) + return + } + + // Parse and return the userinfo + var userInfo map[string]interface{} + if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil { + http.Error(w, fmt.Sprintf("Failed to decode userinfo: %v", err), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(userInfo) +} diff --git a/examples/example-app/main.go b/examples/example-app/main.go index c1b0e968..389d2ff0 100644 --- a/examples/example-app/main.go +++ b/examples/example-app/main.go @@ -1,21 +1,14 @@ package main import ( - "bytes" "context" - "crypto/tls" - "crypto/x509" - "encoding/json" "errors" "fmt" "log" - "net" "net/http" - "net/http/httputil" "net/url" "os" - "slices" - "time" + "sync" "github.com/coreos/go-oidc/v3/oidc" "github.com/spf13/cobra" @@ -49,55 +42,19 @@ type app struct { offlineAsScope bool client *http.Client -} - -// return an HTTP client which trusts the provided root CAs. -func httpClientForRootCAs(rootCAs string) (*http.Client, error) { - tlsConfig := tls.Config{RootCAs: x509.NewCertPool()} - rootCABytes, err := os.ReadFile(rootCAs) - if err != nil { - return nil, fmt.Errorf("failed to read root-ca: %v", err) - } - if !tlsConfig.RootCAs.AppendCertsFromPEM(rootCABytes) { - return nil, fmt.Errorf("no certs found in root CA file %q", rootCAs) - } - return &http.Client{ - Transport: &http.Transport{ - TLSClientConfig: &tlsConfig, - Proxy: http.ProxyFromEnvironment, - Dial: (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - }).Dial, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - }, - }, nil -} - -type debugTransport struct { - t http.RoundTripper -} -func (d debugTransport) RoundTrip(req *http.Request) (*http.Response, error) { - reqDump, err := httputil.DumpRequest(req, true) - if err != nil { - return nil, err + // Device flow state + // Only one session is possible at a time + // Since it is an example, we don't bother locking', this is a simplicity tradeoff + deviceFlowMutex sync.Mutex + deviceFlowData struct { + sessionID string // Unique ID for current flow session + deviceCode string + userCode string + verificationURI string + pollInterval int + token *oauth2.Token } - log.Printf("%s", reqDump) - - resp, err := d.t.RoundTrip(req) - if err != nil { - return nil, err - } - - respDump, err := httputil.DumpResponse(resp, true) - if err != nil { - resp.Body.Close() - return nil, err - } - log.Printf("%s", respDump) - return resp, nil } func cmd() *cobra.Command { @@ -191,6 +148,11 @@ func cmd() *cobra.Command { http.Handle("/static/", http.StripPrefix("/static/", staticHandler)) http.HandleFunc("/", a.handleIndex) http.HandleFunc("/login", a.handleLogin) + http.HandleFunc("/device/login", a.handleDeviceLogin) + http.HandleFunc("/device", a.handleDevicePage) + http.HandleFunc("/device/poll", a.handleDevicePoll) + http.HandleFunc("/device/result", a.handleDeviceResult) + http.HandleFunc("/userinfo", a.handleUserInfo) http.HandleFunc(u.Path, a.handleCallback) switch listenURL.Scheme { @@ -224,184 +186,3 @@ func main() { os.Exit(2) } } - -func (a *app) handleIndex(w http.ResponseWriter, r *http.Request) { - renderIndex(w, indexPageData{ - ScopesSupported: a.scopesSupported, - LogoURI: dexLogoDataURI, - }) -} - -func (a *app) oauth2Config(scopes []string) *oauth2.Config { - return &oauth2.Config{ - ClientID: a.clientID, - ClientSecret: a.clientSecret, - Endpoint: a.provider.Endpoint(), - Scopes: scopes, - RedirectURL: a.redirectURI, - } -} - -func (a *app) handleLogin(w http.ResponseWriter, r *http.Request) { - if err := r.ParseForm(); err != nil { - http.Error(w, fmt.Sprintf("failed to parse form: %v", err), http.StatusBadRequest) - return - } - - // Only use scopes that are checked in the form - scopes := r.Form["extra_scopes"] - - clients := r.Form["cross_client"] - for _, client := range clients { - if client == "" { - continue - } - scopes = append(scopes, "audience:server:client_id:"+client) - } - connectorID := "" - if id := r.FormValue("connector_id"); id != "" { - connectorID = id - } - - authCodeURL := "" - scopes = uniqueStrings(scopes) - - var authCodeOptions []oauth2.AuthCodeOption - - if a.pkce { - authCodeOptions = append(authCodeOptions, oauth2.SetAuthURLParam("code_challenge", codeChallenge)) - authCodeOptions = append(authCodeOptions, oauth2.SetAuthURLParam("code_challenge_method", "S256")) - } - - // Check if offline_access scope is present to determine offline access mode - hasOfflineAccess := false - for _, scope := range scopes { - if scope == "offline_access" { - hasOfflineAccess = true - break - } - } - - if hasOfflineAccess && !a.offlineAsScope { - // Provider uses access_type=offline instead of offline_access scope - authCodeOptions = append(authCodeOptions, oauth2.AccessTypeOffline) - // Remove offline_access from scopes as it's not supported - filteredScopes := make([]string, 0, len(scopes)) - for _, scope := range scopes { - if scope != "offline_access" { - filteredScopes = append(filteredScopes, scope) - } - } - scopes = filteredScopes - } - - authCodeURL = a.oauth2Config(scopes).AuthCodeURL(exampleAppState, authCodeOptions...) - - // Parse the auth code URL and safely add connector_id parameter if provided - u, err := url.Parse(authCodeURL) - if err != nil { - http.Error(w, "Failed to parse auth URL", http.StatusInternalServerError) - return - } - - if connectorID != "" { - query := u.Query() - query.Set("connector_id", connectorID) - u.RawQuery = query.Encode() - } - - http.Redirect(w, r, u.String(), http.StatusSeeOther) -} - -func (a *app) handleCallback(w http.ResponseWriter, r *http.Request) { - var ( - err error - token *oauth2.Token - ) - - ctx := oidc.ClientContext(r.Context(), a.client) - oauth2Config := a.oauth2Config(nil) - switch r.Method { - case http.MethodGet: - // Authorization redirect callback from OAuth2 auth flow. - if errMsg := r.FormValue("error"); errMsg != "" { - http.Error(w, errMsg+": "+r.FormValue("error_description"), http.StatusBadRequest) - return - } - code := r.FormValue("code") - if code == "" { - http.Error(w, fmt.Sprintf("no code in request: %q", r.Form), http.StatusBadRequest) - return - } - if state := r.FormValue("state"); state != exampleAppState { - http.Error(w, fmt.Sprintf("expected state %q got %q", exampleAppState, state), http.StatusBadRequest) - return - } - - var authCodeOptions []oauth2.AuthCodeOption - if a.pkce { - authCodeOptions = append(authCodeOptions, oauth2.SetAuthURLParam("code_verifier", codeVerifier)) - } - - token, err = oauth2Config.Exchange(ctx, code, authCodeOptions...) - case http.MethodPost: - // Form request from frontend to refresh a token. - refresh := r.FormValue("refresh_token") - if refresh == "" { - http.Error(w, fmt.Sprintf("no refresh_token in request: %q", r.Form), http.StatusBadRequest) - return - } - t := &oauth2.Token{ - RefreshToken: refresh, - Expiry: time.Now().Add(-time.Hour), - } - token, err = oauth2Config.TokenSource(ctx, t).Token() - default: - http.Error(w, fmt.Sprintf("method not implemented: %s", r.Method), http.StatusBadRequest) - return - } - - if err != nil { - http.Error(w, fmt.Sprintf("failed to get token: %v", err), http.StatusInternalServerError) - return - } - - rawIDToken, ok := token.Extra("id_token").(string) - if !ok { - http.Error(w, "no id_token in token response", http.StatusInternalServerError) - return - } - - idToken, err := a.verifier.Verify(r.Context(), rawIDToken) - if err != nil { - http.Error(w, fmt.Sprintf("failed to verify ID token: %v", err), http.StatusInternalServerError) - return - } - - accessToken, ok := token.Extra("access_token").(string) - if !ok { - http.Error(w, "no access_token in token response", http.StatusInternalServerError) - return - } - - var claims json.RawMessage - if err := idToken.Claims(&claims); err != nil { - http.Error(w, fmt.Sprintf("error decoding ID token claims: %v", err), http.StatusInternalServerError) - return - } - - buff := new(bytes.Buffer) - if err := json.Indent(buff, claims, "", " "); err != nil { - http.Error(w, fmt.Sprintf("error indenting ID token claims: %v", err), http.StatusInternalServerError) - return - } - - renderToken(w, r.Context(), a.provider, a.redirectURI, rawIDToken, accessToken, token.RefreshToken, buff.String()) -} - -func uniqueStrings(values []string) []string { - slices.Sort(values) - values = slices.Compact(values) - - return values -} diff --git a/examples/example-app/static/app.js b/examples/example-app/static/app.js index fc2c350d..14a8a80a 100644 --- a/examples/example-app/static/app.js +++ b/examples/example-app/static/app.js @@ -102,5 +102,53 @@ customScopeInput.value = ""; } }); + + // Device Grant Login Handler + const deviceGrantBtn = document.getElementById("device-grant-btn"); + deviceGrantBtn?.addEventListener("click", async () => { + deviceGrantBtn.disabled = true; + deviceGrantBtn.textContent = "Loading..."; + + try { + // Collect form data similar to regular login + const form = document.getElementById("login-form"); + const formData = new FormData(form); + + // Get selected scopes + const scopes = formData.getAll("extra_scopes"); + + // Get cross-client values + const crossClients = formData.getAll("cross_client"); + + // Get connector_id if specified + const connectorId = formData.get("connector_id") || ""; + + // Initiate device flow with options + const response = await fetch('/device/login', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + scopes: scopes, + cross_clients: crossClients, + connector_id: connectorId + }) + }); + + if (response.ok) { + // Redirect to device flow page + window.location.href = '/device'; + } else { + const errorText = await response.text(); + alert('Failed to start device flow: ' + errorText); + } + } catch (error) { + alert('Error starting device flow: ' + error.message); + } finally { + deviceGrantBtn.disabled = false; + deviceGrantBtn.textContent = "Device Code Flow"; + } + }); })(); diff --git a/examples/example-app/static/device.js b/examples/example-app/static/device.js new file mode 100644 index 00000000..5896682a --- /dev/null +++ b/examples/example-app/static/device.js @@ -0,0 +1,110 @@ +(function() { + const sessionID = document.getElementById("session-id")?.value; + const deviceCode = document.getElementById("device-code")?.value; + const pollInterval = parseInt(document.getElementById("poll-interval")?.value || "5", 10); + const verificationURL = document.getElementById("verification-url")?.textContent; + const userCode = document.getElementById("user-code")?.textContent; + const statusText = document.getElementById("status-text"); + const errorMessage = document.getElementById("error-message"); + const openAuthBtn = document.getElementById("open-auth-btn"); + + let pollTimer = null; + + document.querySelectorAll(".copy-btn").forEach(btn => { + btn.addEventListener("click", async function() { + const targetId = this.getAttribute("data-copy"); + const targetElement = document.getElementById(targetId); + + if (targetElement) { + const textToCopy = targetElement.textContent; + + try { + await navigator.clipboard.writeText(textToCopy); + const originalText = this.textContent; + this.textContent = "✓"; + setTimeout(() => { + this.textContent = originalText; + }, 2000); + } catch (err) { + console.error('Failed to copy:', err); + } + } + }); + }); + + openAuthBtn?.addEventListener("click", () => { + if (verificationURL && userCode) { + const url = verificationURL + "?user_code=" + encodeURIComponent(userCode); + window.open(url, "_blank", "width=600,height=800"); + } + }); + + async function pollForToken() { + try { + const response = await fetch('/device/poll', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + session_id: sessionID, + device_code: deviceCode + }) + }); + + const data = await response.json(); + + if (response.ok && data.status === 'complete') { + statusText.textContent = "Authentication successful! Redirecting..."; + stopPolling(); + window.location.href = '/device/result'; + } else if (response.ok && data.status === 'pending') { + statusText.textContent = "Waiting for authentication..."; + } else { + const errorText = data.error_description || data.error || 'Unknown error'; + + if (data.error === 'session_expired') { + showError('This session has been superseded by a new device flow. Please start over.'); + stopPolling(); + } else if (data.error === 'expired_token' || data.error === 'access_denied') { + showError(data.error === 'expired_token' ? + 'The device code has expired. Please start over.' : + 'Authentication was denied.'); + stopPolling(); + } + } + } catch (error) { + console.error('Polling error:', error); + } + } + + function showError(message) { + errorMessage.textContent = message; + errorMessage.style.display = 'block'; + + // Hide the status indicator (contains spinner and status text) + const statusIndicator = document.querySelector('.status-indicator'); + if (statusIndicator) { + statusIndicator.style.display = 'none'; + } + } + + function startPolling() { + pollForToken(); + pollTimer = setInterval(pollForToken, pollInterval * 1000); + } + + function stopPolling() { + if (pollTimer) { + clearInterval(pollTimer); + pollTimer = null; + } + } + + if (deviceCode) { + startPolling(); + } + + window.addEventListener('beforeunload', stopPolling); +})(); + diff --git a/examples/example-app/static/style.css b/examples/example-app/static/style.css index def84224..fafca567 100644 --- a/examples/example-app/static/style.css +++ b/examples/example-app/static/style.css @@ -355,3 +355,235 @@ pre .number { color: #00f; } +/* Login Buttons Styles */ +.login-buttons { + display: flex; + flex-direction: column; + gap: 12px; + margin-bottom: 20px; +} + +.login-button { + width: 100%; + padding: 14px 24px; + font-size: 16px; + border-radius: 4px; + cursor: pointer; + transition: all 0.3s ease; + font-weight: 600; + border: 2px solid #3F9FD8; +} + +.login-button.primary { + background-color: #3F9FD8; + color: #fff; +} + +.login-button.primary:hover { + background-color: #357FAA; + border-color: #357FAA; +} + +.login-button.secondary { + background-color: #fff; + color: #3F9FD8; +} + +.login-button.secondary:hover { + background-color: #f0f8ff; +} + +/* Device Flow Page Styles */ +.device-flow-container { + background-color: #fff; + padding: 30px; + border-radius: 8px; + box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1); + width: 100%; +} + +.device-instructions h2 { + margin-top: 0; + color: #333; + text-align: center; +} + +.instruction-text { + text-align: center; + color: #666; + margin-bottom: 25px; +} + +.verification-info { + display: flex; + flex-direction: column; + gap: 20px; + margin-bottom: 25px; +} + +.info-item { + display: flex; + flex-direction: column; + gap: 8px; +} + +.info-item label { + font-weight: 600; + color: #333; + font-size: 14px; +} + +.code-display { + display: flex; + align-items: center; + gap: 10px; + background-color: #f5f5f5; + padding: 12px; + border-radius: 4px; + border: 1px solid #ddd; +} + +.code-display.large { + padding: 20px; +} + +.code-display code { + flex: 1; + font-family: 'Courier New', Courier, monospace; + font-size: 14px; + word-break: break-all; +} + +.code-display code.user-code { + font-size: 24px; + font-weight: bold; + letter-spacing: 2px; + color: #3F9FD8; +} + +.copy-btn { + background: none; + border: none; + cursor: pointer; + font-size: 18px; + padding: 5px 10px; + border-radius: 4px; + transition: background-color 0.2s; +} + +.copy-btn:hover { + background-color: #e0e0e0; +} + +.copy-btn:active { + background-color: #d0d0d0; +} + +.actions { + text-align: center; +} + +.primary-button { + padding: 12px 32px; + font-size: 16px; + background-color: #3F9FD8; + color: #fff; + border: none; + border-radius: 4px; + cursor: pointer; + font-weight: 600; + transition: background-color 0.3s; +} + +.primary-button:hover { + background-color: #357FAA; +} + +.polling-status { + margin-top: 30px; + padding-top: 30px; + border-top: 1px solid #eee; +} + +.status-indicator { + display: flex; + align-items: center; + justify-content: center; + gap: 15px; + color: #666; +} + +.spinner { + width: 20px; + height: 20px; + border: 3px solid #f3f3f3; + border-top: 3px solid #3F9FD8; + border-radius: 50%; + animation: spin 1s linear infinite; +} + +@keyframes spin { + 0% { transform: rotate(0deg); } + 100% { transform: rotate(360deg); } +} + +.error-message { + margin-top: 15px; + padding: 12px; + background-color: #fee; + border: 1px solid #fcc; + border-radius: 4px; + color: #c00; + text-align: center; +} + +.device-data { + display: none; +} + +/* UserInfo Styles */ +#userinfo-section { + margin-top: 10px; +} + +.fetch-userinfo-btn { + padding: 10px 20px; + background-color: #3F9FD8; + color: white; + border: none; + border-radius: 4px; + cursor: pointer; + font-size: 14px; + font-weight: 500; + transition: background-color 0.2s; +} + +.fetch-userinfo-btn:hover { + background-color: #357FAA; +} + +.userinfo-loading { + display: flex; + align-items: center; + gap: 10px; + color: #666; + margin-top: 10px; +} + +.userinfo-loading .spinner { + width: 16px; + height: 16px; + border: 2px solid #f3f3f3; + border-top: 2px solid #3F9FD8; + border-radius: 50%; + animation: spin 1s linear infinite; +} + +#userinfo-claims { + margin-top: 15px; +} + +#userinfo-error { + margin-top: 10px; +} + diff --git a/examples/example-app/static/token.js b/examples/example-app/static/token.js index c147ce32..6d410a85 100644 --- a/examples/example-app/static/token.js +++ b/examples/example-app/static/token.js @@ -80,3 +80,73 @@ function showCopyFeedback(message) { }, 2000); } +// UserInfo functionality +document.addEventListener("DOMContentLoaded", function() { + const form = document.getElementById("userinfo-form"); + if (form) { + form.addEventListener("submit", fetchUserInfo); + } +}); + +async function fetchUserInfo(event) { + event.preventDefault(); + + const form = event.target; + const loading = document.getElementById("userinfo-loading"); + const error = document.getElementById("userinfo-error"); + const claimsElement = document.getElementById("userinfo-claims"); + const submitButton = form.querySelector('button[type="submit"]'); + + // Hide error and claims from previous attempts + error.style.display = "none"; + claimsElement.style.display = "none"; + + // Show loading, hide button + submitButton.style.display = "none"; + loading.style.display = "flex"; + + try { + const formData = new FormData(form); + + // Convert FormData to URL-encoded string + const urlEncodedData = new URLSearchParams(formData).toString(); + + const response = await fetch("/userinfo", { + method: "POST", + headers: { + "Content-Type": "application/x-www-form-urlencoded" + }, + body: urlEncodedData + }); + + if (!response.ok) { + const errorText = await response.text(); + throw new Error(errorText || `HTTP ${response.status}`); + } + + const userinfo = await response.json(); + + // Display the userinfo claims + const code = claimsElement.querySelector("code"); + const formattedJson = JSON.stringify(userinfo, null, 2); + code.textContent = formattedJson; + + // Apply syntax highlighting + try { + code.innerHTML = syntaxHighlight(userinfo); + } catch (e) { + console.error("Failed to highlight JSON:", e); + } + + claimsElement.style.display = "block"; + + } catch (err) { + console.error("Failed to fetch userinfo:", err); + error.textContent = "Failed to fetch UserInfo: " + err.message; + error.style.display = "block"; + submitButton.style.display = "inline-block"; + } finally { + loading.style.display = "none"; + } +} + diff --git a/examples/example-app/templates.go b/examples/example-app/templates.go index 67ff709d..b3ae7396 100644 --- a/examples/example-app/templates.go +++ b/examples/example-app/templates.go @@ -29,6 +29,7 @@ const dexLogoDataURI = "/static/dex-glyph-color.svg" var ( indexTmpl *template.Template tokenTmpl *template.Template + deviceTmpl *template.Template staticHandler http.Handler ) @@ -44,6 +45,11 @@ func init() { log.Fatalf("failed to parse token template: %v", err) } + deviceTmpl, err = template.ParseFS(templatesFS, "templates/device.html") + if err != nil { + log.Fatalf("failed to parse device template: %v", err) + } + // Create handler for static files staticSubFS, err := fs.Sub(staticFS, "static") if err != nil { @@ -56,11 +62,24 @@ func renderIndex(w http.ResponseWriter, data indexPageData) { renderTemplate(w, indexTmpl, data) } +func renderDevice(w http.ResponseWriter, data devicePageData) { + renderTemplate(w, deviceTmpl, data) +} + type indexPageData struct { ScopesSupported []string LogoURI string } +type devicePageData struct { + SessionID string + DeviceCode string + UserCode string + VerificationURI string + PollInterval int + LogoURI string +} + type tokenTmplData struct { IDToken string IDTokenJWTLink string diff --git a/examples/example-app/templates/device.html b/examples/example-app/templates/device.html new file mode 100644 index 00000000..092ec371 --- /dev/null +++ b/examples/example-app/templates/device.html @@ -0,0 +1,61 @@ + + + + + + Device Login - Example App + + + +
+ +
+
+

Device Login

+

Please authenticate on your device:

+ +
+
+ +
+ {{.VerificationURI}} + +
+
+ +
+ +
+ {{.UserCode}} + +
+
+
+ +
+ +
+
+ +
+
+
+ Waiting for authentication... +
+ +
+ + +
+
+ + + + + diff --git a/examples/example-app/templates/index.html b/examples/example-app/templates/index.html index 494920d6..063b154f 100644 --- a/examples/example-app/templates/index.html +++ b/examples/example-app/templates/index.html @@ -14,8 +14,13 @@ This is an example application for Dex OpenID Connect provider.
Learn more in the documentation. -
- +
Advanced options diff --git a/examples/example-app/templates/token.html b/examples/example-app/templates/token.html index f830e16b..b003deeb 100644 --- a/examples/example-app/templates/token.html +++ b/examples/example-app/templates/token.html @@ -29,11 +29,31 @@ {{ if .Claims }}
-
Claims:
+
ID Token Claims:
{{ .Claims }}
{{ end }} + {{ if .AccessToken }} +
+
UserInfo:
+
+
+ + +
+ + + +
+
+ {{ end }} + {{ if .RefreshToken }}
Refresh Token:
diff --git a/examples/example-app/utils.go b/examples/example-app/utils.go new file mode 100644 index 00000000..099c1062 --- /dev/null +++ b/examples/example-app/utils.go @@ -0,0 +1,154 @@ +package main + +import ( + "bytes" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "encoding/hex" + "encoding/json" + "fmt" + "log" + "net" + "net/http" + "net/http/httputil" + "os" + "slices" + "time" + + "github.com/coreos/go-oidc/v3/oidc" + "golang.org/x/oauth2" +) + +// generateSessionID creates a random session identifier +func generateSessionID() string { + b := make([]byte, 16) + if _, err := rand.Read(b); err != nil { + // Fallback to timestamp if random fails + return fmt.Sprintf("%d", time.Now().UnixNano()) + } + return hex.EncodeToString(b) +} + +// buildScopes constructs a scope list from base scopes and cross-client IDs +func buildScopes(baseScopes []string, crossClients []string) []string { + scopes := make([]string, len(baseScopes)) + copy(scopes, baseScopes) + + // Add audience scopes for cross-client authorization + for _, client := range crossClients { + if client != "" { + scopes = append(scopes, "audience:server:client_id:"+client) + } + } + + return uniqueStrings(scopes) +} + +func (a *app) oauth2Config(scopes []string) *oauth2.Config { + return &oauth2.Config{ + ClientID: a.clientID, + ClientSecret: a.clientSecret, + Endpoint: a.provider.Endpoint(), + Scopes: scopes, + RedirectURL: a.redirectURI, + } +} + +func uniqueStrings(values []string) []string { + slices.Sort(values) + values = slices.Compact(values) + return values +} + +// return an HTTP client which trusts the provided root CAs. +func httpClientForRootCAs(rootCAs string) (*http.Client, error) { + tlsConfig := tls.Config{RootCAs: x509.NewCertPool()} + rootCABytes, err := os.ReadFile(rootCAs) + if err != nil { + return nil, fmt.Errorf("failed to read root-ca: %v", err) + } + if !tlsConfig.RootCAs.AppendCertsFromPEM(rootCABytes) { + return nil, fmt.Errorf("no certs found in root CA file %q", rootCAs) + } + return &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tlsConfig, + Proxy: http.ProxyFromEnvironment, + Dial: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).Dial, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + }, + }, nil +} + +type debugTransport struct { + t http.RoundTripper +} + +func (d debugTransport) RoundTrip(req *http.Request) (*http.Response, error) { + reqDump, err := httputil.DumpRequest(req, true) + if err != nil { + return nil, err + } + log.Printf("%s", reqDump) + + resp, err := d.t.RoundTrip(req) + if err != nil { + return nil, err + } + + respDump, err := httputil.DumpResponse(resp, true) + if err != nil { + resp.Body.Close() + return nil, err + } + log.Printf("%s", respDump) + return resp, nil +} + +func encodeToken(idToken *oidc.IDToken) (string, error) { + var claims json.RawMessage + if err := idToken.Claims(&claims); err != nil { + return "", fmt.Errorf("error decoding ID token claims: %v", err) + } + + buff := new(bytes.Buffer) + if err := json.Indent(buff, claims, "", " "); err != nil { + return "", fmt.Errorf("error indenting ID token claims: %v", err) + } + return buff.String(), nil +} + +func parseAndRenderToken(w http.ResponseWriter, r *http.Request, a *app, token *oauth2.Token) { + rawIDToken, ok := token.Extra("id_token").(string) + if !ok { + http.Error(w, "no id_token in token response", http.StatusInternalServerError) + return + } + + idToken, err := a.verifier.Verify(r.Context(), rawIDToken) + if err != nil { + http.Error(w, fmt.Sprintf("failed to verify ID token: %v", err), http.StatusInternalServerError) + return + } + + accessToken, ok := token.Extra("access_token").(string) + if !ok { + accessToken = token.AccessToken + if accessToken == "" { + http.Error(w, "no access_token in token response", http.StatusInternalServerError) + return + } + } + + buf, err := encodeToken(idToken) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + renderToken(w, r.Context(), a.provider, a.redirectURI, rawIDToken, accessToken, token.RefreshToken, buf) +}