|
|
|
|
@ -3,8 +3,11 @@ package main
|
|
|
|
|
import ( |
|
|
|
|
"bytes" |
|
|
|
|
"context" |
|
|
|
|
"crypto/rand" |
|
|
|
|
"crypto/sha256" |
|
|
|
|
"crypto/tls" |
|
|
|
|
"crypto/x509" |
|
|
|
|
"encoding/base64" |
|
|
|
|
"encoding/json" |
|
|
|
|
"errors" |
|
|
|
|
"fmt" |
|
|
|
|
@ -24,9 +27,20 @@ import (
|
|
|
|
|
|
|
|
|
|
const exampleAppState = "I wish to wash my irish wristwatch" |
|
|
|
|
|
|
|
|
|
var ( |
|
|
|
|
codeVerifier string |
|
|
|
|
codeChallenge string |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
func init() { |
|
|
|
|
codeVerifier = generateCodeVerifier() |
|
|
|
|
codeChallenge = generateCodeChallenge(codeVerifier) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
type app struct { |
|
|
|
|
clientID string |
|
|
|
|
clientSecret string |
|
|
|
|
pkce bool |
|
|
|
|
redirectURI string |
|
|
|
|
|
|
|
|
|
verifier *oidc.IDTokenVerifier |
|
|
|
|
@ -193,6 +207,7 @@ func cmd() *cobra.Command {
|
|
|
|
|
} |
|
|
|
|
c.Flags().StringVar(&a.clientID, "client-id", "example-app", "OAuth2 client ID of this application.") |
|
|
|
|
c.Flags().StringVar(&a.clientSecret, "client-secret", "ZXhhbXBsZS1hcHAtc2VjcmV0", "OAuth2 client secret of this application.") |
|
|
|
|
c.Flags().BoolVar(&a.pkce, "pkce", true, "Use PKCE flow for the code exchange.") |
|
|
|
|
c.Flags().StringVar(&a.redirectURI, "redirect-uri", "http://127.0.0.1:5555/callback", "Callback URL for OAuth2 responses.") |
|
|
|
|
c.Flags().StringVar(&issuerURL, "issuer", "http://127.0.0.1:5556/dex", "URL of the OpenID Connect issuer.") |
|
|
|
|
c.Flags().StringVar(&listen, "listen", "http://127.0.0.1:5555", "HTTP(S) address to listen at.") |
|
|
|
|
@ -243,14 +258,22 @@ func (a *app) handleLogin(w http.ResponseWriter, r *http.Request) {
|
|
|
|
|
|
|
|
|
|
authCodeURL := "" |
|
|
|
|
scopes = append(scopes, "openid", "profile", "email") |
|
|
|
|
if r.FormValue("offline_access") != "yes" { |
|
|
|
|
authCodeURL = a.oauth2Config(scopes).AuthCodeURL(exampleAppState) |
|
|
|
|
} else if a.offlineAsScope { |
|
|
|
|
|
|
|
|
|
var authCodeOptions []oauth2.AuthCodeOption |
|
|
|
|
|
|
|
|
|
if a.pkce { |
|
|
|
|
authCodeOptions = append(authCodeOptions, oauth2.SetAuthURLParam("code_challenge", codeChallenge)) |
|
|
|
|
authCodeOptions = append(authCodeOptions, oauth2.SetAuthURLParam("code_challenge_method", "S256")) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
a.oauth2Config(scopes) |
|
|
|
|
if r.FormValue("offline_access") == "yes" { |
|
|
|
|
authCodeOptions = append(authCodeOptions, oauth2.AccessTypeOffline) |
|
|
|
|
} |
|
|
|
|
if a.offlineAsScope { |
|
|
|
|
scopes = append(scopes, "offline_access") |
|
|
|
|
authCodeURL = a.oauth2Config(scopes).AuthCodeURL(exampleAppState) |
|
|
|
|
} else { |
|
|
|
|
authCodeURL = a.oauth2Config(scopes).AuthCodeURL(exampleAppState, oauth2.AccessTypeOffline) |
|
|
|
|
} |
|
|
|
|
authCodeURL = a.oauth2Config(scopes).AuthCodeURL(exampleAppState, authCodeOptions...) |
|
|
|
|
if connectorID != "" { |
|
|
|
|
authCodeURL = authCodeURL + "&connector_id=" + connectorID |
|
|
|
|
} |
|
|
|
|
@ -282,7 +305,13 @@ func (a *app) handleCallback(w http.ResponseWriter, r *http.Request) {
|
|
|
|
|
http.Error(w, fmt.Sprintf("expected state %q got %q", exampleAppState, state), http.StatusBadRequest) |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
token, err = oauth2Config.Exchange(ctx, code) |
|
|
|
|
|
|
|
|
|
var authCodeOptions []oauth2.AuthCodeOption |
|
|
|
|
if a.pkce { |
|
|
|
|
authCodeOptions = append(authCodeOptions, oauth2.SetAuthURLParam("code_verifier", codeVerifier)) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
token, err = oauth2Config.Exchange(ctx, code, authCodeOptions...) |
|
|
|
|
case http.MethodPost: |
|
|
|
|
// Form request from frontend to refresh a token.
|
|
|
|
|
refresh := r.FormValue("refresh_token") |
|
|
|
|
@ -337,3 +366,16 @@ func (a *app) handleCallback(w http.ResponseWriter, r *http.Request) {
|
|
|
|
|
|
|
|
|
|
renderToken(w, a.redirectURI, rawIDToken, accessToken, token.RefreshToken, buff.String()) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func generateCodeVerifier() string { |
|
|
|
|
bytes := make([]byte, 64) // 86 symbols Base64URL
|
|
|
|
|
if _, err := rand.Read(bytes); err != nil { |
|
|
|
|
log.Fatalf("rand.Read error: %v", err) |
|
|
|
|
} |
|
|
|
|
return base64.RawURLEncoding.EncodeToString(bytes) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func generateCodeChallenge(verifier string) string { |
|
|
|
|
hash := sha256.Sum256([]byte(verifier)) |
|
|
|
|
return base64.RawURLEncoding.EncodeToString(hash[:]) |
|
|
|
|
} |
|
|
|
|
|