Browse Source

Add support to PKCE in OIDC connector (#3777)

Signed-off-by: johnvan7 <giovanni.vella98@gmail.com>
Signed-off-by: Giovanni Vella <giovanni.vella98@gmail.com>
pull/4572/head
Giovanni Vella 3 weeks ago committed by GitHub
parent
commit
25591eeaf4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 8
      connector/bitbucketcloud/bitbucketcloud.go
  2. 2
      connector/bitbucketcloud/bitbucketcloud_test.go
  3. 4
      connector/connector.go
  4. 8
      connector/gitea/gitea.go
  5. 4
      connector/gitea/gitea_test.go
  6. 8
      connector/github/github.go
  7. 6
      connector/github/github_test.go
  8. 8
      connector/gitlab/gitlab.go
  9. 14
      connector/gitlab/gitlab_test.go
  10. 8
      connector/google/google.go
  11. 2
      connector/google/google_test.go
  12. 8
      connector/linkedin/linkedin.go
  13. 8
      connector/microsoft/microsoft.go
  14. 8
      connector/microsoft/microsoft_test.go
  15. 8
      connector/mock/connectortest.go
  16. 101
      connector/oidc/oidc.go
  17. 42
      connector/oidc/oidc_test.go
  18. 7
      connector/openshift/openshift.go
  19. 2
      connector/openshift/openshift_test.go
  20. 16
      server/handlers.go

8
connector/bitbucketcloud/bitbucketcloud.go

@ -111,12 +111,12 @@ func (b *bitbucketConnector) oauth2Config(scopes connector.Scopes) *oauth2.Confi
}
}
func (b *bitbucketConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) {
func (b *bitbucketConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, []byte, error) {
if b.redirectURI != callbackURL {
return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, b.redirectURI)
return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, b.redirectURI)
}
return b.oauth2Config(scopes).AuthCodeURL(state), nil
return b.oauth2Config(scopes).AuthCodeURL(state), nil, nil
}
type oauth2Error struct {
@ -131,7 +131,7 @@ func (e *oauth2Error) Error() string {
return e.error + ": " + e.errorDescription
}
func (b *bitbucketConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) {
func (b *bitbucketConnector) HandleCallback(s connector.Scopes, connData []byte, r *http.Request) (identity connector.Identity, err error) {
q := r.URL.Query()
if errType := q.Get("error"); errType != "" {
return identity, &oauth2Error{errType, q.Get("error_description")}

2
connector/bitbucketcloud/bitbucketcloud_test.go

@ -102,7 +102,7 @@ func TestUsernameIncludedInFederatedIdentity(t *testing.T) {
expectNil(t, err)
bitbucketConnector := bitbucketConnector{apiURL: s.URL, hostName: hostURL.Host, httpClient: newClient()}
identity, err := bitbucketConnector.HandleCallback(connector.Scopes{}, req)
identity, err := bitbucketConnector.HandleCallback(connector.Scopes{}, nil, req)
expectNil(t, err)
expectEquals(t, identity.Username, "some-login")

4
connector/connector.go

@ -63,10 +63,10 @@ type CallbackConnector interface {
// requested if one has already been issues. There's no good general answer
// for these kind of restrictions, and may require this package to become more
// aware of the global set of user/connector interactions.
LoginURL(s Scopes, callbackURL, state string) (string, error)
LoginURL(s Scopes, callbackURL, state string) (string, []byte, error)
// Handle the callback to the server and return an identity.
HandleCallback(s Scopes, r *http.Request) (identity Identity, err error)
HandleCallback(s Scopes, connData []byte, r *http.Request) (identity Identity, err error)
}
// SAMLConnector represents SAML connectors which implement the HTTP POST binding.

8
connector/gitea/gitea.go

@ -102,11 +102,11 @@ func (c *giteaConnector) oauth2Config(_ connector.Scopes) *oauth2.Config {
}
}
func (c *giteaConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) {
func (c *giteaConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, []byte, error) {
if c.redirectURI != callbackURL {
return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", c.redirectURI, callbackURL)
return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", c.redirectURI, callbackURL)
}
return c.oauth2Config(scopes).AuthCodeURL(state), nil
return c.oauth2Config(scopes).AuthCodeURL(state), nil, nil
}
type oauth2Error struct {
@ -121,7 +121,7 @@ func (e *oauth2Error) Error() string {
return e.error + ": " + e.errorDescription
}
func (c *giteaConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) {
func (c *giteaConnector) HandleCallback(s connector.Scopes, connData []byte, r *http.Request) (identity connector.Identity, err error) {
q := r.URL.Query()
if errType := q.Get("error"); errType != "" {
return identity, &oauth2Error{errType, q.Get("error_description")}

4
connector/gitea/gitea_test.go

@ -30,14 +30,14 @@ func TestUsernameIncludedInFederatedIdentity(t *testing.T) {
expectNil(t, err)
c := giteaConnector{baseURL: s.URL, httpClient: newClient()}
identity, err := c.HandleCallback(connector.Scopes{}, req)
identity, err := c.HandleCallback(connector.Scopes{}, nil, req)
expectNil(t, err)
expectEquals(t, identity.Username, "some@email.com")
expectEquals(t, identity.UserID, "12345678")
c = giteaConnector{baseURL: s.URL, httpClient: newClient()}
identity, err = c.HandleCallback(connector.Scopes{}, req)
identity, err = c.HandleCallback(connector.Scopes{}, nil, req)
expectNil(t, err)
expectEquals(t, identity.Username, "some@email.com")

8
connector/github/github.go

@ -194,12 +194,12 @@ func (c *githubConnector) oauth2Config(scopes connector.Scopes) *oauth2.Config {
}
}
func (c *githubConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) {
func (c *githubConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, []byte, error) {
if c.redirectURI != callbackURL {
return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI)
return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI)
}
return c.oauth2Config(scopes).AuthCodeURL(state), nil
return c.oauth2Config(scopes).AuthCodeURL(state), nil, nil
}
type oauth2Error struct {
@ -214,7 +214,7 @@ func (e *oauth2Error) Error() string {
return e.error + ": " + e.errorDescription
}
func (c *githubConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) {
func (c *githubConnector) HandleCallback(s connector.Scopes, connData []byte, r *http.Request) (identity connector.Identity, err error) {
q := r.URL.Query()
if errType := q.Get("error"); errType != "" {
return identity, &oauth2Error{errType, q.Get("error_description")}

6
connector/github/github_test.go

@ -152,7 +152,7 @@ func TestUsernameIncludedInFederatedIdentity(t *testing.T) {
expectNil(t, err)
c := githubConnector{apiURL: s.URL, hostName: hostURL.Host, httpClient: newClient()}
identity, err := c.HandleCallback(connector.Scopes{Groups: true}, req)
identity, err := c.HandleCallback(connector.Scopes{Groups: true}, nil, req)
expectNil(t, err)
expectEquals(t, identity.Username, "some-login")
@ -160,7 +160,7 @@ func TestUsernameIncludedInFederatedIdentity(t *testing.T) {
expectEquals(t, 0, len(identity.Groups))
c = githubConnector{apiURL: s.URL, hostName: hostURL.Host, httpClient: newClient(), loadAllGroups: true}
identity, err = c.HandleCallback(connector.Scopes{Groups: true}, req)
identity, err = c.HandleCallback(connector.Scopes{Groups: true}, nil, req)
expectNil(t, err)
expectEquals(t, identity.Username, "some-login")
@ -193,7 +193,7 @@ func TestLoginUsedAsIDWhenConfigured(t *testing.T) {
expectNil(t, err)
c := githubConnector{apiURL: s.URL, hostName: hostURL.Host, httpClient: newClient(), useLoginAsID: true}
identity, err := c.HandleCallback(connector.Scopes{Groups: true}, req)
identity, err := c.HandleCallback(connector.Scopes{Groups: true}, nil, req)
expectNil(t, err)
expectEquals(t, identity.UserID, "some-login")

8
connector/gitlab/gitlab.go

@ -122,11 +122,11 @@ func (c *gitlabConnector) oauth2Config(scopes connector.Scopes) *oauth2.Config {
}
}
func (c *gitlabConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) {
func (c *gitlabConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, []byte, error) {
if c.redirectURI != callbackURL {
return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", c.redirectURI, callbackURL)
return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", c.redirectURI, callbackURL)
}
return c.oauth2Config(scopes).AuthCodeURL(state), nil
return c.oauth2Config(scopes).AuthCodeURL(state), nil, nil
}
type oauth2Error struct {
@ -141,7 +141,7 @@ func (e *oauth2Error) Error() string {
return e.error + ": " + e.errorDescription
}
func (c *gitlabConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) {
func (c *gitlabConnector) HandleCallback(s connector.Scopes, connData []byte, r *http.Request) (identity connector.Identity, err error) {
q := r.URL.Query()
if errType := q.Get("error"); errType != "" {
return identity, &oauth2Error{errType, q.Get("error_description")}

14
connector/gitlab/gitlab_test.go

@ -247,7 +247,7 @@ func TestUsernameIncludedInFederatedIdentity(t *testing.T) {
expectNil(t, err)
c := gitlabConnector{baseURL: s.URL, httpClient: newClient()}
identity, err := c.HandleCallback(connector.Scopes{Groups: false}, req)
identity, err := c.HandleCallback(connector.Scopes{Groups: false}, nil, req)
expectNil(t, err)
expectEquals(t, identity.Username, "some@email.com")
@ -255,7 +255,7 @@ func TestUsernameIncludedInFederatedIdentity(t *testing.T) {
expectEquals(t, 0, len(identity.Groups))
c = gitlabConnector{baseURL: s.URL, httpClient: newClient()}
identity, err = c.HandleCallback(connector.Scopes{Groups: true}, req)
identity, err = c.HandleCallback(connector.Scopes{Groups: true}, nil, req)
expectNil(t, err)
expectEquals(t, identity.Username, "some@email.com")
@ -283,7 +283,7 @@ func TestLoginUsedAsIDWhenConfigured(t *testing.T) {
expectNil(t, err)
c := gitlabConnector{baseURL: s.URL, httpClient: newClient(), useLoginAsID: true}
identity, err := c.HandleCallback(connector.Scopes{Groups: true}, req)
identity, err := c.HandleCallback(connector.Scopes{Groups: true}, nil, req)
expectNil(t, err)
expectEquals(t, identity.UserID, "joebloggs")
@ -310,7 +310,7 @@ func TestLoginWithTeamWhitelisted(t *testing.T) {
expectNil(t, err)
c := gitlabConnector{baseURL: s.URL, httpClient: newClient(), groups: []string{"team-1"}}
identity, err := c.HandleCallback(connector.Scopes{Groups: true}, req)
identity, err := c.HandleCallback(connector.Scopes{Groups: true}, nil, req)
expectNil(t, err)
expectEquals(t, identity.UserID, "12345678")
@ -337,7 +337,7 @@ func TestLoginWithTeamNonWhitelisted(t *testing.T) {
expectNil(t, err)
c := gitlabConnector{baseURL: s.URL, httpClient: newClient(), groups: []string{"team-2"}}
_, err = c.HandleCallback(connector.Scopes{Groups: true}, req)
_, err = c.HandleCallback(connector.Scopes{Groups: true}, nil, req)
expectNotNil(t, err, "HandleCallback error")
expectEquals(t, err.Error(), "gitlab: get groups: gitlab: user \"joebloggs\" is not in any of the required groups")
@ -371,7 +371,7 @@ func TestRefresh(t *testing.T) {
})
expectNil(t, err)
identity, err := c.HandleCallback(connector.Scopes{OfflineAccess: true}, req)
identity, err := c.HandleCallback(connector.Scopes{OfflineAccess: true}, nil, req)
expectNil(t, err)
expectEquals(t, identity.Username, "some@email.com")
expectEquals(t, identity.UserID, "12345678")
@ -435,7 +435,7 @@ func TestGroupsWithPermission(t *testing.T) {
expectNil(t, err)
c := gitlabConnector{baseURL: s.URL, httpClient: newClient(), getGroupsPermission: true}
identity, err := c.HandleCallback(connector.Scopes{Groups: true}, req)
identity, err := c.HandleCallback(connector.Scopes{Groups: true}, nil, req)
expectNil(t, err)
expectEquals(t, identity.Groups, []string{

8
connector/google/google.go

@ -168,9 +168,9 @@ func (c *googleConnector) Close() error {
return nil
}
func (c *googleConnector) LoginURL(s connector.Scopes, callbackURL, state string) (string, error) {
func (c *googleConnector) LoginURL(s connector.Scopes, callbackURL, state string) (string, []byte, error) {
if c.redirectURI != callbackURL {
return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI)
return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI)
}
var opts []oauth2.AuthCodeOption
@ -186,7 +186,7 @@ func (c *googleConnector) LoginURL(s connector.Scopes, callbackURL, state string
opts = append(opts, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", c.promptType))
}
return c.oauth2Config.AuthCodeURL(state, opts...), nil
return c.oauth2Config.AuthCodeURL(state, opts...), nil, nil
}
type oauth2Error struct {
@ -201,7 +201,7 @@ func (e *oauth2Error) Error() string {
return e.error + ": " + e.errorDescription
}
func (c *googleConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) {
func (c *googleConnector) HandleCallback(s connector.Scopes, connData []byte, r *http.Request) (identity connector.Identity, err error) {
q := r.URL.Query()
if errType := q.Get("error"); errType != "" {
return identity, &oauth2Error{errType, q.Get("error_description")}

2
connector/google/google_test.go

@ -439,7 +439,7 @@ func TestPromptTypeConfig(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, test.expectedPromptTypeValue, conn.promptType)
loginURL, err := conn.LoginURL(connector.Scopes{OfflineAccess: true}, ts.URL+"/callback", "state")
loginURL, _, err := conn.LoginURL(connector.Scopes{OfflineAccess: true}, ts.URL+"/callback", "state")
assert.Nil(t, err)
urlp, err := url.Parse(loginURL)

8
connector/linkedin/linkedin.go

@ -62,17 +62,17 @@ var (
)
// LoginURL returns an access token request URL
func (c *linkedInConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) {
func (c *linkedInConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, []byte, error) {
if c.oauth2Config.RedirectURL != callbackURL {
return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q",
return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q",
callbackURL, c.oauth2Config.RedirectURL)
}
return c.oauth2Config.AuthCodeURL(state), nil
return c.oauth2Config.AuthCodeURL(state), nil, nil
}
// HandleCallback handles HTTP redirect from LinkedIn
func (c *linkedInConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) {
func (c *linkedInConnector) HandleCallback(s connector.Scopes, connData []byte, r *http.Request) (identity connector.Identity, err error) {
q := r.URL.Query()
if errType := q.Get("error"); errType != "" {
return identity, &oauth2Error{errType, q.Get("error_description")}

8
connector/microsoft/microsoft.go

@ -175,9 +175,9 @@ func (c *microsoftConnector) oauth2Config(scopes connector.Scopes) *oauth2.Confi
}
}
func (c *microsoftConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) {
func (c *microsoftConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, []byte, error) {
if c.redirectURI != callbackURL {
return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI)
return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI)
}
var options []oauth2.AuthCodeOption
@ -188,10 +188,10 @@ func (c *microsoftConnector) LoginURL(scopes connector.Scopes, callbackURL, stat
options = append(options, oauth2.SetAuthURLParam("domain_hint", c.domainHint))
}
return c.oauth2Config(scopes).AuthCodeURL(state, options...), nil
return c.oauth2Config(scopes).AuthCodeURL(state, options...), nil, nil
}
func (c *microsoftConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) {
func (c *microsoftConnector) HandleCallback(s connector.Scopes, connData []byte, r *http.Request) (identity connector.Identity, err error) {
q := r.URL.Query()
if errType := q.Get("error"); errType != "" {
return identity, &oauth2Error{errType, q.Get("error_description")}

8
connector/microsoft/microsoft_test.go

@ -39,7 +39,7 @@ func TestLoginURL(t *testing.T) {
tenant: tenant,
}
loginURL, _ := conn.LoginURL(connector.Scopes{}, conn.redirectURI, testState)
loginURL, _, _ := conn.LoginURL(connector.Scopes{}, conn.redirectURI, testState)
parsedLoginURL, _ := url.Parse(loginURL)
queryParams := parsedLoginURL.Query()
@ -70,7 +70,7 @@ func TestLoginURLWithOptions(t *testing.T) {
domainHint: domainHint,
}
loginURL, _ := conn.LoginURL(connector.Scopes{}, conn.redirectURI, "some-state")
loginURL, _, _ := conn.LoginURL(connector.Scopes{}, conn.redirectURI, "some-state")
parsedLoginURL, _ := url.Parse(loginURL)
queryParams := parsedLoginURL.Query()
@ -91,7 +91,7 @@ func TestUserIdentityFromGraphAPI(t *testing.T) {
req, _ := http.NewRequest("GET", s.URL, nil)
c := microsoftConnector{apiURL: s.URL, graphURL: s.URL, tenant: tenant}
identity, err := c.HandleCallback(connector.Scopes{Groups: false}, req)
identity, err := c.HandleCallback(connector.Scopes{Groups: false}, nil, req)
expectNil(t, err)
expectEquals(t, identity.Username, "Jane Doe")
expectEquals(t, identity.UserID, "S56767889")
@ -114,7 +114,7 @@ func TestUserGroupsFromGraphAPI(t *testing.T) {
req, _ := http.NewRequest("GET", s.URL, nil)
c := microsoftConnector{apiURL: s.URL, graphURL: s.URL, tenant: tenant}
identity, err := c.HandleCallback(connector.Scopes{Groups: true}, req)
identity, err := c.HandleCallback(connector.Scopes{Groups: true}, nil, req)
expectNil(t, err)
expectEquals(t, identity.Groups, []string{"a", "b"})
}

8
connector/mock/connectortest.go

@ -43,21 +43,21 @@ type Callback struct {
}
// LoginURL returns the URL to redirect the user to login with.
func (m *Callback) LoginURL(s connector.Scopes, callbackURL, state string) (string, error) {
func (m *Callback) LoginURL(s connector.Scopes, callbackURL, state string) (string, []byte, error) {
u, err := url.Parse(callbackURL)
if err != nil {
return "", fmt.Errorf("failed to parse callbackURL %q: %v", callbackURL, err)
return "", nil, fmt.Errorf("failed to parse callbackURL %q: %v", callbackURL, err)
}
v := u.Query()
v.Set("state", state)
u.RawQuery = v.Encode()
return u.String(), nil
return u.String(), nil, nil
}
var connectorData = []byte("foobar")
// HandleCallback parses the request and returns the user's identity
func (m *Callback) HandleCallback(s connector.Scopes, r *http.Request) (connector.Identity, error) {
func (m *Callback) HandleCallback(s connector.Scopes, connData []byte, r *http.Request) (connector.Identity, error) {
return m.Identity, nil
}

101
connector/oidc/oidc.go

@ -21,6 +21,20 @@ import (
"github.com/dexidp/dex/pkg/httpclient"
)
const (
codeChallengeMethodPlain = "plain"
codeChallengeMethodS256 = "S256"
)
func contains(arr []string, item string) bool {
for _, itemFromArray := range arr {
if itemFromArray == item {
return true
}
}
return false
}
// Config holds configuration options for OpenID Connect logins.
type Config struct {
Issuer string `json:"issuer"`
@ -84,6 +98,10 @@ type Config struct {
// PromptType will be used for the prompt parameter (when offline_access, by default prompt=consent)
PromptType *string `json:"promptType"`
// PKCEChallenge specifies which PKCE algorithm will be used
// If not setted it will be auto-detected the best-fit for the connector.
PKCEChallenge string `json:"pkceChallenge"`
// OverrideClaimMapping will be used to override the options defined in claimMappings.
// i.e. if there are 'email' and `preferred_email` claims available, by default Dex will always use the `email` claim independent of the ClaimMapping.EmailKey.
// This setting allows you to override the default behavior of Dex and enforce the mappings defined in `claimMapping`.
@ -224,6 +242,25 @@ func knownBrokenAuthHeaderProvider(issuerURL string) bool {
return false
}
// PKCEChallengeData is used to store info for PKCE Challenge method and verifier
// in the connectorData
type PKCEChallengeData struct {
CodeChallenge string `json:"codeChallenge"`
CodeChallengeMethod string `json:"codeChallengeMethod"`
}
// Returns an AuthCodeOption according to the provided codeChallengeMethod
func getAuthCodeOptionForCodeChallenge(codeVerifier, codeChallengeMethod string) (oauth2.AuthCodeOption, error) {
switch codeChallengeMethod {
case codeChallengeMethodPlain:
return oauth2.VerifierOption(codeVerifier), nil
case codeChallengeMethodS256:
return oauth2.S256ChallengeOption(codeVerifier), nil
default:
return nil, fmt.Errorf("unknown challenge method (%v)", codeChallengeMethod)
}
}
// Open returns a connector which can be used to login users through an upstream
// OpenID Connect provider.
func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector, err error) {
@ -282,6 +319,27 @@ func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector,
}
}
// Obtain CodeChallengeMethodsSupported from the provider
var metadata struct {
CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported"`
}
if err := provider.Claims(&metadata); err != nil {
logger.Warn("failed to parse provider metadata")
}
// if PKCEChallenge method has not been setted in the config, auto-detect the best fit
if c.PKCEChallenge == "" {
if contains(metadata.CodeChallengeMethodsSupported, codeChallengeMethodS256) {
c.PKCEChallenge = codeChallengeMethodS256
} else if contains(metadata.CodeChallengeMethodsSupported, codeChallengeMethodPlain) {
c.PKCEChallenge = codeChallengeMethodPlain
}
} else {
// if PKCEChallenge method has been setted in the config, check if it is supported
if !contains(metadata.CodeChallengeMethodsSupported, c.PKCEChallenge) {
logger.Warn("provided PKCEChallenge method not supported by the connector")
}
}
clientID := c.ClientID
return &oidcConnector{
provider: provider,
@ -316,6 +374,7 @@ func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector,
groupsFilter: groupsFilter,
groupsPrefix: c.ClaimMutations.ModifyGroupNames.Prefix,
groupsSuffix: c.ClaimMutations.ModifyGroupNames.Suffix,
pkceChallenge: c.PKCEChallenge,
}, nil
}
@ -348,6 +407,7 @@ type oidcConnector struct {
groupsFilter *regexp.Regexp
groupsPrefix string
groupsSuffix string
pkceChallenge string
}
func (c *oidcConnector) Close() error {
@ -355,12 +415,13 @@ func (c *oidcConnector) Close() error {
return nil
}
func (c *oidcConnector) LoginURL(s connector.Scopes, callbackURL, state string) (string, error) {
func (c *oidcConnector) LoginURL(s connector.Scopes, callbackURL, state string) (string, []byte, error) {
if c.redirectURI != callbackURL {
return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI)
return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI)
}
var opts []oauth2.AuthCodeOption
var connectorData []byte
if len(c.acrValues) > 0 {
acrValues := strings.Join(c.acrValues, " ")
@ -370,7 +431,25 @@ func (c *oidcConnector) LoginURL(s connector.Scopes, callbackURL, state string)
if s.OfflineAccess {
opts = append(opts, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", c.promptType))
}
return c.oauth2Config.AuthCodeURL(state, opts...), nil
if c.pkceChallenge != "" {
codeVerifier := oauth2.GenerateVerifier()
authCodeOption, err := getAuthCodeOptionForCodeChallenge(codeVerifier, c.pkceChallenge)
if err != nil {
return "", nil, fmt.Errorf("oidc: failed to get PKCE AuthCodeOption for CodeChallenge: %v", err)
}
data := PKCEChallengeData{
CodeChallenge: codeVerifier,
CodeChallengeMethod: c.pkceChallenge,
}
connectorData, err = json.Marshal(data)
if err != nil {
return "", nil, fmt.Errorf("oidc: failed to create PKCEChallenge data: %v", err)
}
opts = append(opts, authCodeOption)
}
return c.oauth2Config.AuthCodeURL(state, opts...), connectorData, nil
}
type oauth2Error struct {
@ -393,7 +472,7 @@ const (
exchangeCaller
)
func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) {
func (c *oidcConnector) HandleCallback(s connector.Scopes, connData []byte, r *http.Request) (identity connector.Identity, err error) {
q := r.URL.Query()
if errType := q.Get("error"); errType != "" {
return identity, &oauth2Error{errType, q.Get("error_description")}
@ -401,7 +480,19 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide
ctx := context.WithValue(r.Context(), oauth2.HTTPClient, c.httpClient)
token, err := c.oauth2Config.Exchange(ctx, q.Get("code"))
var opts []oauth2.AuthCodeOption
if c.pkceChallenge != "" {
var data PKCEChallengeData
if err := json.Unmarshal(connData, &data); err != nil {
return identity, fmt.Errorf("oidc: failed to parse PKCEChallenge data: %v", err)
}
if data.CodeChallenge == "" {
return identity, fmt.Errorf("oidc: invalid PKCE CodeChallenge")
}
opts = append(opts, oauth2.VerifierOption(data.CodeChallenge))
}
token, err := c.oauth2Config.Exchange(ctx, q.Get("code"), opts...)
if err != nil {
return identity, fmt.Errorf("oidc: failed to get token: %v", err)
}

42
connector/oidc/oidc_test.go

@ -67,6 +67,7 @@ func TestHandleCallback(t *testing.T) {
newGroupFromClaims []NewGroupFromClaims
groupsPrefix string
groupsSuffix string
pkceChallenge string
}{
{
name: "simpleCase",
@ -484,6 +485,40 @@ func TestHandleCallback(t *testing.T) {
"email_verified": true,
},
},
{
name: "S256PKCEChallenge",
userIDKey: "", // not configured
userNameKey: "", // not configured
pkceChallenge: "S256",
expectUserID: "subvalue",
expectUserName: "namevalue",
expectGroups: []string{"group1", "group2"},
expectedEmailField: "emailvalue",
token: map[string]interface{}{
"sub": "subvalue",
"name": "namevalue",
"groups": []string{"group1", "group2"},
"email": "emailvalue",
"email_verified": true,
},
},
{
name: "plainPKCEChallenge",
userIDKey: "", // not configured
userNameKey: "", // not configured
pkceChallenge: "plain",
expectUserID: "subvalue",
expectUserName: "namevalue",
expectGroups: []string{"group1", "group2"},
expectedEmailField: "emailvalue",
token: map[string]interface{}{
"sub": "subvalue",
"name": "namevalue",
"groups": []string{"group1", "group2"},
"email": "emailvalue",
"email_verified": true,
},
},
}
for _, tc := range tests {
@ -515,6 +550,7 @@ func TestHandleCallback(t *testing.T) {
InsecureEnableGroups: true,
BasicAuthUnsupported: &basicAuth,
OverrideClaimMapping: tc.overrideClaimMapping,
PKCEChallenge: tc.pkceChallenge,
}
config.ClaimMapping.PreferredUsernameKey = tc.preferredUsernameKey
config.ClaimMapping.EmailKey = tc.emailKey
@ -534,7 +570,11 @@ func TestHandleCallback(t *testing.T) {
t.Fatal("failed to create request", err)
}
identity, err := conn.HandleCallback(connector.Scopes{Groups: true}, req)
connectorDataStrTemplate := `{"codeChallenge":"abcdefgh123456qwertuiop89101112uvpwizABC234","codeChallengeMethod":"%s"}`
connectorDataStr := fmt.Sprintf(connectorDataStrTemplate, config.PKCEChallenge)
connectorData := []byte(connectorDataStr)
identity, err := conn.HandleCallback(connector.Scopes{Groups: true}, connectorData, req)
if err != nil {
t.Fatal("handle callback failed", err)
}

7
connector/openshift/openshift.go

@ -138,12 +138,12 @@ func (c *openshiftConnector) Close() error {
}
// LoginURL returns the URL to redirect the user to login with.
func (c *openshiftConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) {
func (c *openshiftConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, []byte, error) {
if c.redirectURI != callbackURL {
return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q",
return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q",
callbackURL, c.redirectURI)
}
return c.oauth2Config.AuthCodeURL(state), nil
return c.oauth2Config.AuthCodeURL(state), nil, nil
}
type oauth2Error struct {
@ -160,6 +160,7 @@ func (e *oauth2Error) Error() string {
// HandleCallback parses the request and returns the user's identity
func (c *openshiftConnector) HandleCallback(s connector.Scopes,
connData []byte,
r *http.Request,
) (identity connector.Identity, err error) {
q := r.URL.Query()

2
connector/openshift/openshift_test.go

@ -175,7 +175,7 @@ func TestCallbackIdentity(t *testing.T) {
TokenURL: fmt.Sprintf("%s/oauth/token", s.URL),
},
}}
identity, err := oc.HandleCallback(connector.Scopes{Groups: true}, req)
identity, err := oc.HandleCallback(connector.Scopes{Groups: true}, nil, req)
expectNil(t, err)
expectEquals(t, identity.UserID, "12345")

16
server/handlers.go

@ -271,12 +271,24 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
// Use the auth request ID as the "state" token.
//
// TODO(ericchiang): Is this appropriate or should we also be using a nonce?
callbackURL, err := conn.LoginURL(scopes, s.absURL("/callback"), authReq.ID)
callbackURL, connData, err := conn.LoginURL(scopes, s.absURL("/callback"), authReq.ID)
if err != nil {
s.logger.ErrorContext(r.Context(), "connector returned error when creating callback", "connector_id", connID, "err", err)
s.renderError(r, w, http.StatusInternalServerError, "Login error.")
return
}
if len(connData) > 0 {
updater := func(a storage.AuthRequest) (storage.AuthRequest, error) {
a.ConnectorData = connData
return a, nil
}
err := s.storage.UpdateAuthRequest(ctx, authReq.ID, updater)
if err != nil {
s.logger.ErrorContext(r.Context(), "Failed to set connector data on auth request", "connector_id", connID, "err", err)
s.renderError(r, w, http.StatusInternalServerError, "Database error.")
return
}
}
http.Redirect(w, r, callbackURL, http.StatusFound)
case connector.PasswordConnector:
loginURL := url.URL{
@ -472,7 +484,7 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request)
s.renderError(r, w, http.StatusBadRequest, "Invalid request")
return
}
identity, err = conn.HandleCallback(parseScopes(authReq.Scopes), r)
identity, err = conn.HandleCallback(parseScopes(authReq.Scopes), authReq.ConnectorData, r)
case connector.SAMLConnector:
if r.Method != http.MethodPost {
s.logger.ErrorContext(r.Context(), "OAuth2 request mapped to SAML connector")

Loading…
Cancel
Save