diff --git a/server/handlers.go b/server/handlers.go index 522f3700..ac73bee4 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -718,9 +718,10 @@ func (s *Server) finalizeLogin(ctx context.Context, identity connector.Identity, } // Create or update UserIdentity to persist user claims across sessions. + var userIdentity *storage.UserIdentity if featureflags.SessionsEnabled.Enabled() { now := s.now() - _, err := s.storage.GetUserIdentity(ctx, identity.UserID, authReq.ConnectorID) + existing, err := s.storage.GetUserIdentity(ctx, identity.UserID, authReq.ConnectorID) switch { case err == storage.ErrNotFound: ui := storage.UserIdentity{ @@ -733,8 +734,12 @@ func (s *Server) finalizeLogin(ctx context.Context, identity connector.Identity, } if err := s.storage.CreateUserIdentity(ctx, ui); err != nil { s.logger.ErrorContext(ctx, "failed to create user identity", "err", err) + } else { + userIdentity = &ui } case err == nil: + existing.Claims = claims + existing.LastLogin = now if err := s.storage.UpdateUserIdentity(ctx, identity.UserID, authReq.ConnectorID, func(old storage.UserIdentity) (storage.UserIdentity, error) { old.Claims = claims old.LastLogin = now @@ -742,6 +747,7 @@ func (s *Server) finalizeLogin(ctx context.Context, identity connector.Identity, }); err != nil { s.logger.ErrorContext(ctx, "failed to update user identity", "err", err) } + userIdentity = &existing default: s.logger.ErrorContext(ctx, "failed to get user identity", "err", err) } @@ -753,8 +759,8 @@ func (s *Server) finalizeLogin(ctx context.Context, identity connector.Identity, } // Skip approval if user already consented to the requested scopes for this client. - if !authReq.ForceApprovalPrompt && featureflags.SessionsEnabled.Enabled() { - if s.hasExistingConsent(ctx, identity.UserID, authReq.ConnectorID, authReq.ClientID, authReq.Scopes) { + if !authReq.ForceApprovalPrompt && userIdentity != nil { + if scopesCoveredByConsent(userIdentity.Consents[authReq.ClientID], authReq.Scopes) { return "", true, nil } } @@ -982,26 +988,15 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe http.Redirect(w, r, u.String(), http.StatusSeeOther) } -// hasExistingConsent checks whether the user has already consented to the requested -// scopes for the given client. Technical scopes (openid, offline_access) are excluded -// from the comparison. Returns false on any error as a safe default. -func (s *Server) hasExistingConsent(ctx context.Context, userID, connectorID, clientID string, scopes []string) bool { - ui, err := s.storage.GetUserIdentity(ctx, userID, connectorID) - if err != nil { - return false - } - - approved, ok := ui.Consents[clientID] - if !ok { - return false - } - +// scopesCoveredByConsent checks whether the approved scopes cover all requested scopes. +// Technical scopes (openid, offline_access) are excluded from the comparison. +func scopesCoveredByConsent(approved, requested []string) bool { approvedSet := make(map[string]bool, len(approved)) for _, s := range approved { approvedSet[s] = true } - for _, scope := range scopes { + for _, scope := range requested { if scope == scopeOpenID || scope == scopeOfflineAccess { continue } diff --git a/server/handlers_test.go b/server/handlers_test.go index 5eb1b564..23196a27 100644 --- a/server/handlers_test.go +++ b/server/handlers_test.go @@ -826,76 +826,48 @@ func TestConsentPersistedOnApproval(t *testing.T) { require.Equal(t, []string{"openid", "email", "profile"}, ui.Consents[clientID]) } -func TestHasExistingConsent(t *testing.T) { - ctx := t.Context() - - httpServer, s := newTestServer(t, nil) - defer httpServer.Close() - - userID := "consent-user" - connID := "mock" - +func TestScopesCoveredByConsent(t *testing.T) { tests := []struct { - name string - consents map[string][]string - clientID string - scopes []string - want bool + name string + approved []string + requested []string + want bool }{ { - name: "All scopes covered", - consents: map[string][]string{"client-a": {"email", "profile"}}, - clientID: "client-a", - scopes: []string{"openid", "email", "profile"}, - want: true, + name: "All scopes covered", + approved: []string{"email", "profile"}, + requested: []string{"openid", "email", "profile"}, + want: true, }, { - name: "Missing scope", - consents: map[string][]string{"client-a": {"email"}}, - clientID: "client-a", - scopes: []string{"openid", "email", "groups"}, - want: false, + name: "Missing scope", + approved: []string{"email"}, + requested: []string{"openid", "email", "groups"}, + want: false, }, { - name: "Only technical scopes", - consents: map[string][]string{"client-a": {}}, - clientID: "client-a", - scopes: []string{"openid", "offline_access"}, - want: true, + name: "Only technical scopes", + approved: []string{}, + requested: []string{"openid", "offline_access"}, + want: true, }, { - name: "No consent for client", - consents: map[string][]string{"other": {"email"}}, - clientID: "client-a", - scopes: []string{"email"}, - want: false, + name: "Nil approved", + approved: nil, + requested: []string{"email"}, + want: false, }, { - name: "No UserIdentity at all", - consents: nil, - clientID: "client-a", - scopes: []string{"email"}, - want: false, + name: "Empty requested", + approved: []string{"email"}, + requested: []string{}, + want: true, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - // Clean up any previous identity - _ = s.storage.DeleteUserIdentity(ctx, userID, connID) - - if tc.consents != nil { - require.NoError(t, s.storage.CreateUserIdentity(ctx, storage.UserIdentity{ - UserID: userID, - ConnectorID: connID, - Claims: storage.Claims{UserID: userID}, - Consents: tc.consents, - CreatedAt: time.Now(), - LastLogin: time.Now(), - })) - } - - got := s.hasExistingConsent(ctx, userID, connID, tc.clientID, tc.scopes) + got := scopesCoveredByConsent(tc.approved, tc.requested) require.Equal(t, tc.want, got) }) }