From 12339f2cef5a367bcafcb5a275e6c078f54ca6c4 Mon Sep 17 00:00:00 2001 From: Maksim Nabokikh Date: Mon, 16 Mar 2026 13:53:27 +0100 Subject: [PATCH] feat: implement user identity creation and persisting consent (#4645) Signed-off-by: maksim.nabokikh --- server/handlers.go | 83 +++++++++ server/handlers_test.go | 393 ++++++++++++++++++++++++++++++++++++++++ server/server.go | 5 + 3 files changed, 481 insertions(+) diff --git a/server/handlers.go b/server/handlers.go index 815b4451..20fd85bf 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -23,6 +23,7 @@ import ( "github.com/gorilla/mux" "github.com/dexidp/dex/connector" + "github.com/dexidp/dex/pkg/featureflags" "github.com/dexidp/dex/server/internal" "github.com/dexidp/dex/storage" ) @@ -717,11 +718,60 @@ func (s *Server) finalizeLogin(ctx context.Context, identity connector.Identity, } } + // Create or update UserIdentity to persist user claims across sessions. + var userIdentity *storage.UserIdentity + if featureflags.SessionsEnabled.Enabled() { + now := s.now() + + ui, err := s.storage.GetUserIdentity(ctx, identity.UserID, authReq.ConnectorID) + switch { + case err != nil && errors.Is(err, storage.ErrNotFound): + ui = storage.UserIdentity{ + UserID: identity.UserID, + ConnectorID: authReq.ConnectorID, + Claims: claims, + Consents: make(map[string][]string), + CreatedAt: now, + LastLogin: now, + } + if err := s.storage.CreateUserIdentity(ctx, ui); err != nil { + s.logger.ErrorContext(ctx, "failed to create user identity", "err", err) + return "", false, err + } + case err == nil: + if err := s.storage.UpdateUserIdentity(ctx, identity.UserID, authReq.ConnectorID, func(old storage.UserIdentity) (storage.UserIdentity, error) { + if len(identity.ConnectorData) > 0 { + old.Claims = claims + old.LastLogin = now + return old, nil + } + return old, nil + }); err != nil { + s.logger.ErrorContext(ctx, "failed to update user identity", "err", err) + return "", false, err + } + // Update the existing UserIdentity obj with new claims to use them later in the flow. + ui.Claims = claims + ui.LastLogin = now + default: + s.logger.ErrorContext(ctx, "failed to get user identity", "err", err) + return "", false, err + } + userIdentity = &ui + } + // we can skip the redirect to /approval and go ahead and send code if it's not required if s.skipApproval && !authReq.ForceApprovalPrompt { return "", true, nil } + // Skip approval if user already consented to the requested scopes for this client. + if !authReq.ForceApprovalPrompt && userIdentity != nil { + if scopesCoveredByConsent(userIdentity.Consents[authReq.ClientID], authReq.Scopes) { + return "", true, nil + } + } + // an HMAC is used here to ensure that the request ID is unpredictable, ensuring that an attacker who intercepted the original // flow would be unable to poll for the result at the /approval endpoint h := hmac.New(sha256.New, authReq.HMACKey) @@ -787,6 +837,18 @@ func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) { s.renderError(r, w, http.StatusInternalServerError, "Approval rejected.") return } + // Persist user-approved scopes as consent for this client. + if featureflags.SessionsEnabled.Enabled() { + if err := s.storage.UpdateUserIdentity(ctx, authReq.Claims.UserID, authReq.ConnectorID, func(old storage.UserIdentity) (storage.UserIdentity, error) { + if old.Consents == nil { + old.Consents = make(map[string][]string) + } + old.Consents[authReq.ClientID] = authReq.Scopes + return old, nil + }); err != nil { + s.logger.ErrorContext(ctx, "failed to update user identity consents", "err", err) + } + } s.sendCodeResponse(w, r, authReq) } } @@ -932,6 +994,27 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe http.Redirect(w, r, u.String(), http.StatusSeeOther) } +// scopesCoveredByConsent checks whether the approved scopes cover all requested scopes. +// The openid scope is excluded from the comparison as it is a technical scope +// that does not require user consent. +func scopesCoveredByConsent(approved, requested []string) bool { + approvedSet := make(map[string]struct{}, len(approved)) + for _, s := range approved { + approvedSet[s] = struct{}{} + } + + for _, scope := range requested { + if scope == scopeOpenID { + continue + } + if _, ok := approvedSet[scope]; !ok { + return false + } + } + + return true +} + func (s *Server) withClientFromStorage(w http.ResponseWriter, r *http.Request, handler func(http.ResponseWriter, *http.Request, storage.Client)) { ctx := r.Context() clientID, clientSecret, ok := r.BasicAuth() diff --git a/server/handlers_test.go b/server/handlers_test.go index 12c664f5..933e9e4d 100644 --- a/server/handlers_test.go +++ b/server/handlers_test.go @@ -3,6 +3,9 @@ package server import ( "bytes" "context" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -495,6 +498,396 @@ func TestHandlePassword_LocalPasswordDBClaims(t *testing.T) { require.Equal(t, []string{"team-a", "team-a/admins"}, claims.Groups) } +func setSessionsEnabled(t *testing.T, enabled bool) { + t.Helper() + if enabled { + t.Setenv("DEX_SESSIONS_ENABLED", "true") + } else { + t.Setenv("DEX_SESSIONS_ENABLED", "false") + } +} + +func TestFinalizeLoginCreatesUserIdentity(t *testing.T) { + ctx := t.Context() + setSessionsEnabled(t, true) + + connID := "mockPw" + authReqID := "test-create-ui" + expiry := time.Now().Add(100 * time.Second) + + httpServer, s := newTestServer(t, func(c *Config) { + c.SkipApprovalScreen = true + c.Now = time.Now + }) + defer httpServer.Close() + + sc := storage.Connector{ + ID: connID, + Type: "mockPassword", + Name: "MockPassword", + ResourceVersion: "1", + Config: []byte(`{"username": "foo", "password": "password"}`), + } + require.NoError(t, s.storage.CreateConnector(ctx, sc)) + _, err := s.OpenConnector(sc) + require.NoError(t, err) + + authReq := storage.AuthRequest{ + ID: authReqID, + ConnectorID: connID, + RedirectURI: "cb", + Expiry: expiry, + ResponseTypes: []string{responseTypeCode}, + } + require.NoError(t, s.storage.CreateAuthRequest(ctx, authReq)) + + rr := httptest.NewRecorder() + reqPath := fmt.Sprintf("/auth/%s/login?state=%s&back=&login=foo&password=password", connID, authReqID) + s.handlePasswordLogin(rr, httptest.NewRequest("POST", reqPath, nil)) + + require.Equal(t, 303, rr.Code) + + ui, err := s.storage.GetUserIdentity(ctx, "0-385-28089-0", connID) + require.NoError(t, err, "UserIdentity should exist after login") + require.Equal(t, "0-385-28089-0", ui.UserID) + require.Equal(t, connID, ui.ConnectorID) + require.Equal(t, "kilgore@kilgore.trout", ui.Claims.Email) + require.NotZero(t, ui.CreatedAt, "CreatedAt should be set") + require.NotZero(t, ui.LastLogin, "LastLogin should be set") +} + +func TestFinalizeLoginUpdatesUserIdentity(t *testing.T) { + ctx := t.Context() + setSessionsEnabled(t, true) + + connID := "mockPw" + authReqID := "test-update-ui" + expiry := time.Now().Add(100 * time.Second) + oldTime := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) + + httpServer, s := newTestServer(t, func(c *Config) { + c.SkipApprovalScreen = true + c.Now = time.Now + }) + defer httpServer.Close() + + sc := storage.Connector{ + ID: connID, + Type: "mockPassword", + Name: "MockPassword", + ResourceVersion: "1", + Config: []byte(`{"username": "foo", "password": "password"}`), + } + require.NoError(t, s.storage.CreateConnector(ctx, sc)) + _, err := s.OpenConnector(sc) + require.NoError(t, err) + + // Pre-create UserIdentity with old data + require.NoError(t, s.storage.CreateUserIdentity(ctx, storage.UserIdentity{ + UserID: "0-385-28089-0", + ConnectorID: connID, + Claims: storage.Claims{ + UserID: "0-385-28089-0", + Username: "Old Name", + Email: "old@example.com", + }, + Consents: map[string][]string{"existing-client": {"openid"}}, + CreatedAt: oldTime, + LastLogin: oldTime, + })) + + authReq := storage.AuthRequest{ + ID: authReqID, + ConnectorID: connID, + RedirectURI: "cb", + Expiry: expiry, + ResponseTypes: []string{responseTypeCode}, + } + require.NoError(t, s.storage.CreateAuthRequest(ctx, authReq)) + + rr := httptest.NewRecorder() + reqPath := fmt.Sprintf("/auth/%s/login?state=%s&back=&login=foo&password=password", connID, authReqID) + s.handlePasswordLogin(rr, httptest.NewRequest("POST", reqPath, nil)) + + require.Equal(t, 303, rr.Code) + + ui, err := s.storage.GetUserIdentity(ctx, "0-385-28089-0", connID) + require.NoError(t, err, "UserIdentity should exist after login") + require.Equal(t, "Kilgore Trout", ui.Claims.Username, "claims should be refreshed from the connector") + require.Equal(t, "kilgore@kilgore.trout", ui.Claims.Email, "claims should be refreshed from the connector") + require.True(t, ui.LastLogin.After(oldTime), "LastLogin should be updated") + require.Equal(t, oldTime, ui.CreatedAt, "CreatedAt should not change on update") + require.Equal(t, []string{"openid"}, ui.Consents["existing-client"], "existing consents should be preserved") +} + +func TestFinalizeLoginSkipsUserIdentityWhenDisabled(t *testing.T) { + ctx := t.Context() + setSessionsEnabled(t, false) + + connID := "mockPw" + authReqID := "test-no-ui" + expiry := time.Now().Add(100 * time.Second) + + httpServer, s := newTestServer(t, func(c *Config) { + c.SkipApprovalScreen = true + c.Now = time.Now + }) + defer httpServer.Close() + + sc := storage.Connector{ + ID: connID, + Type: "mockPassword", + Name: "MockPassword", + ResourceVersion: "1", + Config: []byte(`{"username": "foo", "password": "password"}`), + } + require.NoError(t, s.storage.CreateConnector(ctx, sc)) + _, err := s.OpenConnector(sc) + require.NoError(t, err) + + authReq := storage.AuthRequest{ + ID: authReqID, + ConnectorID: connID, + RedirectURI: "cb", + Expiry: expiry, + ResponseTypes: []string{responseTypeCode}, + } + require.NoError(t, s.storage.CreateAuthRequest(ctx, authReq)) + + rr := httptest.NewRecorder() + reqPath := fmt.Sprintf("/auth/%s/login?state=%s&back=&login=foo&password=password", connID, authReqID) + s.handlePasswordLogin(rr, httptest.NewRequest("POST", reqPath, nil)) + + require.Equal(t, 303, rr.Code) + + _, err = s.storage.GetUserIdentity(ctx, "0-385-28089-0", connID) + require.ErrorIs(t, err, storage.ErrNotFound, "UserIdentity should not be created when sessions disabled") +} + +func TestSkipApprovalWithExistingConsent(t *testing.T) { + ctx := t.Context() + setSessionsEnabled(t, true) + + connID := "mock" + authReqID := "test-consent-skip" + expiry := time.Now().Add(100 * time.Second) + + tests := []struct { + name string + consents map[string][]string + scopes []string + clientID string + forcePrompt bool + wantPath string + }{ + { + name: "Existing consent covers requested scopes", + consents: map[string][]string{"test": {"email", "profile"}}, + scopes: []string{"openid", "email", "profile"}, + clientID: "test", + wantPath: "/callback/cb", + }, + { + name: "Existing consent missing a scope", + consents: map[string][]string{"test": {"email"}}, + scopes: []string{"openid", "email", "profile"}, + clientID: "test", + wantPath: "/approval", + }, + { + name: "Force approval overrides consent", + consents: map[string][]string{"test": {"email", "profile"}}, + scopes: []string{"openid", "email", "profile"}, + clientID: "test", + forcePrompt: true, + wantPath: "/approval", + }, + { + name: "No consent for this client", + consents: map[string][]string{"other-client": {"email"}}, + scopes: []string{"openid", "email"}, + clientID: "test", + wantPath: "/approval", + }, + { + name: "Only openid scope - skip with empty consent", + consents: map[string][]string{"test": {}}, + scopes: []string{"openid"}, + clientID: "test", + wantPath: "/callback/cb", + }, + { + name: "offline_access requires consent", + consents: map[string][]string{"test": {}}, + scopes: []string{"openid", "offline_access"}, + clientID: "test", + wantPath: "/approval", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + httpServer, s := newTestServer(t, func(c *Config) { + c.SkipApprovalScreen = false + c.Now = time.Now + }) + defer httpServer.Close() + + // Pre-create UserIdentity with consents + require.NoError(t, s.storage.CreateUserIdentity(ctx, storage.UserIdentity{ + UserID: "0-385-28089-0", + ConnectorID: connID, + Claims: storage.Claims{ + UserID: "0-385-28089-0", + Username: "Kilgore Trout", + Email: "kilgore@kilgore.trout", + EmailVerified: true, + }, + Consents: tc.consents, + CreatedAt: time.Now(), + LastLogin: time.Now(), + })) + + authReq := storage.AuthRequest{ + ID: authReqID, + ConnectorID: connID, + ClientID: tc.clientID, + RedirectURI: "cb", + Expiry: expiry, + ResponseTypes: []string{responseTypeCode}, + Scopes: tc.scopes, + ForceApprovalPrompt: tc.forcePrompt, + } + require.NoError(t, s.storage.CreateAuthRequest(ctx, authReq)) + + rr := httptest.NewRecorder() + reqPath := fmt.Sprintf("/callback/%s?state=%s", connID, authReqID) + s.handleConnectorCallback(rr, httptest.NewRequest("GET", reqPath, nil)) + + require.Equal(t, 303, rr.Code) + cb, err := url.Parse(rr.Result().Header.Get("Location")) + require.NoError(t, err) + require.Equal(t, tc.wantPath, cb.Path) + }) + } +} + +func TestConsentPersistedOnApproval(t *testing.T) { + ctx := t.Context() + setSessionsEnabled(t, true) + + httpServer, s := newTestServer(t, nil) + defer httpServer.Close() + + userID := "test-user" + connectorID := "mock" + clientID := "test" + + // Pre-create UserIdentity (would have been created during login) + require.NoError(t, s.storage.CreateUserIdentity(ctx, storage.UserIdentity{ + UserID: userID, + ConnectorID: connectorID, + Claims: storage.Claims{UserID: userID}, + Consents: make(map[string][]string), + CreatedAt: time.Now(), + LastLogin: time.Now(), + })) + + authReq := storage.AuthRequest{ + ID: "approval-consent-test", + ClientID: clientID, + ConnectorID: connectorID, + ResponseTypes: []string{responseTypeCode}, + RedirectURI: "https://client.example/callback", + Expiry: time.Now().Add(time.Minute), + LoggedIn: true, + Claims: storage.Claims{UserID: userID}, + Scopes: []string{"openid", "email", "profile"}, + HMACKey: []byte("consent-test-key"), + } + require.NoError(t, s.storage.CreateAuthRequest(ctx, authReq)) + + h := hmac.New(sha256.New, authReq.HMACKey) + h.Write([]byte(authReq.ID)) + mac := base64.RawURLEncoding.EncodeToString(h.Sum(nil)) + + form := url.Values{ + "approval": {"approve"}, + "req": {authReq.ID}, + "hmac": {mac}, + } + + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/approval", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + s.ServeHTTP(rr, req) + + require.Equal(t, http.StatusSeeOther, rr.Code, "approval should redirect") + + ui, err := s.storage.GetUserIdentity(ctx, userID, connectorID) + require.NoError(t, err, "UserIdentity should exist") + require.Equal(t, []string{"openid", "email", "profile"}, ui.Consents[clientID], "approved scopes should be persisted") +} + +func TestScopesCoveredByConsent(t *testing.T) { + tests := []struct { + name string + approved []string + requested []string + want bool + }{ + { + name: "All scopes covered", + approved: []string{"email", "profile"}, + requested: []string{"openid", "email", "profile"}, + want: true, + }, + { + name: "Missing scope", + approved: []string{"email"}, + requested: []string{"openid", "email", "groups"}, + want: false, + }, + { + name: "Only openid scope skipped", + approved: []string{}, + requested: []string{"openid"}, + want: true, + }, + { + name: "offline_access requires consent", + approved: []string{}, + requested: []string{"openid", "offline_access"}, + want: false, + }, + { + name: "offline_access covered by consent", + approved: []string{"offline_access"}, + requested: []string{"openid", "offline_access"}, + want: true, + }, + { + name: "Nil approved", + approved: nil, + requested: []string{"email"}, + want: false, + }, + { + name: "Empty requested", + approved: []string{"email"}, + requested: []string{}, + want: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := scopesCoveredByConsent(tc.approved, tc.requested) + require.Equal(t, tc.want, got) + }) + } +} + func TestHandlePasswordLoginWithSkipApproval(t *testing.T) { ctx := t.Context() diff --git a/server/server.go b/server/server.go index e6945c72..e63cb278 100644 --- a/server/server.go +++ b/server/server.go @@ -45,6 +45,7 @@ import ( "github.com/dexidp/dex/connector/oidc" "github.com/dexidp/dex/connector/openshift" "github.com/dexidp/dex/connector/saml" + "github.com/dexidp/dex/pkg/featureflags" "github.com/dexidp/dex/server/signer" "github.com/dexidp/dex/storage" "github.com/dexidp/dex/web" @@ -377,6 +378,10 @@ func newServer(ctx context.Context, c Config) (*Server, error) { return nil, fmt.Errorf("server: failed to open all connectors (%d/%d)", failedCount, len(storageConnectors)) } + if featureflags.SessionsEnabled.Enabled() { + s.logger.InfoContext(ctx, "sessions feature flag is enabled") + } + instrumentHandler := func(_ string, handler http.Handler) http.HandlerFunc { return handler.ServeHTTP }