From 51574eef3a819ef97e6103ae3006be76758c2614 Mon Sep 17 00:00:00 2001 From: "maksim.nabokikh" Date: Mon, 16 Mar 2026 10:37:06 +0100 Subject: [PATCH] Code review fixes Signed-off-by: maksim.nabokikh --- server/handlers.go | 62 ++++++++++++++++++++++------------------- server/handlers_test.go | 55 +++++++++++++++++++++++------------- 2 files changed, 69 insertions(+), 48 deletions(-) diff --git a/server/handlers.go b/server/handlers.go index ac73bee4..a571db05 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -721,10 +721,11 @@ func (s *Server) finalizeLogin(ctx context.Context, identity connector.Identity, var userIdentity *storage.UserIdentity if featureflags.SessionsEnabled.Enabled() { now := s.now() - existing, err := s.storage.GetUserIdentity(ctx, identity.UserID, authReq.ConnectorID) + + ui, err := s.storage.GetUserIdentity(ctx, identity.UserID, authReq.ConnectorID) switch { - case err == storage.ErrNotFound: - ui := storage.UserIdentity{ + case err != nil && errors.Is(err, storage.ErrNotFound): + ui = storage.UserIdentity{ UserID: identity.UserID, ConnectorID: authReq.ConnectorID, Claims: claims, @@ -734,23 +735,28 @@ func (s *Server) finalizeLogin(ctx context.Context, identity connector.Identity, } if err := s.storage.CreateUserIdentity(ctx, ui); err != nil { s.logger.ErrorContext(ctx, "failed to create user identity", "err", err) - } else { - userIdentity = &ui + return "", false, err } case err == nil: - existing.Claims = claims - existing.LastLogin = now if err := s.storage.UpdateUserIdentity(ctx, identity.UserID, authReq.ConnectorID, func(old storage.UserIdentity) (storage.UserIdentity, error) { - old.Claims = claims - old.LastLogin = now + if len(identity.ConnectorData) > 0 { + old.Claims = claims + old.LastLogin = now + return old, nil + } return old, nil }); err != nil { s.logger.ErrorContext(ctx, "failed to update user identity", "err", err) + return "", false, err } - userIdentity = &existing + // Update the existing UserIdentity obj with new claims to use them later in the flow. + ui.Claims = claims + ui.LastLogin = now default: s.logger.ErrorContext(ctx, "failed to get user identity", "err", err) + return "", false, err } + userIdentity = &ui } // we can skip the redirect to /approval and go ahead and send code if it's not required @@ -830,6 +836,18 @@ func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) { s.renderError(r, w, http.StatusInternalServerError, "Approval rejected.") return } + // Persist user-approved scopes as consent for this client. + if featureflags.SessionsEnabled.Enabled() { + if err := s.storage.UpdateUserIdentity(ctx, authReq.Claims.UserID, authReq.ConnectorID, func(old storage.UserIdentity) (storage.UserIdentity, error) { + if old.Consents == nil { + old.Consents = make(map[string][]string) + } + old.Consents[authReq.ClientID] = authReq.Scopes + return old, nil + }); err != nil { + s.logger.ErrorContext(ctx, "failed to update user identity consents", "err", err) + } + } s.sendCodeResponse(w, r, authReq) } } @@ -972,35 +990,23 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe u.RawQuery = q.Encode() } - // Persist approved scopes as user consent for this client. - if featureflags.SessionsEnabled.Enabled() { - if err := s.storage.UpdateUserIdentity(ctx, authReq.Claims.UserID, authReq.ConnectorID, func(old storage.UserIdentity) (storage.UserIdentity, error) { - if old.Consents == nil { - old.Consents = make(map[string][]string) - } - old.Consents[authReq.ClientID] = authReq.Scopes - return old, nil - }); err != nil { - s.logger.ErrorContext(ctx, "failed to update user identity consents", "err", err) - } - } - http.Redirect(w, r, u.String(), http.StatusSeeOther) } // scopesCoveredByConsent checks whether the approved scopes cover all requested scopes. -// Technical scopes (openid, offline_access) are excluded from the comparison. +// The openid scope is excluded from the comparison as it is a technical scope +// that does not require user consent. func scopesCoveredByConsent(approved, requested []string) bool { - approvedSet := make(map[string]bool, len(approved)) + approvedSet := make(map[string]struct{}, len(approved)) for _, s := range approved { - approvedSet[s] = true + approvedSet[s] = struct{}{} } for _, scope := range requested { - if scope == scopeOpenID || scope == scopeOfflineAccess { + if scope == scopeOpenID { continue } - if !approvedSet[scope] { + if _, ok := approvedSet[scope]; !ok { return false } } diff --git a/server/handlers_test.go b/server/handlers_test.go index 23196a27..933e9e4d 100644 --- a/server/handlers_test.go +++ b/server/handlers_test.go @@ -548,12 +548,12 @@ func TestFinalizeLoginCreatesUserIdentity(t *testing.T) { require.Equal(t, 303, rr.Code) ui, err := s.storage.GetUserIdentity(ctx, "0-385-28089-0", connID) - require.NoError(t, err) + require.NoError(t, err, "UserIdentity should exist after login") require.Equal(t, "0-385-28089-0", ui.UserID) require.Equal(t, connID, ui.ConnectorID) require.Equal(t, "kilgore@kilgore.trout", ui.Claims.Email) - require.NotZero(t, ui.CreatedAt) - require.NotZero(t, ui.LastLogin) + require.NotZero(t, ui.CreatedAt, "CreatedAt should be set") + require.NotZero(t, ui.LastLogin, "LastLogin should be set") } func TestFinalizeLoginUpdatesUserIdentity(t *testing.T) { @@ -612,16 +612,12 @@ func TestFinalizeLoginUpdatesUserIdentity(t *testing.T) { require.Equal(t, 303, rr.Code) ui, err := s.storage.GetUserIdentity(ctx, "0-385-28089-0", connID) - require.NoError(t, err) - // Claims should be refreshed from the connector - require.Equal(t, "Kilgore Trout", ui.Claims.Username) - require.Equal(t, "kilgore@kilgore.trout", ui.Claims.Email) - // LastLogin should be updated - require.True(t, ui.LastLogin.After(oldTime)) - // CreatedAt should NOT change - require.Equal(t, oldTime, ui.CreatedAt) - // Existing consents should be preserved - require.Equal(t, []string{"openid"}, ui.Consents["existing-client"]) + require.NoError(t, err, "UserIdentity should exist after login") + require.Equal(t, "Kilgore Trout", ui.Claims.Username, "claims should be refreshed from the connector") + require.Equal(t, "kilgore@kilgore.trout", ui.Claims.Email, "claims should be refreshed from the connector") + require.True(t, ui.LastLogin.After(oldTime), "LastLogin should be updated") + require.Equal(t, oldTime, ui.CreatedAt, "CreatedAt should not change on update") + require.Equal(t, []string{"openid"}, ui.Consents["existing-client"], "existing consents should be preserved") } func TestFinalizeLoginSkipsUserIdentityWhenDisabled(t *testing.T) { @@ -665,7 +661,7 @@ func TestFinalizeLoginSkipsUserIdentityWhenDisabled(t *testing.T) { require.Equal(t, 303, rr.Code) _, err = s.storage.GetUserIdentity(ctx, "0-385-28089-0", connID) - require.ErrorIs(t, err, storage.ErrNotFound) + require.ErrorIs(t, err, storage.ErrNotFound, "UserIdentity should not be created when sessions disabled") } func TestSkipApprovalWithExistingConsent(t *testing.T) { @@ -714,12 +710,19 @@ func TestSkipApprovalWithExistingConsent(t *testing.T) { wantPath: "/approval", }, { - name: "Only technical scopes - skip with empty consent", + name: "Only openid scope - skip with empty consent", consents: map[string][]string{"test": {}}, - scopes: []string{"openid", "offline_access"}, + scopes: []string{"openid"}, clientID: "test", wantPath: "/callback/cb", }, + { + name: "offline_access requires consent", + consents: map[string][]string{"test": {}}, + scopes: []string{"openid", "offline_access"}, + clientID: "test", + wantPath: "/approval", + }, } for _, tc := range tests { @@ -819,11 +822,11 @@ func TestConsentPersistedOnApproval(t *testing.T) { req.Header.Set("Content-Type", "application/x-www-form-urlencoded") s.ServeHTTP(rr, req) - require.Equal(t, http.StatusSeeOther, rr.Code) + require.Equal(t, http.StatusSeeOther, rr.Code, "approval should redirect") ui, err := s.storage.GetUserIdentity(ctx, userID, connectorID) - require.NoError(t, err) - require.Equal(t, []string{"openid", "email", "profile"}, ui.Consents[clientID]) + require.NoError(t, err, "UserIdentity should exist") + require.Equal(t, []string{"openid", "email", "profile"}, ui.Consents[clientID], "approved scopes should be persisted") } func TestScopesCoveredByConsent(t *testing.T) { @@ -846,9 +849,21 @@ func TestScopesCoveredByConsent(t *testing.T) { want: false, }, { - name: "Only technical scopes", + name: "Only openid scope skipped", + approved: []string{}, + requested: []string{"openid"}, + want: true, + }, + { + name: "offline_access requires consent", approved: []string{}, requested: []string{"openid", "offline_access"}, + want: false, + }, + { + name: "offline_access covered by consent", + approved: []string{"offline_access"}, + requested: []string{"openid", "offline_access"}, want: true, }, {