Browse Source

feat: implement user identity creation and persisting consent

Signed-off-by: maksim.nabokikh <max.nabokih@gmail.com>
pull/4645/head
maksim.nabokikh 3 days ago
parent
commit
896051de33
  1. 82
      server/handlers.go
  2. 406
      server/handlers_test.go
  3. 5
      server/server.go

82
server/handlers.go

@ -23,6 +23,7 @@ import (
"github.com/gorilla/mux"
"github.com/dexidp/dex/connector"
"github.com/dexidp/dex/pkg/featureflags"
"github.com/dexidp/dex/server/internal"
"github.com/dexidp/dex/storage"
)
@ -716,11 +717,48 @@ func (s *Server) finalizeLogin(ctx context.Context, identity connector.Identity,
}
}
// Create or update UserIdentity to persist user claims across sessions.
if featureflags.SessionsEnabled.Enabled() {
now := s.now()
_, err := s.storage.GetUserIdentity(ctx, identity.UserID, authReq.ConnectorID)
switch {
case err == storage.ErrNotFound:
ui := storage.UserIdentity{
UserID: identity.UserID,
ConnectorID: authReq.ConnectorID,
Claims: claims,
Consents: make(map[string][]string),
CreatedAt: now,
LastLogin: now,
}
if err := s.storage.CreateUserIdentity(ctx, ui); err != nil {
s.logger.ErrorContext(ctx, "failed to create user identity", "err", err)
}
case err == nil:
if err := s.storage.UpdateUserIdentity(ctx, identity.UserID, authReq.ConnectorID, func(old storage.UserIdentity) (storage.UserIdentity, error) {
old.Claims = claims
old.LastLogin = now
return old, nil
}); err != nil {
s.logger.ErrorContext(ctx, "failed to update user identity", "err", err)
}
default:
s.logger.ErrorContext(ctx, "failed to get user identity", "err", err)
}
}
// we can skip the redirect to /approval and go ahead and send code if it's not required
if s.skipApproval && !authReq.ForceApprovalPrompt {
return "", true, nil
}
// Skip approval if user already consented to the requested scopes for this client.
if !authReq.ForceApprovalPrompt && featureflags.SessionsEnabled.Enabled() {
if s.hasExistingConsent(ctx, identity.UserID, authReq.ConnectorID, authReq.ClientID, authReq.Scopes) {
return "", true, nil
}
}
// an HMAC is used here to ensure that the request ID is unpredictable, ensuring that an attacker who intercepted the original
// flow would be unable to poll for the result at the /approval endpoint
h := hmac.New(sha256.New, authReq.HMACKey)
@ -928,9 +966,53 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe
u.RawQuery = q.Encode()
}
// Persist approved scopes as user consent for this client.
if featureflags.SessionsEnabled.Enabled() {
if err := s.storage.UpdateUserIdentity(ctx, authReq.Claims.UserID, authReq.ConnectorID, func(old storage.UserIdentity) (storage.UserIdentity, error) {
if old.Consents == nil {
old.Consents = make(map[string][]string)
}
old.Consents[authReq.ClientID] = authReq.Scopes
return old, nil
}); err != nil {
s.logger.ErrorContext(ctx, "failed to update user identity consents", "err", err)
}
}
http.Redirect(w, r, u.String(), http.StatusSeeOther)
}
// hasExistingConsent checks whether the user has already consented to the requested
// scopes for the given client. Technical scopes (openid, offline_access) are excluded
// from the comparison. Returns false on any error as a safe default.
func (s *Server) hasExistingConsent(ctx context.Context, userID, connectorID, clientID string, scopes []string) bool {
ui, err := s.storage.GetUserIdentity(ctx, userID, connectorID)
if err != nil {
return false
}
approved, ok := ui.Consents[clientID]
if !ok {
return false
}
approvedSet := make(map[string]bool, len(approved))
for _, s := range approved {
approvedSet[s] = true
}
for _, scope := range scopes {
if scope == scopeOpenID || scope == scopeOfflineAccess {
continue
}
if !approvedSet[scope] {
return false
}
}
return true
}
func (s *Server) withClientFromStorage(w http.ResponseWriter, r *http.Request, handler func(http.ResponseWriter, *http.Request, storage.Client)) {
ctx := r.Context()
clientID, clientSecret, ok := r.BasicAuth()

406
server/handlers_test.go

@ -3,6 +3,9 @@ package server
import (
"bytes"
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
@ -495,6 +498,409 @@ func TestHandlePassword_LocalPasswordDBClaims(t *testing.T) {
require.Equal(t, []string{"team-a", "team-a/admins"}, claims.Groups)
}
func setSessionsEnabled(t *testing.T, enabled bool) {
t.Helper()
if enabled {
t.Setenv("DEX_SESSIONS_ENABLED", "true")
} else {
t.Setenv("DEX_SESSIONS_ENABLED", "false")
}
}
func TestFinalizeLoginCreatesUserIdentity(t *testing.T) {
ctx := t.Context()
setSessionsEnabled(t, true)
connID := "mockPw"
authReqID := "test-create-ui"
expiry := time.Now().Add(100 * time.Second)
httpServer, s := newTestServer(t, func(c *Config) {
c.SkipApprovalScreen = true
c.Now = time.Now
})
defer httpServer.Close()
sc := storage.Connector{
ID: connID,
Type: "mockPassword",
Name: "MockPassword",
ResourceVersion: "1",
Config: []byte(`{"username": "foo", "password": "password"}`),
}
require.NoError(t, s.storage.CreateConnector(ctx, sc))
_, err := s.OpenConnector(sc)
require.NoError(t, err)
authReq := storage.AuthRequest{
ID: authReqID,
ConnectorID: connID,
RedirectURI: "cb",
Expiry: expiry,
ResponseTypes: []string{responseTypeCode},
}
require.NoError(t, s.storage.CreateAuthRequest(ctx, authReq))
rr := httptest.NewRecorder()
reqPath := fmt.Sprintf("/auth/%s/login?state=%s&back=&login=foo&password=password", connID, authReqID)
s.handlePasswordLogin(rr, httptest.NewRequest("POST", reqPath, nil))
require.Equal(t, 303, rr.Code)
ui, err := s.storage.GetUserIdentity(ctx, "0-385-28089-0", connID)
require.NoError(t, err)
require.Equal(t, "0-385-28089-0", ui.UserID)
require.Equal(t, connID, ui.ConnectorID)
require.Equal(t, "kilgore@kilgore.trout", ui.Claims.Email)
require.NotZero(t, ui.CreatedAt)
require.NotZero(t, ui.LastLogin)
}
func TestFinalizeLoginUpdatesUserIdentity(t *testing.T) {
ctx := t.Context()
setSessionsEnabled(t, true)
connID := "mockPw"
authReqID := "test-update-ui"
expiry := time.Now().Add(100 * time.Second)
oldTime := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
httpServer, s := newTestServer(t, func(c *Config) {
c.SkipApprovalScreen = true
c.Now = time.Now
})
defer httpServer.Close()
sc := storage.Connector{
ID: connID,
Type: "mockPassword",
Name: "MockPassword",
ResourceVersion: "1",
Config: []byte(`{"username": "foo", "password": "password"}`),
}
require.NoError(t, s.storage.CreateConnector(ctx, sc))
_, err := s.OpenConnector(sc)
require.NoError(t, err)
// Pre-create UserIdentity with old data
require.NoError(t, s.storage.CreateUserIdentity(ctx, storage.UserIdentity{
UserID: "0-385-28089-0",
ConnectorID: connID,
Claims: storage.Claims{
UserID: "0-385-28089-0",
Username: "Old Name",
Email: "old@example.com",
},
Consents: map[string][]string{"existing-client": {"openid"}},
CreatedAt: oldTime,
LastLogin: oldTime,
}))
authReq := storage.AuthRequest{
ID: authReqID,
ConnectorID: connID,
RedirectURI: "cb",
Expiry: expiry,
ResponseTypes: []string{responseTypeCode},
}
require.NoError(t, s.storage.CreateAuthRequest(ctx, authReq))
rr := httptest.NewRecorder()
reqPath := fmt.Sprintf("/auth/%s/login?state=%s&back=&login=foo&password=password", connID, authReqID)
s.handlePasswordLogin(rr, httptest.NewRequest("POST", reqPath, nil))
require.Equal(t, 303, rr.Code)
ui, err := s.storage.GetUserIdentity(ctx, "0-385-28089-0", connID)
require.NoError(t, err)
// Claims should be refreshed from the connector
require.Equal(t, "Kilgore Trout", ui.Claims.Username)
require.Equal(t, "kilgore@kilgore.trout", ui.Claims.Email)
// LastLogin should be updated
require.True(t, ui.LastLogin.After(oldTime))
// CreatedAt should NOT change
require.Equal(t, oldTime, ui.CreatedAt)
// Existing consents should be preserved
require.Equal(t, []string{"openid"}, ui.Consents["existing-client"])
}
func TestFinalizeLoginSkipsUserIdentityWhenDisabled(t *testing.T) {
ctx := t.Context()
setSessionsEnabled(t, false)
connID := "mockPw"
authReqID := "test-no-ui"
expiry := time.Now().Add(100 * time.Second)
httpServer, s := newTestServer(t, func(c *Config) {
c.SkipApprovalScreen = true
c.Now = time.Now
})
defer httpServer.Close()
sc := storage.Connector{
ID: connID,
Type: "mockPassword",
Name: "MockPassword",
ResourceVersion: "1",
Config: []byte(`{"username": "foo", "password": "password"}`),
}
require.NoError(t, s.storage.CreateConnector(ctx, sc))
_, err := s.OpenConnector(sc)
require.NoError(t, err)
authReq := storage.AuthRequest{
ID: authReqID,
ConnectorID: connID,
RedirectURI: "cb",
Expiry: expiry,
ResponseTypes: []string{responseTypeCode},
}
require.NoError(t, s.storage.CreateAuthRequest(ctx, authReq))
rr := httptest.NewRecorder()
reqPath := fmt.Sprintf("/auth/%s/login?state=%s&back=&login=foo&password=password", connID, authReqID)
s.handlePasswordLogin(rr, httptest.NewRequest("POST", reqPath, nil))
require.Equal(t, 303, rr.Code)
_, err = s.storage.GetUserIdentity(ctx, "0-385-28089-0", connID)
require.ErrorIs(t, err, storage.ErrNotFound)
}
func TestSkipApprovalWithExistingConsent(t *testing.T) {
ctx := t.Context()
setSessionsEnabled(t, true)
connID := "mock"
authReqID := "test-consent-skip"
expiry := time.Now().Add(100 * time.Second)
tests := []struct {
name string
consents map[string][]string
scopes []string
clientID string
forcePrompt bool
wantPath string
}{
{
name: "Existing consent covers requested scopes",
consents: map[string][]string{"test": {"email", "profile"}},
scopes: []string{"openid", "email", "profile"},
clientID: "test",
wantPath: "/callback/cb",
},
{
name: "Existing consent missing a scope",
consents: map[string][]string{"test": {"email"}},
scopes: []string{"openid", "email", "profile"},
clientID: "test",
wantPath: "/approval",
},
{
name: "Force approval overrides consent",
consents: map[string][]string{"test": {"email", "profile"}},
scopes: []string{"openid", "email", "profile"},
clientID: "test",
forcePrompt: true,
wantPath: "/approval",
},
{
name: "No consent for this client",
consents: map[string][]string{"other-client": {"email"}},
scopes: []string{"openid", "email"},
clientID: "test",
wantPath: "/approval",
},
{
name: "Only technical scopes - skip with empty consent",
consents: map[string][]string{"test": {}},
scopes: []string{"openid", "offline_access"},
clientID: "test",
wantPath: "/callback/cb",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
httpServer, s := newTestServer(t, func(c *Config) {
c.SkipApprovalScreen = false
c.Now = time.Now
})
defer httpServer.Close()
// Pre-create UserIdentity with consents
require.NoError(t, s.storage.CreateUserIdentity(ctx, storage.UserIdentity{
UserID: "0-385-28089-0",
ConnectorID: connID,
Claims: storage.Claims{
UserID: "0-385-28089-0",
Username: "Kilgore Trout",
Email: "kilgore@kilgore.trout",
EmailVerified: true,
},
Consents: tc.consents,
CreatedAt: time.Now(),
LastLogin: time.Now(),
}))
authReq := storage.AuthRequest{
ID: authReqID,
ConnectorID: connID,
ClientID: tc.clientID,
RedirectURI: "cb",
Expiry: expiry,
ResponseTypes: []string{responseTypeCode},
Scopes: tc.scopes,
ForceApprovalPrompt: tc.forcePrompt,
}
require.NoError(t, s.storage.CreateAuthRequest(ctx, authReq))
rr := httptest.NewRecorder()
reqPath := fmt.Sprintf("/callback/%s?state=%s", connID, authReqID)
s.handleConnectorCallback(rr, httptest.NewRequest("GET", reqPath, nil))
require.Equal(t, 303, rr.Code)
cb, err := url.Parse(rr.Result().Header.Get("Location"))
require.NoError(t, err)
require.Equal(t, tc.wantPath, cb.Path)
})
}
}
func TestConsentPersistedOnApproval(t *testing.T) {
ctx := t.Context()
setSessionsEnabled(t, true)
httpServer, s := newTestServer(t, nil)
defer httpServer.Close()
userID := "test-user"
connectorID := "mock"
clientID := "test"
// Pre-create UserIdentity (would have been created during login)
require.NoError(t, s.storage.CreateUserIdentity(ctx, storage.UserIdentity{
UserID: userID,
ConnectorID: connectorID,
Claims: storage.Claims{UserID: userID},
Consents: make(map[string][]string),
CreatedAt: time.Now(),
LastLogin: time.Now(),
}))
authReq := storage.AuthRequest{
ID: "approval-consent-test",
ClientID: clientID,
ConnectorID: connectorID,
ResponseTypes: []string{responseTypeCode},
RedirectURI: "https://client.example/callback",
Expiry: time.Now().Add(time.Minute),
LoggedIn: true,
Claims: storage.Claims{UserID: userID},
Scopes: []string{"openid", "email", "profile"},
HMACKey: []byte("consent-test-key"),
}
require.NoError(t, s.storage.CreateAuthRequest(ctx, authReq))
h := hmac.New(sha256.New, authReq.HMACKey)
h.Write([]byte(authReq.ID))
mac := base64.RawURLEncoding.EncodeToString(h.Sum(nil))
form := url.Values{
"approval": {"approve"},
"req": {authReq.ID},
"hmac": {mac},
}
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/approval", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
s.ServeHTTP(rr, req)
require.Equal(t, http.StatusSeeOther, rr.Code)
ui, err := s.storage.GetUserIdentity(ctx, userID, connectorID)
require.NoError(t, err)
require.Equal(t, []string{"openid", "email", "profile"}, ui.Consents[clientID])
}
func TestHasExistingConsent(t *testing.T) {
ctx := t.Context()
httpServer, s := newTestServer(t, nil)
defer httpServer.Close()
userID := "consent-user"
connID := "mock"
tests := []struct {
name string
consents map[string][]string
clientID string
scopes []string
want bool
}{
{
name: "All scopes covered",
consents: map[string][]string{"client-a": {"email", "profile"}},
clientID: "client-a",
scopes: []string{"openid", "email", "profile"},
want: true,
},
{
name: "Missing scope",
consents: map[string][]string{"client-a": {"email"}},
clientID: "client-a",
scopes: []string{"openid", "email", "groups"},
want: false,
},
{
name: "Only technical scopes",
consents: map[string][]string{"client-a": {}},
clientID: "client-a",
scopes: []string{"openid", "offline_access"},
want: true,
},
{
name: "No consent for client",
consents: map[string][]string{"other": {"email"}},
clientID: "client-a",
scopes: []string{"email"},
want: false,
},
{
name: "No UserIdentity at all",
consents: nil,
clientID: "client-a",
scopes: []string{"email"},
want: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Clean up any previous identity
_ = s.storage.DeleteUserIdentity(ctx, userID, connID)
if tc.consents != nil {
require.NoError(t, s.storage.CreateUserIdentity(ctx, storage.UserIdentity{
UserID: userID,
ConnectorID: connID,
Claims: storage.Claims{UserID: userID},
Consents: tc.consents,
CreatedAt: time.Now(),
LastLogin: time.Now(),
}))
}
got := s.hasExistingConsent(ctx, userID, connID, tc.clientID, tc.scopes)
require.Equal(t, tc.want, got)
})
}
}
func TestHandlePasswordLoginWithSkipApproval(t *testing.T) {
ctx := t.Context()

5
server/server.go

@ -45,6 +45,7 @@ import (
"github.com/dexidp/dex/connector/oidc"
"github.com/dexidp/dex/connector/openshift"
"github.com/dexidp/dex/connector/saml"
"github.com/dexidp/dex/pkg/featureflags"
"github.com/dexidp/dex/server/signer"
"github.com/dexidp/dex/storage"
"github.com/dexidp/dex/web"
@ -377,6 +378,10 @@ func newServer(ctx context.Context, c Config) (*Server, error) {
return nil, fmt.Errorf("server: failed to open all connectors (%d/%d)", failedCount, len(storageConnectors))
}
if featureflags.SessionsEnabled.Enabled() {
s.logger.InfoContext(ctx, "sessions feature flag is enabled")
}
instrumentHandler := func(_ string, handler http.Handler) http.HandlerFunc {
return handler.ServeHTTP
}

Loading…
Cancel
Save