diff --git a/connector/oidc/oidc.go b/connector/oidc/oidc.go index 8e1fe724..1c152a1d 100644 --- a/connector/oidc/oidc.go +++ b/connector/oidc/oidc.go @@ -107,6 +107,11 @@ type Config struct { // This setting allows you to override the default behavior of Dex and enforce the mappings defined in `claimMapping`. OverrideClaimMapping bool `json:"overrideClaimMapping"` // defaults to false + // ForceQueryResponseModeSet ensures the `response_mode` query parameter to be explicitly set in the LoginURL. + // Although the OIDC specification defines query to be the default, + // some implementations require this parameter to be set in order to deliver the code as a query. + ForceQueryResponseMode bool `json:"forceQueryResponseMode"` + ClaimMapping struct { // Configurable key which contains the preferred username claims PreferredUsernameKey string `json:"preferred_username"` // defaults to "preferred_username" @@ -367,6 +372,7 @@ func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector, userIDKey: c.UserIDKey, userNameKey: c.UserNameKey, overrideClaimMapping: c.OverrideClaimMapping, + forceQueryResponseMode: c.ForceQueryResponseMode, preferredUsernameKey: c.ClaimMapping.PreferredUsernameKey, emailKey: c.ClaimMapping.EmailKey, groupsKey: c.ClaimMapping.GroupsKey, @@ -401,6 +407,7 @@ type oidcConnector struct { userIDKey string userNameKey string overrideClaimMapping bool + forceQueryResponseMode bool preferredUsernameKey string emailKey string groupsKey string @@ -429,6 +436,10 @@ func (c *oidcConnector) LoginURL(s connector.Scopes, callbackURL, state string) opts = append(opts, oauth2.SetAuthURLParam("acr_values", acrValues)) } + if c.forceQueryResponseMode { + opts = append(opts, oauth2.SetAuthURLParam("response_mode", "query")) + } + if s.OfflineAccess { opts = append(opts, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", c.promptType)) } diff --git a/connector/oidc/oidc_test.go b/connector/oidc/oidc_test.go index 71a30b6e..9cb5952f 100644 --- a/connector/oidc/oidc_test.go +++ b/connector/oidc/oidc_test.go @@ -13,6 +13,7 @@ import ( "log/slog" "net/http" "net/http/httptest" + "net/url" "reflect" "strings" "testing" @@ -44,6 +45,91 @@ func TestKnownBrokenAuthHeaderProvider(t *testing.T) { } } +func TestLoginURL(t *testing.T) { + t.Helper() + + tests := []struct { + name string + forceQueryResponseModeSet bool + expectedURLValues map[string]string + expectedNonExistingURLValues []string + }{ + { + name: "default", + expectedURLValues: map[string]string{ + "scope": "openid email groups", + }, + expectedNonExistingURLValues: []string{"response_mode"}, + }, + { + name: "forceResponseMode", + forceQueryResponseModeSet: true, + expectedURLValues: map[string]string{ + "scope": "openid email groups", + "response_mode": "query", + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + idTokenDesired := true + testServer, err := setupServer(nil, idTokenDesired) + if err != nil { + t.Fatal("failed to setup test server", err) + } + defer testServer.Close() + + scopes := []string{"email", "groups"} + + serverURL := testServer.URL + config := Config{ + Issuer: serverURL, + ClientID: "clientID", + ClientSecret: "clientSecret", + Scopes: scopes, + RedirectURI: fmt.Sprintf("%s/callback", serverURL), + InsecureEnableGroups: true, + } + if tc.forceQueryResponseModeSet { + config.ForceQueryResponseMode = true + } + + conn, err := newConnector(config) + if err != nil { + t.Fatal("failed to create new connector", err) + } + + state := "" // state is handled by the oAuth library, so no need to test it + loginURL, err := conn.LoginURL(connector.Scopes{}, config.RedirectURI, state) + if err != nil { + t.Error("unable to get login url", err) + } + + u, err := url.Parse(loginURL) + if err != nil { + t.Fatal("failed to parse login url", err) + } + + actualQueryParams := u.Query() + + if tc.expectedURLValues != nil { + for param, values := range tc.expectedURLValues { + expectEquals(t, actualQueryParams.Get(param), values) + } + } + + if tc.expectedNonExistingURLValues != nil { + for _, param := range tc.expectedNonExistingURLValues { + if actualQueryParams.Has(param) { + t.Error("Unexpected query param found", param) + } + } + } + }) + } +} + func TestHandleCallback(t *testing.T) { t.Helper()