mirror of https://github.com/dexidp/dex.git
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
407 lines
11 KiB
407 lines
11 KiB
package main |
|
|
|
import ( |
|
"bytes" |
|
"context" |
|
"crypto/tls" |
|
"crypto/x509" |
|
"encoding/json" |
|
"errors" |
|
"fmt" |
|
"log" |
|
"net" |
|
"net/http" |
|
"net/http/httputil" |
|
"net/url" |
|
"os" |
|
"slices" |
|
"time" |
|
|
|
"github.com/coreos/go-oidc/v3/oidc" |
|
"github.com/spf13/cobra" |
|
"golang.org/x/oauth2" |
|
) |
|
|
|
const exampleAppState = "I wish to wash my irish wristwatch" |
|
|
|
var ( |
|
codeVerifier string |
|
codeChallenge string |
|
) |
|
|
|
func init() { |
|
codeVerifier = oauth2.GenerateVerifier() |
|
codeChallenge = oauth2.S256ChallengeFromVerifier(codeVerifier) |
|
} |
|
|
|
type app struct { |
|
clientID string |
|
clientSecret string |
|
pkce bool |
|
redirectURI string |
|
|
|
verifier *oidc.IDTokenVerifier |
|
provider *oidc.Provider |
|
scopesSupported []string |
|
|
|
// Does the provider use "offline_access" scope to request a refresh token |
|
// or does it use "access_type=offline" (e.g. Google)? |
|
offlineAsScope bool |
|
|
|
client *http.Client |
|
} |
|
|
|
// return an HTTP client which trusts the provided root CAs. |
|
func httpClientForRootCAs(rootCAs string) (*http.Client, error) { |
|
tlsConfig := tls.Config{RootCAs: x509.NewCertPool()} |
|
rootCABytes, err := os.ReadFile(rootCAs) |
|
if err != nil { |
|
return nil, fmt.Errorf("failed to read root-ca: %v", err) |
|
} |
|
if !tlsConfig.RootCAs.AppendCertsFromPEM(rootCABytes) { |
|
return nil, fmt.Errorf("no certs found in root CA file %q", rootCAs) |
|
} |
|
return &http.Client{ |
|
Transport: &http.Transport{ |
|
TLSClientConfig: &tlsConfig, |
|
Proxy: http.ProxyFromEnvironment, |
|
Dial: (&net.Dialer{ |
|
Timeout: 30 * time.Second, |
|
KeepAlive: 30 * time.Second, |
|
}).Dial, |
|
TLSHandshakeTimeout: 10 * time.Second, |
|
ExpectContinueTimeout: 1 * time.Second, |
|
}, |
|
}, nil |
|
} |
|
|
|
type debugTransport struct { |
|
t http.RoundTripper |
|
} |
|
|
|
func (d debugTransport) RoundTrip(req *http.Request) (*http.Response, error) { |
|
reqDump, err := httputil.DumpRequest(req, true) |
|
if err != nil { |
|
return nil, err |
|
} |
|
log.Printf("%s", reqDump) |
|
|
|
resp, err := d.t.RoundTrip(req) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
respDump, err := httputil.DumpResponse(resp, true) |
|
if err != nil { |
|
resp.Body.Close() |
|
return nil, err |
|
} |
|
log.Printf("%s", respDump) |
|
return resp, nil |
|
} |
|
|
|
func cmd() *cobra.Command { |
|
var ( |
|
a app |
|
issuerURL string |
|
listen string |
|
tlsCert string |
|
tlsKey string |
|
rootCAs string |
|
debug bool |
|
) |
|
c := cobra.Command{ |
|
Use: "example-app", |
|
Short: "An example OpenID Connect client", |
|
Long: "", |
|
RunE: func(cmd *cobra.Command, args []string) error { |
|
if len(args) != 0 { |
|
return errors.New("surplus arguments provided") |
|
} |
|
|
|
u, err := url.Parse(a.redirectURI) |
|
if err != nil { |
|
return fmt.Errorf("parse redirect-uri: %v", err) |
|
} |
|
listenURL, err := url.Parse(listen) |
|
if err != nil { |
|
return fmt.Errorf("parse listen address: %v", err) |
|
} |
|
|
|
if rootCAs != "" { |
|
client, err := httpClientForRootCAs(rootCAs) |
|
if err != nil { |
|
return err |
|
} |
|
a.client = client |
|
} |
|
|
|
if debug { |
|
if a.client == nil { |
|
a.client = &http.Client{ |
|
Transport: debugTransport{http.DefaultTransport}, |
|
} |
|
} else { |
|
a.client.Transport = debugTransport{a.client.Transport} |
|
} |
|
} |
|
|
|
if a.client == nil { |
|
a.client = http.DefaultClient |
|
} |
|
|
|
// TODO(ericchiang): Retry with backoff |
|
ctx := oidc.ClientContext(context.Background(), a.client) |
|
provider, err := oidc.NewProvider(ctx, issuerURL) |
|
if err != nil { |
|
return fmt.Errorf("failed to query provider %q: %v", issuerURL, err) |
|
} |
|
|
|
var s struct { |
|
// What scopes does a provider support? |
|
// |
|
// See: https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata |
|
ScopesSupported []string `json:"scopes_supported"` |
|
} |
|
if err := provider.Claims(&s); err != nil { |
|
return fmt.Errorf("failed to parse provider scopes_supported: %v", err) |
|
} |
|
|
|
if len(s.ScopesSupported) == 0 { |
|
// scopes_supported is a "RECOMMENDED" discovery claim, not a required |
|
// one. If missing, assume that the provider follows the spec and has |
|
// an "offline_access" scope. |
|
a.offlineAsScope = true |
|
} else { |
|
// See if scopes_supported has the "offline_access" scope. |
|
a.offlineAsScope = func() bool { |
|
for _, scope := range s.ScopesSupported { |
|
if scope == oidc.ScopeOfflineAccess { |
|
return true |
|
} |
|
} |
|
return false |
|
}() |
|
} |
|
|
|
a.provider = provider |
|
a.verifier = provider.Verifier(&oidc.Config{ClientID: a.clientID}) |
|
a.scopesSupported = s.ScopesSupported |
|
|
|
http.Handle("/static/", http.StripPrefix("/static/", staticHandler)) |
|
http.HandleFunc("/", a.handleIndex) |
|
http.HandleFunc("/login", a.handleLogin) |
|
http.HandleFunc(u.Path, a.handleCallback) |
|
|
|
switch listenURL.Scheme { |
|
case "http": |
|
log.Printf("listening on %s", listen) |
|
return http.ListenAndServe(listenURL.Host, nil) |
|
case "https": |
|
log.Printf("listening on %s", listen) |
|
return http.ListenAndServeTLS(listenURL.Host, tlsCert, tlsKey, nil) |
|
default: |
|
return fmt.Errorf("listen address %q is not using http or https", listen) |
|
} |
|
}, |
|
} |
|
c.Flags().StringVar(&a.clientID, "client-id", "example-app", "OAuth2 client ID of this application.") |
|
c.Flags().StringVar(&a.clientSecret, "client-secret", "ZXhhbXBsZS1hcHAtc2VjcmV0", "OAuth2 client secret of this application.") |
|
c.Flags().BoolVar(&a.pkce, "pkce", true, "Use PKCE flow for the code exchange.") |
|
c.Flags().StringVar(&a.redirectURI, "redirect-uri", "http://127.0.0.1:5555/callback", "Callback URL for OAuth2 responses.") |
|
c.Flags().StringVar(&issuerURL, "issuer", "http://127.0.0.1:5556/dex", "URL of the OpenID Connect issuer.") |
|
c.Flags().StringVar(&listen, "listen", "http://127.0.0.1:5555", "HTTP(S) address to listen at.") |
|
c.Flags().StringVar(&tlsCert, "tls-cert", "", "X509 cert file to present when serving HTTPS.") |
|
c.Flags().StringVar(&tlsKey, "tls-key", "", "Private key for the HTTPS cert.") |
|
c.Flags().StringVar(&rootCAs, "issuer-root-ca", "", "Root certificate authorities for the issuer. Defaults to host certs.") |
|
c.Flags().BoolVar(&debug, "debug", false, "Print all request and responses from the OpenID Connect issuer.") |
|
return &c |
|
} |
|
|
|
func main() { |
|
if err := cmd().Execute(); err != nil { |
|
fmt.Fprintf(os.Stderr, "error: %v\n", err) |
|
os.Exit(2) |
|
} |
|
} |
|
|
|
func (a *app) handleIndex(w http.ResponseWriter, r *http.Request) { |
|
renderIndex(w, indexPageData{ |
|
ScopesSupported: a.scopesSupported, |
|
LogoURI: dexLogoDataURI, |
|
}) |
|
} |
|
|
|
func (a *app) oauth2Config(scopes []string) *oauth2.Config { |
|
return &oauth2.Config{ |
|
ClientID: a.clientID, |
|
ClientSecret: a.clientSecret, |
|
Endpoint: a.provider.Endpoint(), |
|
Scopes: scopes, |
|
RedirectURL: a.redirectURI, |
|
} |
|
} |
|
|
|
func (a *app) handleLogin(w http.ResponseWriter, r *http.Request) { |
|
if err := r.ParseForm(); err != nil { |
|
http.Error(w, fmt.Sprintf("failed to parse form: %v", err), http.StatusBadRequest) |
|
return |
|
} |
|
|
|
// Only use scopes that are checked in the form |
|
scopes := r.Form["extra_scopes"] |
|
|
|
clients := r.Form["cross_client"] |
|
for _, client := range clients { |
|
if client == "" { |
|
continue |
|
} |
|
scopes = append(scopes, "audience:server:client_id:"+client) |
|
} |
|
connectorID := "" |
|
if id := r.FormValue("connector_id"); id != "" { |
|
connectorID = id |
|
} |
|
|
|
authCodeURL := "" |
|
scopes = uniqueStrings(scopes) |
|
|
|
var authCodeOptions []oauth2.AuthCodeOption |
|
|
|
if a.pkce { |
|
authCodeOptions = append(authCodeOptions, oauth2.SetAuthURLParam("code_challenge", codeChallenge)) |
|
authCodeOptions = append(authCodeOptions, oauth2.SetAuthURLParam("code_challenge_method", "S256")) |
|
} |
|
|
|
// Check if offline_access scope is present to determine offline access mode |
|
hasOfflineAccess := false |
|
for _, scope := range scopes { |
|
if scope == "offline_access" { |
|
hasOfflineAccess = true |
|
break |
|
} |
|
} |
|
|
|
if hasOfflineAccess && !a.offlineAsScope { |
|
// Provider uses access_type=offline instead of offline_access scope |
|
authCodeOptions = append(authCodeOptions, oauth2.AccessTypeOffline) |
|
// Remove offline_access from scopes as it's not supported |
|
filteredScopes := make([]string, 0, len(scopes)) |
|
for _, scope := range scopes { |
|
if scope != "offline_access" { |
|
filteredScopes = append(filteredScopes, scope) |
|
} |
|
} |
|
scopes = filteredScopes |
|
} |
|
|
|
authCodeURL = a.oauth2Config(scopes).AuthCodeURL(exampleAppState, authCodeOptions...) |
|
|
|
// Parse the auth code URL and safely add connector_id parameter if provided |
|
u, err := url.Parse(authCodeURL) |
|
if err != nil { |
|
http.Error(w, "Failed to parse auth URL", http.StatusInternalServerError) |
|
return |
|
} |
|
|
|
if connectorID != "" { |
|
query := u.Query() |
|
query.Set("connector_id", connectorID) |
|
u.RawQuery = query.Encode() |
|
} |
|
|
|
http.Redirect(w, r, u.String(), http.StatusSeeOther) |
|
} |
|
|
|
func (a *app) handleCallback(w http.ResponseWriter, r *http.Request) { |
|
var ( |
|
err error |
|
token *oauth2.Token |
|
) |
|
|
|
ctx := oidc.ClientContext(r.Context(), a.client) |
|
oauth2Config := a.oauth2Config(nil) |
|
switch r.Method { |
|
case http.MethodGet: |
|
// Authorization redirect callback from OAuth2 auth flow. |
|
if errMsg := r.FormValue("error"); errMsg != "" { |
|
http.Error(w, errMsg+": "+r.FormValue("error_description"), http.StatusBadRequest) |
|
return |
|
} |
|
code := r.FormValue("code") |
|
if code == "" { |
|
http.Error(w, fmt.Sprintf("no code in request: %q", r.Form), http.StatusBadRequest) |
|
return |
|
} |
|
if state := r.FormValue("state"); state != exampleAppState { |
|
http.Error(w, fmt.Sprintf("expected state %q got %q", exampleAppState, state), http.StatusBadRequest) |
|
return |
|
} |
|
|
|
var authCodeOptions []oauth2.AuthCodeOption |
|
if a.pkce { |
|
authCodeOptions = append(authCodeOptions, oauth2.SetAuthURLParam("code_verifier", codeVerifier)) |
|
} |
|
|
|
token, err = oauth2Config.Exchange(ctx, code, authCodeOptions...) |
|
case http.MethodPost: |
|
// Form request from frontend to refresh a token. |
|
refresh := r.FormValue("refresh_token") |
|
if refresh == "" { |
|
http.Error(w, fmt.Sprintf("no refresh_token in request: %q", r.Form), http.StatusBadRequest) |
|
return |
|
} |
|
t := &oauth2.Token{ |
|
RefreshToken: refresh, |
|
Expiry: time.Now().Add(-time.Hour), |
|
} |
|
token, err = oauth2Config.TokenSource(ctx, t).Token() |
|
default: |
|
http.Error(w, fmt.Sprintf("method not implemented: %s", r.Method), http.StatusBadRequest) |
|
return |
|
} |
|
|
|
if err != nil { |
|
http.Error(w, fmt.Sprintf("failed to get token: %v", err), http.StatusInternalServerError) |
|
return |
|
} |
|
|
|
rawIDToken, ok := token.Extra("id_token").(string) |
|
if !ok { |
|
http.Error(w, "no id_token in token response", http.StatusInternalServerError) |
|
return |
|
} |
|
|
|
idToken, err := a.verifier.Verify(r.Context(), rawIDToken) |
|
if err != nil { |
|
http.Error(w, fmt.Sprintf("failed to verify ID token: %v", err), http.StatusInternalServerError) |
|
return |
|
} |
|
|
|
accessToken, ok := token.Extra("access_token").(string) |
|
if !ok { |
|
http.Error(w, "no access_token in token response", http.StatusInternalServerError) |
|
return |
|
} |
|
|
|
var claims json.RawMessage |
|
if err := idToken.Claims(&claims); err != nil { |
|
http.Error(w, fmt.Sprintf("error decoding ID token claims: %v", err), http.StatusInternalServerError) |
|
return |
|
} |
|
|
|
buff := new(bytes.Buffer) |
|
if err := json.Indent(buff, claims, "", " "); err != nil { |
|
http.Error(w, fmt.Sprintf("error indenting ID token claims: %v", err), http.StatusInternalServerError) |
|
return |
|
} |
|
|
|
renderToken(w, r.Context(), a.provider, a.redirectURI, rawIDToken, accessToken, token.RefreshToken, buff.String()) |
|
} |
|
|
|
func uniqueStrings(values []string) []string { |
|
slices.Sort(values) |
|
values = slices.Compact(values) |
|
|
|
return values |
|
}
|
|
|