diff --git a/examples/example-app/main.go b/examples/example-app/main.go index 451bea5b..af566704 100644 --- a/examples/example-app/main.go +++ b/examples/example-app/main.go @@ -3,8 +3,11 @@ package main import ( "bytes" "context" + "crypto/rand" + "crypto/sha256" "crypto/tls" "crypto/x509" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -24,9 +27,20 @@ import ( const exampleAppState = "I wish to wash my irish wristwatch" +var ( + codeVerifier string + codeChallenge string +) + +func init() { + codeVerifier = generateCodeVerifier() + codeChallenge = generateCodeChallenge(codeVerifier) +} + type app struct { clientID string clientSecret string + pkce bool redirectURI string verifier *oidc.IDTokenVerifier @@ -193,6 +207,7 @@ func cmd() *cobra.Command { } c.Flags().StringVar(&a.clientID, "client-id", "example-app", "OAuth2 client ID of this application.") c.Flags().StringVar(&a.clientSecret, "client-secret", "ZXhhbXBsZS1hcHAtc2VjcmV0", "OAuth2 client secret of this application.") + c.Flags().BoolVar(&a.pkce, "pkce", true, "Use PKCE flow for the code exchange.") c.Flags().StringVar(&a.redirectURI, "redirect-uri", "http://127.0.0.1:5555/callback", "Callback URL for OAuth2 responses.") c.Flags().StringVar(&issuerURL, "issuer", "http://127.0.0.1:5556/dex", "URL of the OpenID Connect issuer.") c.Flags().StringVar(&listen, "listen", "http://127.0.0.1:5555", "HTTP(S) address to listen at.") @@ -243,14 +258,22 @@ func (a *app) handleLogin(w http.ResponseWriter, r *http.Request) { authCodeURL := "" scopes = append(scopes, "openid", "profile", "email") - if r.FormValue("offline_access") != "yes" { - authCodeURL = a.oauth2Config(scopes).AuthCodeURL(exampleAppState) - } else if a.offlineAsScope { + + var authCodeOptions []oauth2.AuthCodeOption + + if a.pkce { + authCodeOptions = append(authCodeOptions, oauth2.SetAuthURLParam("code_challenge", codeChallenge)) + authCodeOptions = append(authCodeOptions, oauth2.SetAuthURLParam("code_challenge_method", "S256")) + } + + a.oauth2Config(scopes) + if r.FormValue("offline_access") == "yes" { + authCodeOptions = append(authCodeOptions, oauth2.AccessTypeOffline) + } + if a.offlineAsScope { scopes = append(scopes, "offline_access") - authCodeURL = a.oauth2Config(scopes).AuthCodeURL(exampleAppState) - } else { - authCodeURL = a.oauth2Config(scopes).AuthCodeURL(exampleAppState, oauth2.AccessTypeOffline) } + authCodeURL = a.oauth2Config(scopes).AuthCodeURL(exampleAppState, authCodeOptions...) if connectorID != "" { authCodeURL = authCodeURL + "&connector_id=" + connectorID } @@ -282,7 +305,13 @@ func (a *app) handleCallback(w http.ResponseWriter, r *http.Request) { http.Error(w, fmt.Sprintf("expected state %q got %q", exampleAppState, state), http.StatusBadRequest) return } - token, err = oauth2Config.Exchange(ctx, code) + + var authCodeOptions []oauth2.AuthCodeOption + if a.pkce { + authCodeOptions = append(authCodeOptions, oauth2.SetAuthURLParam("code_verifier", codeVerifier)) + } + + token, err = oauth2Config.Exchange(ctx, code, authCodeOptions...) case http.MethodPost: // Form request from frontend to refresh a token. refresh := r.FormValue("refresh_token") @@ -337,3 +366,16 @@ func (a *app) handleCallback(w http.ResponseWriter, r *http.Request) { renderToken(w, a.redirectURI, rawIDToken, accessToken, token.RefreshToken, buff.String()) } + +func generateCodeVerifier() string { + bytes := make([]byte, 64) // 86 symbols Base64URL + if _, err := rand.Read(bytes); err != nil { + log.Fatalf("rand.Read error: %v", err) + } + return base64.RawURLEncoding.EncodeToString(bytes) +} + +func generateCodeChallenge(verifier string) string { + hash := sha256.Sum256([]byte(verifier)) + return base64.RawURLEncoding.EncodeToString(hash[:]) +}