Browse Source

Code review fixes

Signed-off-by: maksim.nabokikh <max.nabokih@gmail.com>
pull/4645/head
maksim.nabokikh 19 hours ago
parent
commit
51574eef3a
  1. 62
      server/handlers.go
  2. 55
      server/handlers_test.go

62
server/handlers.go

@ -721,10 +721,11 @@ func (s *Server) finalizeLogin(ctx context.Context, identity connector.Identity,
var userIdentity *storage.UserIdentity
if featureflags.SessionsEnabled.Enabled() {
now := s.now()
existing, err := s.storage.GetUserIdentity(ctx, identity.UserID, authReq.ConnectorID)
ui, err := s.storage.GetUserIdentity(ctx, identity.UserID, authReq.ConnectorID)
switch {
case err == storage.ErrNotFound:
ui := storage.UserIdentity{
case err != nil && errors.Is(err, storage.ErrNotFound):
ui = storage.UserIdentity{
UserID: identity.UserID,
ConnectorID: authReq.ConnectorID,
Claims: claims,
@ -734,23 +735,28 @@ func (s *Server) finalizeLogin(ctx context.Context, identity connector.Identity,
}
if err := s.storage.CreateUserIdentity(ctx, ui); err != nil {
s.logger.ErrorContext(ctx, "failed to create user identity", "err", err)
} else {
userIdentity = &ui
return "", false, err
}
case err == nil:
existing.Claims = claims
existing.LastLogin = now
if err := s.storage.UpdateUserIdentity(ctx, identity.UserID, authReq.ConnectorID, func(old storage.UserIdentity) (storage.UserIdentity, error) {
old.Claims = claims
old.LastLogin = now
if len(identity.ConnectorData) > 0 {
old.Claims = claims
old.LastLogin = now
return old, nil
}
return old, nil
}); err != nil {
s.logger.ErrorContext(ctx, "failed to update user identity", "err", err)
return "", false, err
}
userIdentity = &existing
// Update the existing UserIdentity obj with new claims to use them later in the flow.
ui.Claims = claims
ui.LastLogin = now
default:
s.logger.ErrorContext(ctx, "failed to get user identity", "err", err)
return "", false, err
}
userIdentity = &ui
}
// we can skip the redirect to /approval and go ahead and send code if it's not required
@ -830,6 +836,18 @@ func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) {
s.renderError(r, w, http.StatusInternalServerError, "Approval rejected.")
return
}
// Persist user-approved scopes as consent for this client.
if featureflags.SessionsEnabled.Enabled() {
if err := s.storage.UpdateUserIdentity(ctx, authReq.Claims.UserID, authReq.ConnectorID, func(old storage.UserIdentity) (storage.UserIdentity, error) {
if old.Consents == nil {
old.Consents = make(map[string][]string)
}
old.Consents[authReq.ClientID] = authReq.Scopes
return old, nil
}); err != nil {
s.logger.ErrorContext(ctx, "failed to update user identity consents", "err", err)
}
}
s.sendCodeResponse(w, r, authReq)
}
}
@ -972,35 +990,23 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe
u.RawQuery = q.Encode()
}
// Persist approved scopes as user consent for this client.
if featureflags.SessionsEnabled.Enabled() {
if err := s.storage.UpdateUserIdentity(ctx, authReq.Claims.UserID, authReq.ConnectorID, func(old storage.UserIdentity) (storage.UserIdentity, error) {
if old.Consents == nil {
old.Consents = make(map[string][]string)
}
old.Consents[authReq.ClientID] = authReq.Scopes
return old, nil
}); err != nil {
s.logger.ErrorContext(ctx, "failed to update user identity consents", "err", err)
}
}
http.Redirect(w, r, u.String(), http.StatusSeeOther)
}
// scopesCoveredByConsent checks whether the approved scopes cover all requested scopes.
// Technical scopes (openid, offline_access) are excluded from the comparison.
// The openid scope is excluded from the comparison as it is a technical scope
// that does not require user consent.
func scopesCoveredByConsent(approved, requested []string) bool {
approvedSet := make(map[string]bool, len(approved))
approvedSet := make(map[string]struct{}, len(approved))
for _, s := range approved {
approvedSet[s] = true
approvedSet[s] = struct{}{}
}
for _, scope := range requested {
if scope == scopeOpenID || scope == scopeOfflineAccess {
if scope == scopeOpenID {
continue
}
if !approvedSet[scope] {
if _, ok := approvedSet[scope]; !ok {
return false
}
}

55
server/handlers_test.go

@ -548,12 +548,12 @@ func TestFinalizeLoginCreatesUserIdentity(t *testing.T) {
require.Equal(t, 303, rr.Code)
ui, err := s.storage.GetUserIdentity(ctx, "0-385-28089-0", connID)
require.NoError(t, err)
require.NoError(t, err, "UserIdentity should exist after login")
require.Equal(t, "0-385-28089-0", ui.UserID)
require.Equal(t, connID, ui.ConnectorID)
require.Equal(t, "kilgore@kilgore.trout", ui.Claims.Email)
require.NotZero(t, ui.CreatedAt)
require.NotZero(t, ui.LastLogin)
require.NotZero(t, ui.CreatedAt, "CreatedAt should be set")
require.NotZero(t, ui.LastLogin, "LastLogin should be set")
}
func TestFinalizeLoginUpdatesUserIdentity(t *testing.T) {
@ -612,16 +612,12 @@ func TestFinalizeLoginUpdatesUserIdentity(t *testing.T) {
require.Equal(t, 303, rr.Code)
ui, err := s.storage.GetUserIdentity(ctx, "0-385-28089-0", connID)
require.NoError(t, err)
// Claims should be refreshed from the connector
require.Equal(t, "Kilgore Trout", ui.Claims.Username)
require.Equal(t, "kilgore@kilgore.trout", ui.Claims.Email)
// LastLogin should be updated
require.True(t, ui.LastLogin.After(oldTime))
// CreatedAt should NOT change
require.Equal(t, oldTime, ui.CreatedAt)
// Existing consents should be preserved
require.Equal(t, []string{"openid"}, ui.Consents["existing-client"])
require.NoError(t, err, "UserIdentity should exist after login")
require.Equal(t, "Kilgore Trout", ui.Claims.Username, "claims should be refreshed from the connector")
require.Equal(t, "kilgore@kilgore.trout", ui.Claims.Email, "claims should be refreshed from the connector")
require.True(t, ui.LastLogin.After(oldTime), "LastLogin should be updated")
require.Equal(t, oldTime, ui.CreatedAt, "CreatedAt should not change on update")
require.Equal(t, []string{"openid"}, ui.Consents["existing-client"], "existing consents should be preserved")
}
func TestFinalizeLoginSkipsUserIdentityWhenDisabled(t *testing.T) {
@ -665,7 +661,7 @@ func TestFinalizeLoginSkipsUserIdentityWhenDisabled(t *testing.T) {
require.Equal(t, 303, rr.Code)
_, err = s.storage.GetUserIdentity(ctx, "0-385-28089-0", connID)
require.ErrorIs(t, err, storage.ErrNotFound)
require.ErrorIs(t, err, storage.ErrNotFound, "UserIdentity should not be created when sessions disabled")
}
func TestSkipApprovalWithExistingConsent(t *testing.T) {
@ -714,12 +710,19 @@ func TestSkipApprovalWithExistingConsent(t *testing.T) {
wantPath: "/approval",
},
{
name: "Only technical scopes - skip with empty consent",
name: "Only openid scope - skip with empty consent",
consents: map[string][]string{"test": {}},
scopes: []string{"openid", "offline_access"},
scopes: []string{"openid"},
clientID: "test",
wantPath: "/callback/cb",
},
{
name: "offline_access requires consent",
consents: map[string][]string{"test": {}},
scopes: []string{"openid", "offline_access"},
clientID: "test",
wantPath: "/approval",
},
}
for _, tc := range tests {
@ -819,11 +822,11 @@ func TestConsentPersistedOnApproval(t *testing.T) {
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
s.ServeHTTP(rr, req)
require.Equal(t, http.StatusSeeOther, rr.Code)
require.Equal(t, http.StatusSeeOther, rr.Code, "approval should redirect")
ui, err := s.storage.GetUserIdentity(ctx, userID, connectorID)
require.NoError(t, err)
require.Equal(t, []string{"openid", "email", "profile"}, ui.Consents[clientID])
require.NoError(t, err, "UserIdentity should exist")
require.Equal(t, []string{"openid", "email", "profile"}, ui.Consents[clientID], "approved scopes should be persisted")
}
func TestScopesCoveredByConsent(t *testing.T) {
@ -846,9 +849,21 @@ func TestScopesCoveredByConsent(t *testing.T) {
want: false,
},
{
name: "Only technical scopes",
name: "Only openid scope skipped",
approved: []string{},
requested: []string{"openid"},
want: true,
},
{
name: "offline_access requires consent",
approved: []string{},
requested: []string{"openid", "offline_access"},
want: false,
},
{
name: "offline_access covered by consent",
approved: []string{"offline_access"},
requested: []string{"openid", "offline_access"},
want: true,
},
{

Loading…
Cancel
Save