Browse Source

Code review fixes

Signed-off-by: maksim.nabokikh <max.nabokih@gmail.com>
pull/4645/head
maksim.nabokikh 1 month ago
parent
commit
51574eef3a
  1. 62
      server/handlers.go
  2. 55
      server/handlers_test.go

62
server/handlers.go

@ -721,10 +721,11 @@ func (s *Server) finalizeLogin(ctx context.Context, identity connector.Identity,
var userIdentity *storage.UserIdentity var userIdentity *storage.UserIdentity
if featureflags.SessionsEnabled.Enabled() { if featureflags.SessionsEnabled.Enabled() {
now := s.now() now := s.now()
existing, err := s.storage.GetUserIdentity(ctx, identity.UserID, authReq.ConnectorID)
ui, err := s.storage.GetUserIdentity(ctx, identity.UserID, authReq.ConnectorID)
switch { switch {
case err == storage.ErrNotFound: case err != nil && errors.Is(err, storage.ErrNotFound):
ui := storage.UserIdentity{ ui = storage.UserIdentity{
UserID: identity.UserID, UserID: identity.UserID,
ConnectorID: authReq.ConnectorID, ConnectorID: authReq.ConnectorID,
Claims: claims, Claims: claims,
@ -734,23 +735,28 @@ func (s *Server) finalizeLogin(ctx context.Context, identity connector.Identity,
} }
if err := s.storage.CreateUserIdentity(ctx, ui); err != nil { if err := s.storage.CreateUserIdentity(ctx, ui); err != nil {
s.logger.ErrorContext(ctx, "failed to create user identity", "err", err) s.logger.ErrorContext(ctx, "failed to create user identity", "err", err)
} else { return "", false, err
userIdentity = &ui
} }
case err == nil: case err == nil:
existing.Claims = claims
existing.LastLogin = now
if err := s.storage.UpdateUserIdentity(ctx, identity.UserID, authReq.ConnectorID, func(old storage.UserIdentity) (storage.UserIdentity, error) { if err := s.storage.UpdateUserIdentity(ctx, identity.UserID, authReq.ConnectorID, func(old storage.UserIdentity) (storage.UserIdentity, error) {
old.Claims = claims if len(identity.ConnectorData) > 0 {
old.LastLogin = now old.Claims = claims
old.LastLogin = now
return old, nil
}
return old, nil return old, nil
}); err != nil { }); err != nil {
s.logger.ErrorContext(ctx, "failed to update user identity", "err", err) s.logger.ErrorContext(ctx, "failed to update user identity", "err", err)
return "", false, err
} }
userIdentity = &existing // Update the existing UserIdentity obj with new claims to use them later in the flow.
ui.Claims = claims
ui.LastLogin = now
default: default:
s.logger.ErrorContext(ctx, "failed to get user identity", "err", err) s.logger.ErrorContext(ctx, "failed to get user identity", "err", err)
return "", false, err
} }
userIdentity = &ui
} }
// we can skip the redirect to /approval and go ahead and send code if it's not required // we can skip the redirect to /approval and go ahead and send code if it's not required
@ -830,6 +836,18 @@ func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) {
s.renderError(r, w, http.StatusInternalServerError, "Approval rejected.") s.renderError(r, w, http.StatusInternalServerError, "Approval rejected.")
return return
} }
// Persist user-approved scopes as consent for this client.
if featureflags.SessionsEnabled.Enabled() {
if err := s.storage.UpdateUserIdentity(ctx, authReq.Claims.UserID, authReq.ConnectorID, func(old storage.UserIdentity) (storage.UserIdentity, error) {
if old.Consents == nil {
old.Consents = make(map[string][]string)
}
old.Consents[authReq.ClientID] = authReq.Scopes
return old, nil
}); err != nil {
s.logger.ErrorContext(ctx, "failed to update user identity consents", "err", err)
}
}
s.sendCodeResponse(w, r, authReq) s.sendCodeResponse(w, r, authReq)
} }
} }
@ -972,35 +990,23 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe
u.RawQuery = q.Encode() u.RawQuery = q.Encode()
} }
// Persist approved scopes as user consent for this client.
if featureflags.SessionsEnabled.Enabled() {
if err := s.storage.UpdateUserIdentity(ctx, authReq.Claims.UserID, authReq.ConnectorID, func(old storage.UserIdentity) (storage.UserIdentity, error) {
if old.Consents == nil {
old.Consents = make(map[string][]string)
}
old.Consents[authReq.ClientID] = authReq.Scopes
return old, nil
}); err != nil {
s.logger.ErrorContext(ctx, "failed to update user identity consents", "err", err)
}
}
http.Redirect(w, r, u.String(), http.StatusSeeOther) http.Redirect(w, r, u.String(), http.StatusSeeOther)
} }
// scopesCoveredByConsent checks whether the approved scopes cover all requested scopes. // scopesCoveredByConsent checks whether the approved scopes cover all requested scopes.
// Technical scopes (openid, offline_access) are excluded from the comparison. // The openid scope is excluded from the comparison as it is a technical scope
// that does not require user consent.
func scopesCoveredByConsent(approved, requested []string) bool { func scopesCoveredByConsent(approved, requested []string) bool {
approvedSet := make(map[string]bool, len(approved)) approvedSet := make(map[string]struct{}, len(approved))
for _, s := range approved { for _, s := range approved {
approvedSet[s] = true approvedSet[s] = struct{}{}
} }
for _, scope := range requested { for _, scope := range requested {
if scope == scopeOpenID || scope == scopeOfflineAccess { if scope == scopeOpenID {
continue continue
} }
if !approvedSet[scope] { if _, ok := approvedSet[scope]; !ok {
return false return false
} }
} }

55
server/handlers_test.go

@ -548,12 +548,12 @@ func TestFinalizeLoginCreatesUserIdentity(t *testing.T) {
require.Equal(t, 303, rr.Code) require.Equal(t, 303, rr.Code)
ui, err := s.storage.GetUserIdentity(ctx, "0-385-28089-0", connID) ui, err := s.storage.GetUserIdentity(ctx, "0-385-28089-0", connID)
require.NoError(t, err) require.NoError(t, err, "UserIdentity should exist after login")
require.Equal(t, "0-385-28089-0", ui.UserID) require.Equal(t, "0-385-28089-0", ui.UserID)
require.Equal(t, connID, ui.ConnectorID) require.Equal(t, connID, ui.ConnectorID)
require.Equal(t, "kilgore@kilgore.trout", ui.Claims.Email) require.Equal(t, "kilgore@kilgore.trout", ui.Claims.Email)
require.NotZero(t, ui.CreatedAt) require.NotZero(t, ui.CreatedAt, "CreatedAt should be set")
require.NotZero(t, ui.LastLogin) require.NotZero(t, ui.LastLogin, "LastLogin should be set")
} }
func TestFinalizeLoginUpdatesUserIdentity(t *testing.T) { func TestFinalizeLoginUpdatesUserIdentity(t *testing.T) {
@ -612,16 +612,12 @@ func TestFinalizeLoginUpdatesUserIdentity(t *testing.T) {
require.Equal(t, 303, rr.Code) require.Equal(t, 303, rr.Code)
ui, err := s.storage.GetUserIdentity(ctx, "0-385-28089-0", connID) ui, err := s.storage.GetUserIdentity(ctx, "0-385-28089-0", connID)
require.NoError(t, err) require.NoError(t, err, "UserIdentity should exist after login")
// Claims should be refreshed from the connector require.Equal(t, "Kilgore Trout", ui.Claims.Username, "claims should be refreshed from the connector")
require.Equal(t, "Kilgore Trout", ui.Claims.Username) require.Equal(t, "kilgore@kilgore.trout", ui.Claims.Email, "claims should be refreshed from the connector")
require.Equal(t, "kilgore@kilgore.trout", ui.Claims.Email) require.True(t, ui.LastLogin.After(oldTime), "LastLogin should be updated")
// LastLogin should be updated require.Equal(t, oldTime, ui.CreatedAt, "CreatedAt should not change on update")
require.True(t, ui.LastLogin.After(oldTime)) require.Equal(t, []string{"openid"}, ui.Consents["existing-client"], "existing consents should be preserved")
// CreatedAt should NOT change
require.Equal(t, oldTime, ui.CreatedAt)
// Existing consents should be preserved
require.Equal(t, []string{"openid"}, ui.Consents["existing-client"])
} }
func TestFinalizeLoginSkipsUserIdentityWhenDisabled(t *testing.T) { func TestFinalizeLoginSkipsUserIdentityWhenDisabled(t *testing.T) {
@ -665,7 +661,7 @@ func TestFinalizeLoginSkipsUserIdentityWhenDisabled(t *testing.T) {
require.Equal(t, 303, rr.Code) require.Equal(t, 303, rr.Code)
_, err = s.storage.GetUserIdentity(ctx, "0-385-28089-0", connID) _, err = s.storage.GetUserIdentity(ctx, "0-385-28089-0", connID)
require.ErrorIs(t, err, storage.ErrNotFound) require.ErrorIs(t, err, storage.ErrNotFound, "UserIdentity should not be created when sessions disabled")
} }
func TestSkipApprovalWithExistingConsent(t *testing.T) { func TestSkipApprovalWithExistingConsent(t *testing.T) {
@ -714,12 +710,19 @@ func TestSkipApprovalWithExistingConsent(t *testing.T) {
wantPath: "/approval", wantPath: "/approval",
}, },
{ {
name: "Only technical scopes - skip with empty consent", name: "Only openid scope - skip with empty consent",
consents: map[string][]string{"test": {}}, consents: map[string][]string{"test": {}},
scopes: []string{"openid", "offline_access"}, scopes: []string{"openid"},
clientID: "test", clientID: "test",
wantPath: "/callback/cb", wantPath: "/callback/cb",
}, },
{
name: "offline_access requires consent",
consents: map[string][]string{"test": {}},
scopes: []string{"openid", "offline_access"},
clientID: "test",
wantPath: "/approval",
},
} }
for _, tc := range tests { for _, tc := range tests {
@ -819,11 +822,11 @@ func TestConsentPersistedOnApproval(t *testing.T) {
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
s.ServeHTTP(rr, req) s.ServeHTTP(rr, req)
require.Equal(t, http.StatusSeeOther, rr.Code) require.Equal(t, http.StatusSeeOther, rr.Code, "approval should redirect")
ui, err := s.storage.GetUserIdentity(ctx, userID, connectorID) ui, err := s.storage.GetUserIdentity(ctx, userID, connectorID)
require.NoError(t, err) require.NoError(t, err, "UserIdentity should exist")
require.Equal(t, []string{"openid", "email", "profile"}, ui.Consents[clientID]) require.Equal(t, []string{"openid", "email", "profile"}, ui.Consents[clientID], "approved scopes should be persisted")
} }
func TestScopesCoveredByConsent(t *testing.T) { func TestScopesCoveredByConsent(t *testing.T) {
@ -846,9 +849,21 @@ func TestScopesCoveredByConsent(t *testing.T) {
want: false, want: false,
}, },
{ {
name: "Only technical scopes", name: "Only openid scope skipped",
approved: []string{},
requested: []string{"openid"},
want: true,
},
{
name: "offline_access requires consent",
approved: []string{}, approved: []string{},
requested: []string{"openid", "offline_access"}, requested: []string{"openid", "offline_access"},
want: false,
},
{
name: "offline_access covered by consent",
approved: []string{"offline_access"},
requested: []string{"openid", "offline_access"},
want: true, want: true,
}, },
{ {

Loading…
Cancel
Save