From 0ac4d403063026297429b563e424c22636127b59 Mon Sep 17 00:00:00 2001 From: Julius Foitzik Date: Fri, 17 Oct 2025 16:54:16 +0200 Subject: [PATCH 1/9] remember me revert snapshot fix lint error snaphost snapshot Signed-off-by: Julius Foitzik --- FEATURE-NOTES.md | 27 ++ cmd/dex/config.go | 11 + cmd/dex/serve.go | 4 + internal/jwt/signature.go | 59 ++++ remember-me/handler.go | 173 +++++++++++ remember-me/handler_test.go | 519 ++++++++++++++++++++++++++++++++ server/handlers.go | 66 ++++ server/oauth2.go | 12 +- server/server.go | 21 +- storage/conformance/gen_jwks.go | 2 +- storage/memory/memory.go | 40 +++ storage/storage.go | 19 +- 12 files changed, 940 insertions(+), 13 deletions(-) create mode 100644 FEATURE-NOTES.md create mode 100644 internal/jwt/signature.go create mode 100644 remember-me/handler.go create mode 100644 remember-me/handler_test.go diff --git a/FEATURE-NOTES.md b/FEATURE-NOTES.md new file mode 100644 index 00000000..153af150 --- /dev/null +++ b/FEATURE-NOTES.md @@ -0,0 +1,27 @@ +Goal is to come up with a compact and minimal design for https://github.com/dexidp/dex/issues/32 + + +# Notes + + - Use cookies to identify a returning user + - Sign cookie using one of the internal private keys + - Verify cookie when present + - Generate and store cookie or cookie encrypted value on Store + - Requires extension of storage.Storage interface + - Feature Flag + - Only introduce code in code-path of Password-based login providers + - Write cookie to response on success login (see Ref(1)) + - Cookie ExpiresIn should be less or equal to the minted JWT ExpiresIn + - I think simple store the storage.Claims+identity.ConnectorData and inject them into finalizeLogin, if the user is already logged in + - An ActiveSession is only valid once it has an associated AccessToken + - The ActiveSession should expire after a configurable amount of time + + + +Ref(1): + +probably here, and only in the case of http.MethodPost. + +```go +func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) { +``` diff --git a/cmd/dex/config.go b/cmd/dex/config.go index aa49a181..ee48951e 100644 --- a/cmd/dex/config.go +++ b/cmd/dex/config.go @@ -33,6 +33,9 @@ type Config struct { Expiry Expiry `json:"expiry"` Logger Logger `json:"logger"` + // Experimental Feature to remember users when they authenticated using a password-based connector + Sessions Sessions `json:"sessions"` + Frontend server.WebConfig `json:"frontend"` // StaticConnectors are user defined connectors specified in the ConfigMap @@ -51,6 +54,10 @@ type Config struct { // querying the storage. Cannot be specified without enabling a passwords // database. StaticPasswords []password `json:"staticPasswords"` + + // If enabled, the server will maintain a active sessions for password connectors + // to identify returning users to avoid re-entering credentials if the session is still valid. + ExperimentalEnableRememberMe bool `json:"experimentalEnableRememberMe"` } // Validate the configuration @@ -165,6 +172,10 @@ type Web struct { ClientRemoteIP ClientRemoteIP `json:"clientRemoteIP"` } +type Sessions struct { + Enable bool `json:"enable"` +} + type ClientRemoteIP struct { Header string `json:"header"` TrustedProxies []string `json:"trustedProxies"` diff --git a/cmd/dex/serve.go b/cmd/dex/serve.go index ac715e60..9a3dc481 100644 --- a/cmd/dex/serve.go +++ b/cmd/dex/serve.go @@ -307,6 +307,10 @@ func runServe(options serveOptions) error { PrometheusRegistry: prometheusRegistry, HealthChecker: healthChecker, ContinueOnConnectorFailure: featureflags.ContinueOnConnectorFailure.Enabled(), + EnableRememberMe: c.Sessions.Enable, + } + if c.Sessions.Enable { + logger.Info("remember me experimental feature enabled") } if c.Expiry.SigningKeys != "" { signingKeys, err := time.ParseDuration(c.Expiry.SigningKeys) diff --git a/internal/jwt/signature.go b/internal/jwt/signature.go new file mode 100644 index 00000000..ad6d2afd --- /dev/null +++ b/internal/jwt/signature.go @@ -0,0 +1,59 @@ +package jwt + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rsa" + "errors" + "fmt" + + "github.com/go-jose/go-jose/v4" +) + +// Determine the signature algorithm for a JWT. +func SignatureAlgorithm(jwk *jose.JSONWebKey) (alg jose.SignatureAlgorithm, err error) { + if jwk.Key == nil { + return alg, errors.New("no signing key") + } + switch key := jwk.Key.(type) { + case *rsa.PrivateKey: + // Because OIDC mandates that we support RS256, we always return that + // value. In the future, we might want to make this configurable on a + // per client basis. For example allowing PS256 or ECDSA variants. + // + // See https://github.com/dexidp/dex/issues/692 + return jose.RS256, nil + case *ecdsa.PrivateKey: + // We don't actually support ECDSA keys yet, but they're tested for + // in case we want to in the future. + // + // These values are prescribed depending on the ECDSA key type. We + // can't return different values. + switch key.Params() { + case elliptic.P256().Params(): + return jose.ES256, nil + case elliptic.P384().Params(): + return jose.ES384, nil + case elliptic.P521().Params(): + return jose.ES512, nil + default: + return alg, errors.New("unsupported ecdsa curve") + } + default: + return alg, fmt.Errorf("unsupported signing key type %T", key) + } +} + +func SignPayload(key *jose.JSONWebKey, alg jose.SignatureAlgorithm, payload []byte) (jws string, err error) { + signingKey := jose.SigningKey{Key: key, Algorithm: alg} + + signer, err := jose.NewSigner(signingKey, &jose.SignerOptions{}) + if err != nil { + return "", fmt.Errorf("new signer: %v", err) + } + signature, err := signer.Sign(payload) + if err != nil { + return "", fmt.Errorf("signing payload: %v", err) + } + return signature.CompactSerialize() +} diff --git a/remember-me/handler.go b/remember-me/handler.go new file mode 100644 index 00000000..ab790a9b --- /dev/null +++ b/remember-me/handler.go @@ -0,0 +1,173 @@ +package rememberme + +import ( + "context" + "crypto/sha3" + "errors" + "fmt" + "log/slog" + "net/http" + "time" + + "github.com/dexidp/dex/connector" + "github.com/dexidp/dex/internal/jwt" + "github.com/dexidp/dex/storage" +) + +const ACTIVE_SESSION_COOKIE_NAME = "dex_active_session_cookie" + +type AuthContext struct { + connectorName string + identity *connector.Identity + configuredExpiryDuration time.Duration +} + +func NewAnonymousAuthContext(connectorName string, configuredExpiryDuration time.Duration) AuthContext { + return AuthContext{connectorName, nil, configuredExpiryDuration} +} + +func NewAuthContextWithIdentity(connectorName string, identity connector.Identity, configuredExpiryDuration time.Duration) AuthContext { + return AuthContext{connectorName, &identity, configuredExpiryDuration} +} + +type GetOrUnsetCookie struct { + cookie *http.Cookie + unset bool +} + +func (c GetOrUnsetCookie) Empty() bool { + return c.unset == false && c.cookie == nil +} + +func (c GetOrUnsetCookie) Get() (*http.Cookie, bool) { + // TODO(juf): would prefer to not return internal pointer + return c.cookie, c.unset +} + +func RequestUnsetCookie(cookieName string) GetOrUnsetCookie { + return GetOrUnsetCookie{ + &http.Cookie{Name: cookieName, Path: "/", MaxAge: -1, Secure: true, HttpOnly: true, SameSite: http.SameSiteStrictMode}, true, + } +} + +func RequestSetCookie(cookie http.Cookie) GetOrUnsetCookie { + return GetOrUnsetCookie{ + &cookie, false, + } +} + +type RememberMeCtx struct { + Session storage.ActiveSession + Cookie GetOrUnsetCookie +} + +func (ctx RememberMeCtx) IsValid() bool { + return ctx.Session.Expiry.After(time.Now()) +} + +// connector_cookie_name creates a string which is used to identify the cookie that matches the given connector. +// The purpose is to avoid having one cookie for multiple providers where you only authenticate once and suddenly would have +// access to other connectors. +func connector_cookie_name(connName string) string { + return fmt.Sprintf("%s_%s", ACTIVE_SESSION_COOKIE_NAME, connName) +} + +func HandleRememberMe(ctx context.Context, logger *slog.Logger, req *http.Request, data AuthContext, store storage.Storage, sessionStore storage.ActiveSessionStorage) (*RememberMeCtx, error) { + keys, err := store.GetKeys(ctx) + if err != nil { + logger.ErrorContext(req.Context(), "failed to get keys", "err", err) + return nil, err + } + signAlg, err := jwt.SignatureAlgorithm(keys.SigningKey) + if err != nil { + logger.ErrorContext(req.Context(), "failed to get signAlg", "err", err) + return nil, err + } + if val, found := extractCookie(req, data.connectorName); found { + cookieName := connector_cookie_name(data.connectorName) + logger.DebugContext(req.Context(), "returning user cookie found, checking for active session", "connectorName", data.connectorName) + session, err := sessionStore.GetSession(ctx, val) + if err != nil { + if errors.Is(err, storage.ErrNotFound) { + return &RememberMeCtx{ + Session: session, + Cookie: RequestUnsetCookie(cookieName), + }, nil + } + logger.ErrorContext(req.Context(), "failed to get active session", "err", err, "connectorName", data.connectorName) + return nil, err + } + cookie := GetOrUnsetCookie{nil, false} + if session.Expiry.Before(time.Now()) { + logger.DebugContext(req.Context(), "session expired unsetting cookie", "connectorName", data.connectorName) + cookie = RequestUnsetCookie(cookieName) + } + return &RememberMeCtx{ + Session: session, + Cookie: cookie, + }, nil + } else { + if data.identity == nil { + logger.DebugContext(req.Context(), "identity is empty, returning early", "connectorName", data.connectorName) + return nil, storage.ErrNotFound + } + h := sha3.New512() + h.Write([]byte(data.identity.Email)) + for _, g := range data.identity.Groups { + h.Write([]byte(g)) + } + h.Write([]byte(data.identity.UserID)) + h.Write([]byte(data.identity.Username)) + h.Write([]byte(data.identity.PreferredUsername)) + hash := fmt.Sprintf("%x", h.Sum(nil)) + signedHash, err := jwt.SignPayload(keys.SigningKey, signAlg, []byte(hash)) + if err != nil { + logger.ErrorContext(req.Context(), "failed to get sign payload", "err", err, "connectorName", data.connectorName) + return nil, err + } + // TODO(juf): Double check what we need to persist and are given + // in the context of whether we need to make an "auto-redirect" + // Because technically we do not return the ID nor RefreshToken to the user + // instead we redirect him back to the caller with an authCode + session := storage.ActiveSession{ + Identity: *data.identity, // TODO(juf): Avoid nil pointer + // TODO(juf): Think about changing to use Token IssuedAt date instead of now to have + // alignment with the token + Expiry: time.Now().Add(data.configuredExpiryDuration), + } + logger.DebugContext(req.Context(), "creating active session for user", "connectorName", data.connectorName) + if err := sessionStore.CreateSession(ctx, signedHash, session); err != nil { + logger.ErrorContext(req.Context(), "failed to store active session", "err", err, "connectorName", data.connectorName) + return nil, err + } + + // TODO(juf): SET COOKIE + return &RememberMeCtx{ + Session: session, + Cookie: RequestSetCookie(http.Cookie{ + Name: connector_cookie_name(data.connectorName), + Value: signedHash, + Path: "/", + Domain: "", // TODO(juf): Check if we need to set this + Expires: session.Expiry, + MaxAge: int(time.Until(session.Expiry).Seconds()), + Secure: true, + HttpOnly: true, + SameSite: http.SameSiteStrictMode, + }), + }, nil + } +} + +func extractCookie(req *http.Request, connName string) (value string, found bool) { + cookies := req.Cookies() + if len(cookies) > 0 { + for _, ck := range cookies { + if ck.Name != connector_cookie_name(connName) { + continue + } + return ck.Value, true + } + } + return "", false +} diff --git a/remember-me/handler_test.go b/remember-me/handler_test.go new file mode 100644 index 00000000..88398609 --- /dev/null +++ b/remember-me/handler_test.go @@ -0,0 +1,519 @@ +package rememberme + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/sha3" + "encoding/base64" + "errors" + "fmt" + "log/slog" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/dexidp/dex/connector" + "github.com/dexidp/dex/internal/jwt" + "github.com/dexidp/dex/storage" + "github.com/go-jose/go-jose/v4" +) + +var _ storage.Storage = (*mockStorage)(nil) +var _ storage.ActiveSessionStorage = (*mockStorage)(nil) + +// mockStorage implements storage.Storage for testing key retrieval. +type mockStorage struct { + keys storage.Keys + err error +} + +// CreateSession implements storage.ActiveSessionStorage. +func (m *mockStorage) CreateSession(ctx context.Context, identifier string, data storage.ActiveSession) error { + panic("unimplemented") +} + +// GetSession implements storage.ActiveSessionStorage. +func (m *mockStorage) GetSession(ctx context.Context, identifier string) (storage.ActiveSession, error) { + panic("unimplemented") +} + +// CreateAuthCode implements storage.Storage. +func (m *mockStorage) CreateAuthCode(ctx context.Context, c storage.AuthCode) error { + panic("unimplemented") +} + +// CreateAuthRequest implements storage.Storage. +func (m *mockStorage) CreateAuthRequest(ctx context.Context, a storage.AuthRequest) error { + panic("unimplemented") +} + +// CreateClient implements storage.Storage. +func (m *mockStorage) CreateClient(ctx context.Context, c storage.Client) error { + panic("unimplemented") +} + +// CreateConnector implements storage.Storage. +func (m *mockStorage) CreateConnector(ctx context.Context, c storage.Connector) error { + panic("unimplemented") +} + +// CreateDeviceRequest implements storage.Storage. +func (m *mockStorage) CreateDeviceRequest(ctx context.Context, d storage.DeviceRequest) error { + panic("unimplemented") +} + +// CreateDeviceToken implements storage.Storage. +func (m *mockStorage) CreateDeviceToken(ctx context.Context, d storage.DeviceToken) error { + panic("unimplemented") +} + +// CreateOfflineSessions implements storage.Storage. +func (m *mockStorage) CreateOfflineSessions(ctx context.Context, s storage.OfflineSessions) error { + panic("unimplemented") +} + +// CreatePassword implements storage.Storage. +func (m *mockStorage) CreatePassword(ctx context.Context, p storage.Password) error { + panic("unimplemented") +} + +// CreateRefresh implements storage.Storage. +func (m *mockStorage) CreateRefresh(ctx context.Context, r storage.RefreshToken) error { + panic("unimplemented") +} + +// DeleteAuthCode implements storage.Storage. +func (m *mockStorage) DeleteAuthCode(ctx context.Context, code string) error { + panic("unimplemented") +} + +// DeleteAuthRequest implements storage.Storage. +func (m *mockStorage) DeleteAuthRequest(ctx context.Context, id string) error { + panic("unimplemented") +} + +// DeleteClient implements storage.Storage. +func (m *mockStorage) DeleteClient(ctx context.Context, id string) error { + panic("unimplemented") +} + +// DeleteConnector implements storage.Storage. +func (m *mockStorage) DeleteConnector(ctx context.Context, id string) error { + panic("unimplemented") +} + +// DeleteOfflineSessions implements storage.Storage. +func (m *mockStorage) DeleteOfflineSessions(ctx context.Context, userID string, connID string) error { + panic("unimplemented") +} + +// DeletePassword implements storage.Storage. +func (m *mockStorage) DeletePassword(ctx context.Context, email string) error { + panic("unimplemented") +} + +// DeleteRefresh implements storage.Storage. +func (m *mockStorage) DeleteRefresh(ctx context.Context, id string) error { + panic("unimplemented") +} + +// GarbageCollect implements storage.Storage. +func (m *mockStorage) GarbageCollect(ctx context.Context, now time.Time) (storage.GCResult, error) { + panic("unimplemented") +} + +// GetAuthCode implements storage.Storage. +func (m *mockStorage) GetAuthCode(ctx context.Context, id string) (storage.AuthCode, error) { + panic("unimplemented") +} + +// GetAuthRequest implements storage.Storage. +func (m *mockStorage) GetAuthRequest(ctx context.Context, id string) (storage.AuthRequest, error) { + panic("unimplemented") +} + +// GetClient implements storage.Storage. +func (m *mockStorage) GetClient(ctx context.Context, id string) (storage.Client, error) { + panic("unimplemented") +} + +// GetConnector implements storage.Storage. +func (m *mockStorage) GetConnector(ctx context.Context, id string) (storage.Connector, error) { + panic("unimplemented") +} + +// GetDeviceRequest implements storage.Storage. +func (m *mockStorage) GetDeviceRequest(ctx context.Context, userCode string) (storage.DeviceRequest, error) { + panic("unimplemented") +} + +// GetDeviceToken implements storage.Storage. +func (m *mockStorage) GetDeviceToken(ctx context.Context, deviceCode string) (storage.DeviceToken, error) { + panic("unimplemented") +} + +// GetOfflineSessions implements storage.Storage. +func (m *mockStorage) GetOfflineSessions(ctx context.Context, userID string, connID string) (storage.OfflineSessions, error) { + panic("unimplemented") +} + +// GetPassword implements storage.Storage. +func (m *mockStorage) GetPassword(ctx context.Context, email string) (storage.Password, error) { + panic("unimplemented") +} + +// GetRefresh implements storage.Storage. +func (m *mockStorage) GetRefresh(ctx context.Context, id string) (storage.RefreshToken, error) { + panic("unimplemented") +} + +// ListClients implements storage.Storage. +func (m *mockStorage) ListClients(ctx context.Context) ([]storage.Client, error) { + panic("unimplemented") +} + +// ListConnectors implements storage.Storage. +func (m *mockStorage) ListConnectors(ctx context.Context) ([]storage.Connector, error) { + panic("unimplemented") +} + +// ListPasswords implements storage.Storage. +func (m *mockStorage) ListPasswords(ctx context.Context) ([]storage.Password, error) { + panic("unimplemented") +} + +// ListRefreshTokens implements storage.Storage. +func (m *mockStorage) ListRefreshTokens(ctx context.Context) ([]storage.RefreshToken, error) { + panic("unimplemented") +} + +// UpdateAuthRequest implements storage.Storage. +func (m *mockStorage) UpdateAuthRequest(ctx context.Context, id string, updater func(a storage.AuthRequest) (storage.AuthRequest, error)) error { + panic("unimplemented") +} + +// UpdateClient implements storage.Storage. +func (m *mockStorage) UpdateClient(ctx context.Context, id string, updater func(old storage.Client) (storage.Client, error)) error { + panic("unimplemented") +} + +// UpdateConnector implements storage.Storage. +func (m *mockStorage) UpdateConnector(ctx context.Context, id string, updater func(c storage.Connector) (storage.Connector, error)) error { + panic("unimplemented") +} + +// UpdateDeviceToken implements storage.Storage. +func (m *mockStorage) UpdateDeviceToken(ctx context.Context, deviceCode string, updater func(t storage.DeviceToken) (storage.DeviceToken, error)) error { + panic("unimplemented") +} + +// UpdateKeys implements storage.Storage. +func (m *mockStorage) UpdateKeys(ctx context.Context, updater func(old storage.Keys) (storage.Keys, error)) error { + panic("unimplemented") +} + +// UpdateOfflineSessions implements storage.Storage. +func (m *mockStorage) UpdateOfflineSessions(ctx context.Context, userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error { + panic("unimplemented") +} + +// UpdatePassword implements storage.Storage. +func (m *mockStorage) UpdatePassword(ctx context.Context, email string, updater func(p storage.Password) (storage.Password, error)) error { + panic("unimplemented") +} + +// UpdateRefreshToken implements storage.Storage. +func (m *mockStorage) UpdateRefreshToken(ctx context.Context, id string, updater func(r storage.RefreshToken) (storage.RefreshToken, error)) error { + panic("unimplemented") +} + +func (m *mockStorage) GetKeys(ctx context.Context) (storage.Keys, error) { + return m.keys, m.err +} + +func (m *mockStorage) Close() error { return nil } + +// mockSessionStorage implements storage.ActiveSessionStorage for testing. +type mockSessionStorage struct { + sessions map[string]storage.ActiveSession + getErr error + createErr error + gcErr error +} + +// GarbageCollect implements storage.ActiveSessionStorage. +func (m *mockSessionStorage) GarbageCollect(ctx context.Context, now time.Time) (storage.GCResult, error) { + panic("unimplemented") +} + +func (m *mockSessionStorage) GetSession(ctx context.Context, id string) (storage.ActiveSession, error) { + if m.getErr != nil { + return storage.ActiveSession{}, m.getErr + } + session, ok := m.sessions[id] + if !ok { + return storage.ActiveSession{}, storage.ErrNotFound + } + return session, nil +} + +func (m *mockSessionStorage) CreateSession(ctx context.Context, id string, session storage.ActiveSession) error { + if m.createErr != nil { + return m.createErr + } + m.sessions[id] = session + return nil +} + +// noOpLogger is a silent logger for tests (level higher than Error to suppress output). +var noOpLogger = slog.New(slog.NewTextHandler(nil, &slog.HandlerOptions{Level: slog.Level(slog.LevelError) + 1})) + +// fixedNow returns a fixed time for deterministic testing. +func fixedNow() time.Time { + return time.Date(2025, 10, 17, 18, 0, 0, 0, time.UTC) +} + +// generateTestKey creates a test ECDSA signing key. +func generateTestKey() *ecdsa.PrivateKey { + key, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + return key +} + +// generateSignedHash creates a signed hash for a given identity. +func generateSignedHash(identity connector.Identity, signingKey *ecdsa.PrivateKey) (string, error) { + h := sha3.New512() + h.Write([]byte(identity.Email)) + for _, g := range identity.Groups { + h.Write([]byte(g)) + } + h.Write([]byte(identity.UserID)) + h.Write([]byte(identity.Username)) + h.Write([]byte(identity.PreferredUsername)) + hash := fmt.Sprintf("%x", h.Sum(nil)) + + signAlg, _ := jwt.SignatureAlgorithm(&jose.JSONWebKey{Key: signingKey}) + signedBytes, err := jwt.SignPayload(&jose.JSONWebKey{Key: signingKey}, signAlg, []byte(hash)) + if err != nil { + return "", err + } + // Explicitly encode as []byte. + return base64.RawURLEncoding.EncodeToString([]byte(signedBytes)), nil +} + +func TestHandleRememberMe(t *testing.T) { + // Common test fixtures. + ctx := context.Background() + connectorName := "test-connector" + expiryDuration := 24 * time.Hour + testKey := generateTestKey() + mockStore := &mockStorage{ + keys: storage.Keys{ + SigningKey: &jose.JSONWebKey{Key: testKey}, + }, + } + mockSessionStore := &mockSessionStorage{ + sessions: make(map[string]storage.ActiveSession), + } + + // Helper to create a request with optional cookie. + newRequest := func(cookieValue string) *http.Request { + req := httptest.NewRequest("GET", "/auth", nil) + if cookieValue != "" { + cookieName := connector_cookie_name(connectorName) + req.AddCookie(&http.Cookie{Name: cookieName, Value: cookieValue}) + } + return req + } + + // Sample identity. + identity := connector.Identity{ + UserID: "user123", + Username: "testuser", + PreferredUsername: "testuser", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"group1"}, + } + + t.Run("No cookie, anonymous context", func(t *testing.T) { + req := newRequest("") // No cookie. + data := NewAnonymousAuthContext(connectorName, expiryDuration) + ctx, err := HandleRememberMe(ctx, noOpLogger, req, data, mockStore, mockSessionStore) + if !errors.Is(err, storage.ErrNotFound) { + t.Errorf("Expected ErrNotFound, got: %v", err) + } + if ctx != nil { + t.Error("Expected nil context") + } + }) + + t.Run("No cookie, with identity: Create new session and set cookie", func(t *testing.T) { + req := newRequest("") // No cookie. + data := NewAuthContextWithIdentity(connectorName, identity, expiryDuration) + rmCtx, err := HandleRememberMe(ctx, noOpLogger, req, data, mockStore, mockSessionStore) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if !rmCtx.IsValid() { + t.Error("Expected valid session") + } + if rmCtx.Session.Identity.UserID != identity.UserID { + t.Errorf("Expected UserID %s, got %s", identity.UserID, rmCtx.Session.Identity.UserID) + } + if rmCtx.Cookie.Empty() || rmCtx.Cookie.unset { + t.Error("Expected cookie to be set") + } + cookie, _ := rmCtx.Cookie.Get() + if cookie.Name != connector_cookie_name(connectorName) { + t.Errorf("Unexpected cookie name: %s", cookie.Name) + } + if cookie.Value == "" { + t.Error("Expected non-empty cookie value") + } + if cookie.Expires.Before(fixedNow().Add(expiryDuration - time.Minute)) { // Allow some leeway. + t.Errorf("Unexpected expiry: %v", cookie.Expires) + } + // Verify session was stored. + storedSession, err := mockSessionStore.GetSession(ctx, cookie.Value) + if err != nil || storedSession.Identity.UserID != identity.UserID { + t.Error("Session not stored correctly") + } + }) + + t.Run("Valid cookie with active session", func(t *testing.T) { + // Pre-populate a session. + signedHash, err := generateSignedHash(identity, testKey) + if err != nil { + t.Fatal(err) + } + session := storage.ActiveSession{ + Identity: identity, + Expiry: fixedNow().Add(expiryDuration), + } + mockSessionStore.sessions[signedHash] = session + + req := newRequest(signedHash) + data := NewAnonymousAuthContext(connectorName, expiryDuration) + rmCtx, err := HandleRememberMe(ctx, noOpLogger, req, data, mockStore, mockSessionStore) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if !rmCtx.IsValid() { + t.Error("Expected valid session") + } + if rmCtx.Session.Identity.UserID != identity.UserID { + t.Errorf("Expected UserID %s, got %s", identity.UserID, rmCtx.Session.Identity.UserID) + } + if !rmCtx.Cookie.Empty() { + t.Error("Expected no cookie change (empty GetOrUnsetCookie)") + } + }) + + t.Run("Cookie present, expired session: Unset cookie", func(t *testing.T) { + // Pre-populate an expired session. + signedHash, err := generateSignedHash(identity, testKey) + if err != nil { + t.Fatal(err) + } + session := storage.ActiveSession{ + Identity: identity, + Expiry: fixedNow().Add(-time.Hour), // Expired. + } + mockSessionStore.sessions[signedHash] = session + + req := newRequest(signedHash) + data := NewAnonymousAuthContext(connectorName, expiryDuration) + rmCtx, err := HandleRememberMe(ctx, noOpLogger, req, data, mockStore, mockSessionStore) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if rmCtx.IsValid() { + t.Error("Expected invalid (expired) session") + } + if rmCtx.Cookie.Empty() || !rmCtx.Cookie.unset { + t.Error("Expected cookie to be unset") + } + cookie, unset := rmCtx.Cookie.Get() + if !unset || cookie.Name != connector_cookie_name(connectorName) || cookie.MaxAge != -1 { + t.Error("Unexpected unset cookie") + } + }) + + t.Run("Cookie present, session not found: Unset cookie", func(t *testing.T) { + signedHash := "invalid-hash" + req := newRequest(signedHash) + data := NewAnonymousAuthContext(connectorName, expiryDuration) + rmCtx, err := HandleRememberMe(ctx, noOpLogger, req, data, mockStore, mockSessionStore) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if rmCtx.IsValid() { + t.Error("Expected invalid session") + } + if rmCtx.Cookie.Empty() || !rmCtx.Cookie.unset { + t.Error("Expected cookie to be unset") + } + }) + + t.Run("Error: Failed to get keys", func(t *testing.T) { + mockStore.err = errors.New("key error") + req := newRequest("") + data := NewAuthContextWithIdentity(connectorName, identity, expiryDuration) + _, err := HandleRememberMe(ctx, noOpLogger, req, data, mockStore, mockSessionStore) + if err == nil || !errors.Is(err, mockStore.err) { + t.Errorf("Expected key error, got: %v", err) + } + mockStore.err = nil // Reset. + }) + + t.Run("Error: Failed to create session", func(t *testing.T) { + mockSessionStore.createErr = errors.New("create error") + req := newRequest("") + data := NewAuthContextWithIdentity(connectorName, identity, expiryDuration) + _, err := HandleRememberMe(ctx, noOpLogger, req, data, mockStore, mockSessionStore) + if err == nil || !errors.Is(err, mockSessionStore.createErr) { + t.Errorf("Expected create error, got: %v", err) + } + mockSessionStore.createErr = nil // Reset. + }) + + t.Run("Error: Failed to get session", func(t *testing.T) { + mockSessionStore.getErr = errors.New("get error") + signedHash, _ := generateSignedHash(identity, testKey) + req := newRequest(signedHash) + data := NewAnonymousAuthContext(connectorName, expiryDuration) + _, err := HandleRememberMe(ctx, noOpLogger, req, data, mockStore, mockSessionStore) + if err == nil || !errors.Is(err, mockSessionStore.getErr) { + t.Errorf("Expected get error, got: %v", err) + } + mockSessionStore.getErr = nil // Reset. + }) +} + +func TestExtractCookie(t *testing.T) { + connectorName := "test-connector" + req := httptest.NewRequest("GET", "/auth", nil) + req.AddCookie(&http.Cookie{Name: connector_cookie_name(connectorName), Value: "test-value"}) + req.AddCookie(&http.Cookie{Name: "other-cookie", Value: "ignored"}) + + value, found := extractCookie(req, connectorName) + if !found || value != "test-value" { + t.Errorf("Expected value 'test-value', got: %s (found: %v)", value, found) + } + + // No cookie. + req = httptest.NewRequest("GET", "/auth", nil) + value, found = extractCookie(req, connectorName) + if found || value != "" { + t.Error("Expected not found") + } +} + +func TestConnectorCookieName(t *testing.T) { + if name := connector_cookie_name("test"); name != "dex_active_session_cookie_test" { + t.Errorf("Unexpected name: %s", name) + } +} diff --git a/server/handlers.go b/server/handlers.go index f8d0ed64..c227251f 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -7,6 +7,7 @@ import ( "crypto/subtle" "encoding/base64" "encoding/json" + "errors" "fmt" "html/template" "net/http" @@ -22,6 +23,7 @@ import ( "github.com/gorilla/mux" "github.com/dexidp/dex/connector" + rememberme "github.com/dexidp/dex/remember-me" "github.com/dexidp/dex/server/internal" "github.com/dexidp/dex/storage" ) @@ -321,6 +323,11 @@ func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) { s.renderError(r, w, http.StatusBadRequest, "User session error.") return } + if s.enableRememberMe { + s.logger.InfoContext(r.Context(), "enableRememberMe enabled") + } else { + s.logger.InfoContext(r.Context(), "enableRememberMe disabled") + } backLink := r.URL.Query().Get("back") @@ -363,6 +370,48 @@ func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: + if s.enableRememberMe { + rememberData, err := rememberme.HandleRememberMe(ctx, s.logger, r, rememberme.NewAnonymousAuthContext(authReq.ConnectorID, s.idTokensValidFor), s.storage, s.sessionStorage) + if err != nil { + if !errors.Is(err, storage.ErrNotFound) { + s.logger.ErrorContext(r.Context(), "failed to call HandleRememberMe handler", "err", err) + s.renderError(r, w, http.StatusInternalServerError, "TODO") + return + } + } + if !errors.Is(err, storage.ErrNotFound) { + s.logger.DebugContext(r.Context(), "returning user session was found") + if !rememberData.Cookie.Empty() { + cookie, unset := rememberData.Cookie.Get() + if unset { + s.logger.DebugContext(r.Context(), "unsetting cookie", "cookie", *cookie) + http.SetCookie(w, cookie) + } + } + if rememberData.IsValid() { + redirectURL, canSkipApproval, err := s.finalizeLogin(ctx, rememberData.Session.Identity, authReq, conn) + if err != nil { + s.logger.ErrorContext(r.Context(), "failed to finalize login using rememberme", "err", err) + s.renderError(r, w, http.StatusInternalServerError, "Login error.") + return + } + + if canSkipApproval { + authReq, err = s.storage.GetAuthRequest(ctx, authReq.ID) + if err != nil { + s.logger.ErrorContext(r.Context(), "failed to get finalized auth request using rememberme", "err", err) + s.renderError(r, w, http.StatusInternalServerError, "Login error.") + return + } + s.sendCodeResponse(w, r, authReq) + return + } + + http.Redirect(w, r, redirectURL, http.StatusSeeOther) + return + } + } + } if err := s.templates.password(r, w, r.URL.String(), "", usernamePrompt(pwConn), false, backLink); err != nil { s.logger.ErrorContext(r.Context(), "server template error", "err", err) } @@ -390,6 +439,23 @@ func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) { s.renderError(r, w, http.StatusInternalServerError, "Login error.") return } + if s.enableRememberMe { + rememberData, err := rememberme.HandleRememberMe(ctx, s.logger, r, rememberme.NewAuthContextWithIdentity(authReq.ConnectorID, identity, s.idTokensValidFor), s.storage, s.sessionStorage) + if err != nil { + if !errors.Is(err, storage.ErrNotFound) { + s.logger.ErrorContext(r.Context(), "failed to call HandleRememberMe handler", "err", err) + s.renderError(r, w, http.StatusInternalServerError, "TODO") + return + } + } + if !errors.Is(err, storage.ErrNotFound) { + s.logger.ErrorContext(r.Context(), "did find returning user") + if !rememberData.Cookie.Empty() { + cookie, _ := rememberData.Cookie.Get() + http.SetCookie(w, cookie) // this sets or unsets the cookie based on it's content + } + } + } if canSkipApproval { authReq, err = s.storage.GetAuthRequest(ctx, authReq.ID) diff --git a/server/oauth2.go b/server/oauth2.go index 18cc3dd4..13fb5fc4 100644 --- a/server/oauth2.go +++ b/server/oauth2.go @@ -17,6 +17,7 @@ import ( "net" "net/http" "net/url" + "slices" "strconv" "strings" "time" @@ -259,15 +260,6 @@ func accessTokenHash(alg jose.SignatureAlgorithm, accessToken string) (string, e type audience []string -func (a audience) contains(aud string) bool { - for _, e := range a { - if aud == e { - return true - } - } - return false -} - func (a audience) MarshalJSON() ([]byte, error) { if len(a) == 1 { return json.Marshal(a[0]) @@ -333,7 +325,7 @@ func getAudience(clientID string, scopes []string) audience { aud = audience{clientID} // Client asked for cross client audience: // if the current client was not requested explicitly - } else if !aud.contains(clientID) { + } else if !slices.Contains(aud, clientID) { // by default it becomes one of entries in Audience aud = append(aud, clientID) } diff --git a/server/server.go b/server/server.go index 70e8ae75..35513542 100644 --- a/server/server.go +++ b/server/server.go @@ -46,6 +46,7 @@ import ( "github.com/dexidp/dex/connector/openshift" "github.com/dexidp/dex/connector/saml" "github.com/dexidp/dex/storage" + "github.com/dexidp/dex/storage/memory" "github.com/dexidp/dex/web" ) @@ -65,6 +66,9 @@ type Connector struct { type Config struct { Issuer string + // Use cookies and keep active sessions in storage + EnableRememberMe bool + // The backing persistence layer. Storage storage.Storage @@ -172,7 +176,8 @@ type Server struct { // Map of connector IDs to connectors. connectors map[string]Connector - storage storage.Storage + storage storage.Storage + sessionStorage storage.ActiveSessionStorage mux http.Handler @@ -181,6 +186,8 @@ type Server struct { // If enabled, don't prompt user for approval after logging in through connector. skipApproval bool + enableRememberMe bool + // If enabled, show the connector selection screen even if there's only one alwaysShowLogin bool @@ -316,6 +323,10 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) templates: tmpls, passwordConnector: c.PasswordConnector, logger: c.Logger, + enableRememberMe: c.EnableRememberMe, + } + if c.EnableRememberMe { + s.sessionStorage = memory.NewSessionStore(s.logger) } // Retrieves connector objects in backend storage. This list includes the static connectors @@ -647,6 +658,14 @@ func (s *Server) startGarbageCollection(ctx context.Context, frequency time.Dura "requests", r.AuthRequests, "auth_codes", r.AuthCodes, "device_requests", r.DeviceRequests, "device_tokens", r.DeviceTokens) } + if s.sessionStorage != nil { + if r, err := s.sessionStorage.GarbageCollect(ctx, now()); err != nil { + s.logger.ErrorContext(ctx, "garbage collection for session storage failed", "err", err) + } else if !r.IsEmpty() { + s.logger.InfoContext(ctx, "garbage collection for session storage run", + "sessions", r.Sessions) + } + } } } }() diff --git a/storage/conformance/gen_jwks.go b/storage/conformance/gen_jwks.go index 0029b9b8..b5affcc8 100644 --- a/storage/conformance/gen_jwks.go +++ b/storage/conformance/gen_jwks.go @@ -105,7 +105,7 @@ func main() { if err != nil { log.Fatalf("gofmt failed: %v", err) } - if err := os.WriteFile("jwks.go", out, 0644); err != nil { + if err := os.WriteFile("jwks.go", out, 0o644); err != nil { log.Fatal(err) } } diff --git a/storage/memory/memory.go b/storage/memory/memory.go index eff75e71..0461d6dd 100644 --- a/storage/memory/memory.go +++ b/storage/memory/memory.go @@ -12,6 +12,7 @@ import ( ) var _ storage.Storage = (*memStorage)(nil) +var _ storage.ActiveSessionStorage = (*memStorage)(nil) // New returns an in memory storage. func New(logger *slog.Logger) storage.Storage { @@ -25,10 +26,18 @@ func New(logger *slog.Logger) storage.Storage { connectors: make(map[string]storage.Connector), deviceRequests: make(map[string]storage.DeviceRequest), deviceTokens: make(map[string]storage.DeviceToken), + sessions: make(map[string]storage.ActiveSession), logger: logger, } } +func NewSessionStore(logger *slog.Logger) *memStorage { + return &memStorage{ + sessions: make(map[string]storage.ActiveSession), + logger: logger, + } +} + // Config is an implementation of a storage configuration. // // TODO(ericchiang): Actually define a storage config interface and have registration. @@ -52,12 +61,37 @@ type memStorage struct { connectors map[string]storage.Connector deviceRequests map[string]storage.DeviceRequest deviceTokens map[string]storage.DeviceToken + sessions map[string]storage.ActiveSession keys storage.Keys logger *slog.Logger } +// CreateSession implements storage.ActiveSessionStorage. +func (s *memStorage) CreateSession(ctx context.Context, identifier string, data storage.ActiveSession) (err error) { + s.tx(func() { + if _, ok := s.sessions[identifier]; ok { + err = storage.ErrAlreadyExists + } else { + s.sessions[identifier] = data + } + }) + return +} + +// GetSession implements storage.ActiveSessionStorage. +func (s *memStorage) GetSession(ctx context.Context, identifier string) (session storage.ActiveSession, err error) { + s.tx(func() { + var ok bool + if session, ok = s.sessions[identifier]; !ok { + err = storage.ErrNotFound + return + } + }) + return +} + type offlineSessionID struct { userID string connID string @@ -97,6 +131,12 @@ func (s *memStorage) GarbageCollect(ctx context.Context, now time.Time) (result result.DeviceTokens++ } } + for id, a := range s.sessions { + if now.After(a.Expiry) { + delete(s.sessions, id) + result.Sessions++ + } + } }) return result, nil } diff --git a/storage/storage.go b/storage/storage.go index 574b0a5a..9f0b5f63 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -12,6 +12,8 @@ import ( "time" "github.com/go-jose/go-jose/v4" + + "github.com/dexidp/dex/connector" ) var ( @@ -60,6 +62,7 @@ type GCResult struct { AuthCodes int64 DeviceRequests int64 DeviceTokens int64 + Sessions int64 } // IsEmpty returns whether the garbage collection result is empty or not. @@ -67,7 +70,8 @@ func (g *GCResult) IsEmpty() bool { return g.AuthRequests == 0 && g.AuthCodes == 0 && g.DeviceRequests == 0 && - g.DeviceTokens == 0 + g.DeviceTokens == 0 && + g.Sessions == 0 } // Storage is the storage interface used by the server. Implementations are @@ -142,6 +146,19 @@ type Storage interface { GarbageCollect(ctx context.Context, now time.Time) (GCResult, error) } +type ActiveSessionStorage interface { + GetSession(ctx context.Context, identifier string) (ActiveSession, error) + CreateSession(ctx context.Context, identifier string, data ActiveSession) error + GarbageCollect(ctx context.Context, now time.Time) (GCResult, error) +} + +type ActiveSession struct { + // TODO(juf): Think about storing only claim/identity data or reference to OfflineSession + // and create a new token every time instead + Expiry time.Time + Identity connector.Identity +} + // Client represents an OAuth2 client. // // For further reading see: From 1885d50e809a3d214ab70fe5e92ff489eadeb1ea Mon Sep 17 00:00:00 2001 From: Julius Foitzik Date: Sat, 18 Oct 2025 11:17:19 +0200 Subject: [PATCH 2/9] snapshot Signed-off-by: Julius Foitzik --- internal/jwt/keyset.go | 56 ++++++++++++++++++++++++++++++++++ remember-me/handler.go | 16 +++++++++- remember-me/handler_test.go | 29 +++++++++--------- server/handlers.go | 9 +++++- server/introspectionhandler.go | 3 +- server/oauth2.go | 38 ----------------------- server/oauth2_test.go | 3 +- storage/memory/memory.go | 6 ++-- 8 files changed, 102 insertions(+), 58 deletions(-) create mode 100644 internal/jwt/keyset.go diff --git a/internal/jwt/keyset.go b/internal/jwt/keyset.go new file mode 100644 index 00000000..0d6ddcaa --- /dev/null +++ b/internal/jwt/keyset.go @@ -0,0 +1,56 @@ +package jwt + +import ( + "context" + "errors" + + "github.com/go-jose/go-jose/v4" + + "github.com/dexidp/dex/storage" +) + +var ErrFailedVerify = errors.New("failed to verify id token signature") + +// StorageKeySet implements the oidc.KeySet interface backed by Dex storage +type StorageKeySet struct { + storage.Storage +} + +func NewStorageKeySet(store storage.Storage) *StorageKeySet { + return &StorageKeySet{ + store, + } +} + +func (s *StorageKeySet) VerifySignature(ctx context.Context, jwt string) (payload []byte, err error) { + jws, err := jose.ParseSigned(jwt, []jose.SignatureAlgorithm{jose.RS256, jose.RS384, jose.RS512, jose.ES256, jose.ES384, jose.ES512}) + if err != nil { + return nil, err + } + + keyID := "" + for _, sig := range jws.Signatures { + keyID = sig.Header.KeyID + break + } + + skeys, err := s.Storage.GetKeys(ctx) + if err != nil { + return nil, err + } + + keys := []*jose.JSONWebKey{skeys.SigningKeyPub} + for _, vk := range skeys.VerificationKeys { + keys = append(keys, vk.PublicKey) + } + + for _, key := range keys { + if keyID == "" || key.KeyID == keyID { + if payload, err := jws.Verify(key); err == nil { + return payload, nil + } + } + } + + return nil, ErrFailedVerify +} diff --git a/remember-me/handler.go b/remember-me/handler.go index ab790a9b..f2247a4c 100644 --- a/remember-me/handler.go +++ b/remember-me/handler.go @@ -16,6 +16,8 @@ import ( const ACTIVE_SESSION_COOKIE_NAME = "dex_active_session_cookie" +var emptySession = storage.ActiveSession{} + type AuthContext struct { connectorName string identity *connector.Identity @@ -72,6 +74,10 @@ func connector_cookie_name(connName string) string { return fmt.Sprintf("%s_%s", ACTIVE_SESSION_COOKIE_NAME, connName) } +// HandleRememberMe either retrieves or creates a Session based on the cookie for the respective connector present in the http.Request. +// It is also responsible for issuing the unsetting / expiration of either an invalid or expired cookie. +// +// The current "design" of the cookie is a sha3 hash of the connector.Identity object as JWK signed payload. func HandleRememberMe(ctx context.Context, logger *slog.Logger, req *http.Request, data AuthContext, store storage.Storage, sessionStore storage.ActiveSessionStorage) (*RememberMeCtx, error) { keys, err := store.GetKeys(ctx) if err != nil { @@ -86,6 +92,15 @@ func HandleRememberMe(ctx context.Context, logger *slog.Logger, req *http.Reques if val, found := extractCookie(req, data.connectorName); found { cookieName := connector_cookie_name(data.connectorName) logger.DebugContext(req.Context(), "returning user cookie found, checking for active session", "connectorName", data.connectorName) + keyset := jwt.NewStorageKeySet(store) + logger.DebugContext(req.Context(), "verifying cookie", "connectorName", data.connectorName) + _, err := keyset.VerifySignature(ctx, val) + if err != nil { + return &RememberMeCtx{ + Session: emptySession, + Cookie: RequestUnsetCookie(cookieName), + }, err + } session, err := sessionStore.GetSession(ctx, val) if err != nil { if errors.Is(err, storage.ErrNotFound) { @@ -141,7 +156,6 @@ func HandleRememberMe(ctx context.Context, logger *slog.Logger, req *http.Reques return nil, err } - // TODO(juf): SET COOKIE return &RememberMeCtx{ Session: session, Cookie: RequestSetCookie(http.Cookie{ diff --git a/remember-me/handler_test.go b/remember-me/handler_test.go index 88398609..c63586a8 100644 --- a/remember-me/handler_test.go +++ b/remember-me/handler_test.go @@ -15,14 +15,17 @@ import ( "testing" "time" + "github.com/go-jose/go-jose/v4" + "github.com/dexidp/dex/connector" "github.com/dexidp/dex/internal/jwt" "github.com/dexidp/dex/storage" - "github.com/go-jose/go-jose/v4" ) -var _ storage.Storage = (*mockStorage)(nil) -var _ storage.ActiveSessionStorage = (*mockStorage)(nil) +var ( + _ storage.Storage = (*mockStorage)(nil) + _ storage.ActiveSessionStorage = (*mockStorage)(nil) +) // mockStorage implements storage.Storage for testing key retrieval. type mockStorage struct { @@ -241,7 +244,6 @@ type mockSessionStorage struct { sessions map[string]storage.ActiveSession getErr error createErr error - gcErr error } // GarbageCollect implements storage.ActiveSessionStorage. @@ -268,8 +270,7 @@ func (m *mockSessionStorage) CreateSession(ctx context.Context, id string, sessi return nil } -// noOpLogger is a silent logger for tests (level higher than Error to suppress output). -var noOpLogger = slog.New(slog.NewTextHandler(nil, &slog.HandlerOptions{Level: slog.Level(slog.LevelError) + 1})) +var testLogger = slog.New(slog.NewTextHandler(nil, &slog.HandlerOptions{Level: slog.LevelError})) // fixedNow returns a fixed time for deterministic testing. func fixedNow() time.Time { @@ -341,7 +342,7 @@ func TestHandleRememberMe(t *testing.T) { t.Run("No cookie, anonymous context", func(t *testing.T) { req := newRequest("") // No cookie. data := NewAnonymousAuthContext(connectorName, expiryDuration) - ctx, err := HandleRememberMe(ctx, noOpLogger, req, data, mockStore, mockSessionStore) + ctx, err := HandleRememberMe(ctx, testLogger, req, data, mockStore, mockSessionStore) if !errors.Is(err, storage.ErrNotFound) { t.Errorf("Expected ErrNotFound, got: %v", err) } @@ -353,7 +354,7 @@ func TestHandleRememberMe(t *testing.T) { t.Run("No cookie, with identity: Create new session and set cookie", func(t *testing.T) { req := newRequest("") // No cookie. data := NewAuthContextWithIdentity(connectorName, identity, expiryDuration) - rmCtx, err := HandleRememberMe(ctx, noOpLogger, req, data, mockStore, mockSessionStore) + rmCtx, err := HandleRememberMe(ctx, testLogger, req, data, mockStore, mockSessionStore) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -397,7 +398,7 @@ func TestHandleRememberMe(t *testing.T) { req := newRequest(signedHash) data := NewAnonymousAuthContext(connectorName, expiryDuration) - rmCtx, err := HandleRememberMe(ctx, noOpLogger, req, data, mockStore, mockSessionStore) + rmCtx, err := HandleRememberMe(ctx, testLogger, req, data, mockStore, mockSessionStore) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -426,7 +427,7 @@ func TestHandleRememberMe(t *testing.T) { req := newRequest(signedHash) data := NewAnonymousAuthContext(connectorName, expiryDuration) - rmCtx, err := HandleRememberMe(ctx, noOpLogger, req, data, mockStore, mockSessionStore) + rmCtx, err := HandleRememberMe(ctx, testLogger, req, data, mockStore, mockSessionStore) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -446,7 +447,7 @@ func TestHandleRememberMe(t *testing.T) { signedHash := "invalid-hash" req := newRequest(signedHash) data := NewAnonymousAuthContext(connectorName, expiryDuration) - rmCtx, err := HandleRememberMe(ctx, noOpLogger, req, data, mockStore, mockSessionStore) + rmCtx, err := HandleRememberMe(ctx, testLogger, req, data, mockStore, mockSessionStore) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -462,7 +463,7 @@ func TestHandleRememberMe(t *testing.T) { mockStore.err = errors.New("key error") req := newRequest("") data := NewAuthContextWithIdentity(connectorName, identity, expiryDuration) - _, err := HandleRememberMe(ctx, noOpLogger, req, data, mockStore, mockSessionStore) + _, err := HandleRememberMe(ctx, testLogger, req, data, mockStore, mockSessionStore) if err == nil || !errors.Is(err, mockStore.err) { t.Errorf("Expected key error, got: %v", err) } @@ -473,7 +474,7 @@ func TestHandleRememberMe(t *testing.T) { mockSessionStore.createErr = errors.New("create error") req := newRequest("") data := NewAuthContextWithIdentity(connectorName, identity, expiryDuration) - _, err := HandleRememberMe(ctx, noOpLogger, req, data, mockStore, mockSessionStore) + _, err := HandleRememberMe(ctx, testLogger, req, data, mockStore, mockSessionStore) if err == nil || !errors.Is(err, mockSessionStore.createErr) { t.Errorf("Expected create error, got: %v", err) } @@ -485,7 +486,7 @@ func TestHandleRememberMe(t *testing.T) { signedHash, _ := generateSignedHash(identity, testKey) req := newRequest(signedHash) data := NewAnonymousAuthContext(connectorName, expiryDuration) - _, err := HandleRememberMe(ctx, noOpLogger, req, data, mockStore, mockSessionStore) + _, err := HandleRememberMe(ctx, testLogger, req, data, mockStore, mockSessionStore) if err == nil || !errors.Is(err, mockSessionStore.getErr) { t.Errorf("Expected get error, got: %v", err) } diff --git a/server/handlers.go b/server/handlers.go index c227251f..47590bde 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -23,6 +23,7 @@ import ( "github.com/gorilla/mux" "github.com/dexidp/dex/connector" + "github.com/dexidp/dex/internal/jwt" rememberme "github.com/dexidp/dex/remember-me" "github.com/dexidp/dex/server/internal" "github.com/dexidp/dex/storage" @@ -374,6 +375,12 @@ func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) { rememberData, err := rememberme.HandleRememberMe(ctx, s.logger, r, rememberme.NewAnonymousAuthContext(authReq.ConnectorID, s.idTokensValidFor), s.storage, s.sessionStorage) if err != nil { if !errors.Is(err, storage.ErrNotFound) { + if rememberData != nil && !rememberData.Cookie.Empty() { + // Overwrite or unset the cookie in certain error cases to allow for "natural" + // recovery, e.g., the cookie is present but malformatted, then it should be unset. + cookie, _ := rememberData.Cookie.Get() + http.SetCookie(w, cookie) + } s.logger.ErrorContext(r.Context(), "failed to call HandleRememberMe handler", "err", err) s.renderError(r, w, http.StatusInternalServerError, "TODO") return @@ -1165,7 +1172,7 @@ func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) { } rawIDToken := auth[len(prefix):] - verifier := oidc.NewVerifier(s.issuerURL.String(), &storageKeySet{s.storage}, &oidc.Config{SkipClientIDCheck: true}) + verifier := oidc.NewVerifier(s.issuerURL.String(), jwt.NewStorageKeySet(s.storage), &oidc.Config{SkipClientIDCheck: true}) idToken, err := verifier.Verify(ctx, rawIDToken) if err != nil { s.tokenErrHelper(w, errAccessDenied, err.Error(), http.StatusForbidden) diff --git a/server/introspectionhandler.go b/server/introspectionhandler.go index 42ad1b3c..747dd735 100644 --- a/server/introspectionhandler.go +++ b/server/introspectionhandler.go @@ -9,6 +9,7 @@ import ( "github.com/coreos/go-oidc/v3/oidc" + "github.com/dexidp/dex/internal/jwt" "github.com/dexidp/dex/server/internal" ) @@ -245,7 +246,7 @@ func (s *Server) introspectRefreshToken(ctx context.Context, token string) (*Int } func (s *Server) introspectAccessToken(ctx context.Context, token string) (*Introspection, error) { - verifier := oidc.NewVerifier(s.issuerURL.String(), &storageKeySet{s.storage}, &oidc.Config{SkipClientIDCheck: true}) + verifier := oidc.NewVerifier(s.issuerURL.String(), jwt.NewStorageKeySet(s.storage), &oidc.Config{SkipClientIDCheck: true}) idToken, err := verifier.Verify(ctx, token) if err != nil { return nil, newIntrospectInactiveTokenError() diff --git a/server/oauth2.go b/server/oauth2.go index 13fb5fc4..a74cf78c 100644 --- a/server/oauth2.go +++ b/server/oauth2.go @@ -694,41 +694,3 @@ func validateConnectorID(connectors []storage.Connector, connectorID string) boo } return false } - -// storageKeySet implements the oidc.KeySet interface backed by Dex storage -type storageKeySet struct { - storage.Storage -} - -func (s *storageKeySet) VerifySignature(ctx context.Context, jwt string) (payload []byte, err error) { - jws, err := jose.ParseSigned(jwt, []jose.SignatureAlgorithm{jose.RS256, jose.RS384, jose.RS512, jose.ES256, jose.ES384, jose.ES512}) - if err != nil { - return nil, err - } - - keyID := "" - for _, sig := range jws.Signatures { - keyID = sig.Header.KeyID - break - } - - skeys, err := s.Storage.GetKeys(ctx) - if err != nil { - return nil, err - } - - keys := []*jose.JSONWebKey{skeys.SigningKeyPub} - for _, vk := range skeys.VerificationKeys { - keys = append(keys, vk.PublicKey) - } - - for _, key := range keys { - if keyID == "" || key.KeyID == keyID { - if payload, err := jws.Verify(key); err == nil { - return payload, nil - } - } - } - - return nil, errors.New("failed to verify id token signature") -} diff --git a/server/oauth2_test.go b/server/oauth2_test.go index 3dff30d6..e9aecf4f 100644 --- a/server/oauth2_test.go +++ b/server/oauth2_test.go @@ -12,6 +12,7 @@ import ( "github.com/go-jose/go-jose/v4" "github.com/stretchr/testify/require" + jwt_lib "github.com/dexidp/dex/internal/jwt" "github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage/memory" ) @@ -668,7 +669,7 @@ func TestStorageKeySet(t *testing.T) { t.Fatal(err) } - keySet := &storageKeySet{s} + keySet := jwt_lib.NewStorageKeySet(s) _, err = keySet.VerifySignature(t.Context(), jwt) if (err != nil && !tc.wantErr) || (err == nil && tc.wantErr) { diff --git a/storage/memory/memory.go b/storage/memory/memory.go index 0461d6dd..9eb7ea25 100644 --- a/storage/memory/memory.go +++ b/storage/memory/memory.go @@ -11,8 +11,10 @@ import ( "github.com/dexidp/dex/storage" ) -var _ storage.Storage = (*memStorage)(nil) -var _ storage.ActiveSessionStorage = (*memStorage)(nil) +var ( + _ storage.Storage = (*memStorage)(nil) + _ storage.ActiveSessionStorage = (*memStorage)(nil) +) // New returns an in memory storage. func New(logger *slog.Logger) storage.Storage { From a080b3367013347a1bc0634b4bf4ad9a2ff1bc82 Mon Sep 17 00:00:00 2001 From: Julius Foitzik Date: Sat, 18 Oct 2025 11:44:25 +0200 Subject: [PATCH 3/9] snapshot Signed-off-by: Julius Foitzik --- remember-me/handler_test.go | 757 +++++++++++++++--------------------- 1 file changed, 305 insertions(+), 452 deletions(-) diff --git a/remember-me/handler_test.go b/remember-me/handler_test.go index c63586a8..640c94f1 100644 --- a/remember-me/handler_test.go +++ b/remember-me/handler_test.go @@ -5,10 +5,7 @@ import ( "crypto/ecdsa" "crypto/elliptic" "crypto/rand" - "crypto/sha3" - "encoding/base64" "errors" - "fmt" "log/slog" "net/http" "net/http/httptest" @@ -16,505 +13,361 @@ import ( "time" "github.com/go-jose/go-jose/v4" + "github.com/stretchr/testify/require" "github.com/dexidp/dex/connector" - "github.com/dexidp/dex/internal/jwt" "github.com/dexidp/dex/storage" + "github.com/dexidp/dex/storage/memory" ) -var ( - _ storage.Storage = (*mockStorage)(nil) - _ storage.ActiveSessionStorage = (*mockStorage)(nil) -) - -// mockStorage implements storage.Storage for testing key retrieval. -type mockStorage struct { - keys storage.Keys - err error -} - -// CreateSession implements storage.ActiveSessionStorage. -func (m *mockStorage) CreateSession(ctx context.Context, identifier string, data storage.ActiveSession) error { - panic("unimplemented") -} - -// GetSession implements storage.ActiveSessionStorage. -func (m *mockStorage) GetSession(ctx context.Context, identifier string) (storage.ActiveSession, error) { - panic("unimplemented") -} - -// CreateAuthCode implements storage.Storage. -func (m *mockStorage) CreateAuthCode(ctx context.Context, c storage.AuthCode) error { - panic("unimplemented") -} - -// CreateAuthRequest implements storage.Storage. -func (m *mockStorage) CreateAuthRequest(ctx context.Context, a storage.AuthRequest) error { - panic("unimplemented") -} - -// CreateClient implements storage.Storage. -func (m *mockStorage) CreateClient(ctx context.Context, c storage.Client) error { - panic("unimplemented") -} - -// CreateConnector implements storage.Storage. -func (m *mockStorage) CreateConnector(ctx context.Context, c storage.Connector) error { - panic("unimplemented") -} - -// CreateDeviceRequest implements storage.Storage. -func (m *mockStorage) CreateDeviceRequest(ctx context.Context, d storage.DeviceRequest) error { - panic("unimplemented") -} - -// CreateDeviceToken implements storage.Storage. -func (m *mockStorage) CreateDeviceToken(ctx context.Context, d storage.DeviceToken) error { - panic("unimplemented") -} - -// CreateOfflineSessions implements storage.Storage. -func (m *mockStorage) CreateOfflineSessions(ctx context.Context, s storage.OfflineSessions) error { - panic("unimplemented") -} - -// CreatePassword implements storage.Storage. -func (m *mockStorage) CreatePassword(ctx context.Context, p storage.Password) error { - panic("unimplemented") -} - -// CreateRefresh implements storage.Storage. -func (m *mockStorage) CreateRefresh(ctx context.Context, r storage.RefreshToken) error { - panic("unimplemented") -} - -// DeleteAuthCode implements storage.Storage. -func (m *mockStorage) DeleteAuthCode(ctx context.Context, code string) error { - panic("unimplemented") -} - -// DeleteAuthRequest implements storage.Storage. -func (m *mockStorage) DeleteAuthRequest(ctx context.Context, id string) error { - panic("unimplemented") -} - -// DeleteClient implements storage.Storage. -func (m *mockStorage) DeleteClient(ctx context.Context, id string) error { - panic("unimplemented") -} - -// DeleteConnector implements storage.Storage. -func (m *mockStorage) DeleteConnector(ctx context.Context, id string) error { - panic("unimplemented") -} - -// DeleteOfflineSessions implements storage.Storage. -func (m *mockStorage) DeleteOfflineSessions(ctx context.Context, userID string, connID string) error { - panic("unimplemented") -} - -// DeletePassword implements storage.Storage. -func (m *mockStorage) DeletePassword(ctx context.Context, email string) error { - panic("unimplemented") -} - -// DeleteRefresh implements storage.Storage. -func (m *mockStorage) DeleteRefresh(ctx context.Context, id string) error { - panic("unimplemented") -} - -// GarbageCollect implements storage.Storage. -func (m *mockStorage) GarbageCollect(ctx context.Context, now time.Time) (storage.GCResult, error) { - panic("unimplemented") -} - -// GetAuthCode implements storage.Storage. -func (m *mockStorage) GetAuthCode(ctx context.Context, id string) (storage.AuthCode, error) { - panic("unimplemented") -} - -// GetAuthRequest implements storage.Storage. -func (m *mockStorage) GetAuthRequest(ctx context.Context, id string) (storage.AuthRequest, error) { - panic("unimplemented") -} - -// GetClient implements storage.Storage. -func (m *mockStorage) GetClient(ctx context.Context, id string) (storage.Client, error) { - panic("unimplemented") -} - -// GetConnector implements storage.Storage. -func (m *mockStorage) GetConnector(ctx context.Context, id string) (storage.Connector, error) { - panic("unimplemented") -} - -// GetDeviceRequest implements storage.Storage. -func (m *mockStorage) GetDeviceRequest(ctx context.Context, userCode string) (storage.DeviceRequest, error) { - panic("unimplemented") -} - -// GetDeviceToken implements storage.Storage. -func (m *mockStorage) GetDeviceToken(ctx context.Context, deviceCode string) (storage.DeviceToken, error) { - panic("unimplemented") -} - -// GetOfflineSessions implements storage.Storage. -func (m *mockStorage) GetOfflineSessions(ctx context.Context, userID string, connID string) (storage.OfflineSessions, error) { - panic("unimplemented") -} - -// GetPassword implements storage.Storage. -func (m *mockStorage) GetPassword(ctx context.Context, email string) (storage.Password, error) { - panic("unimplemented") -} - -// GetRefresh implements storage.Storage. -func (m *mockStorage) GetRefresh(ctx context.Context, id string) (storage.RefreshToken, error) { - panic("unimplemented") -} - -// ListClients implements storage.Storage. -func (m *mockStorage) ListClients(ctx context.Context) ([]storage.Client, error) { - panic("unimplemented") -} - -// ListConnectors implements storage.Storage. -func (m *mockStorage) ListConnectors(ctx context.Context) ([]storage.Connector, error) { - panic("unimplemented") -} - -// ListPasswords implements storage.Storage. -func (m *mockStorage) ListPasswords(ctx context.Context) ([]storage.Password, error) { - panic("unimplemented") -} - -// ListRefreshTokens implements storage.Storage. -func (m *mockStorage) ListRefreshTokens(ctx context.Context) ([]storage.RefreshToken, error) { - panic("unimplemented") -} - -// UpdateAuthRequest implements storage.Storage. -func (m *mockStorage) UpdateAuthRequest(ctx context.Context, id string, updater func(a storage.AuthRequest) (storage.AuthRequest, error)) error { - panic("unimplemented") -} - -// UpdateClient implements storage.Storage. -func (m *mockStorage) UpdateClient(ctx context.Context, id string, updater func(old storage.Client) (storage.Client, error)) error { - panic("unimplemented") -} - -// UpdateConnector implements storage.Storage. -func (m *mockStorage) UpdateConnector(ctx context.Context, id string, updater func(c storage.Connector) (storage.Connector, error)) error { - panic("unimplemented") -} - -// UpdateDeviceToken implements storage.Storage. -func (m *mockStorage) UpdateDeviceToken(ctx context.Context, deviceCode string, updater func(t storage.DeviceToken) (storage.DeviceToken, error)) error { - panic("unimplemented") -} - -// UpdateKeys implements storage.Storage. -func (m *mockStorage) UpdateKeys(ctx context.Context, updater func(old storage.Keys) (storage.Keys, error)) error { - panic("unimplemented") -} - -// UpdateOfflineSessions implements storage.Storage. -func (m *mockStorage) UpdateOfflineSessions(ctx context.Context, userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error { - panic("unimplemented") -} - -// UpdatePassword implements storage.Storage. -func (m *mockStorage) UpdatePassword(ctx context.Context, email string, updater func(p storage.Password) (storage.Password, error)) error { - panic("unimplemented") -} - -// UpdateRefreshToken implements storage.Storage. -func (m *mockStorage) UpdateRefreshToken(ctx context.Context, id string, updater func(r storage.RefreshToken) (storage.RefreshToken, error)) error { - panic("unimplemented") -} +// setupTestEnvironment creates a realistic test environment following dex patterns +func setupTestEnvironment(t *testing.T) (storage.Storage, storage.ActiveSessionStorage, *slog.Logger) { + logger := slog.New(slog.NewTextHandler(nil, &slog.HandlerOptions{Level: slog.LevelError})) -func (m *mockStorage) GetKeys(ctx context.Context) (storage.Keys, error) { - return m.keys, m.err -} + // Use dex's standard in-memory storage + store := memory.New(logger) + sessionStore := memory.NewSessionStore(logger) -func (m *mockStorage) Close() error { return nil } + // Initialize with real keys like dex does + ctx := context.Background() + err := store.UpdateKeys(ctx, func(old storage.Keys) (storage.Keys, error) { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) -// mockSessionStorage implements storage.ActiveSessionStorage for testing. -type mockSessionStorage struct { - sessions map[string]storage.ActiveSession - getErr error - createErr error -} + signingKey := &jose.JSONWebKey{Key: key} + signingKeyPub := &jose.JSONWebKey{Key: &key.PublicKey} -// GarbageCollect implements storage.ActiveSessionStorage. -func (m *mockSessionStorage) GarbageCollect(ctx context.Context, now time.Time) (storage.GCResult, error) { - panic("unimplemented") -} + return storage.Keys{ + SigningKey: signingKey, + SigningKeyPub: signingKeyPub, + }, nil + }) + require.NoError(t, err) -func (m *mockSessionStorage) GetSession(ctx context.Context, id string) (storage.ActiveSession, error) { - if m.getErr != nil { - return storage.ActiveSession{}, m.getErr - } - session, ok := m.sessions[id] - if !ok { - return storage.ActiveSession{}, storage.ErrNotFound - } - return session, nil + return store, sessionStore, logger } -func (m *mockSessionStorage) CreateSession(ctx context.Context, id string, session storage.ActiveSession) error { - if m.createErr != nil { - return m.createErr +// createTestRequest creates an HTTP request with optional cookie +func createTestRequest(connectorName string, cookieValue string) *http.Request { + req := httptest.NewRequest("GET", "/auth", nil) + if cookieValue != "" { + cookieName := connector_cookie_name(connectorName) + req.AddCookie(&http.Cookie{Name: cookieName, Value: cookieValue}) } - m.sessions[id] = session - return nil + return req } -var testLogger = slog.New(slog.NewTextHandler(nil, &slog.HandlerOptions{Level: slog.LevelError})) - -// fixedNow returns a fixed time for deterministic testing. -func fixedNow() time.Time { - return time.Date(2025, 10, 17, 18, 0, 0, 0, time.UTC) -} - -// generateTestKey creates a test ECDSA signing key. -func generateTestKey() *ecdsa.PrivateKey { - key, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - return key -} - -// generateSignedHash creates a signed hash for a given identity. -func generateSignedHash(identity connector.Identity, signingKey *ecdsa.PrivateKey) (string, error) { - h := sha3.New512() - h.Write([]byte(identity.Email)) - for _, g := range identity.Groups { - h.Write([]byte(g)) - } - h.Write([]byte(identity.UserID)) - h.Write([]byte(identity.Username)) - h.Write([]byte(identity.PreferredUsername)) - hash := fmt.Sprintf("%x", h.Sum(nil)) - - signAlg, _ := jwt.SignatureAlgorithm(&jose.JSONWebKey{Key: signingKey}) - signedBytes, err := jwt.SignPayload(&jose.JSONWebKey{Key: signingKey}, signAlg, []byte(hash)) - if err != nil { - return "", err +// createTestIdentity creates a sample identity for testing +func createTestIdentity() connector.Identity { + return connector.Identity{ + UserID: "user123", + Username: "testuser", + PreferredUsername: "testuser", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"group1", "group2"}, } - // Explicitly encode as []byte. - return base64.RawURLEncoding.EncodeToString([]byte(signedBytes)), nil } -func TestHandleRememberMe(t *testing.T) { - // Common test fixtures. +func TestHandleRememberMe_Integration(t *testing.T) { + store, sessionStore, logger := setupTestEnvironment(t) ctx := context.Background() connectorName := "test-connector" expiryDuration := 24 * time.Hour - testKey := generateTestKey() - mockStore := &mockStorage{ - keys: storage.Keys{ - SigningKey: &jose.JSONWebKey{Key: testKey}, + identity := createTestIdentity() + + tests := []struct { + name string + setup func() (*http.Request, AuthContext) + want func(t *testing.T, result *RememberMeCtx, err error) + }{ + { + name: "no cookie with anonymous context returns ErrNotFound", + setup: func() (*http.Request, AuthContext) { + req := createTestRequest(connectorName, "") + authCtx := NewAnonymousAuthContext(connectorName, expiryDuration) + return req, authCtx + }, + want: func(t *testing.T, result *RememberMeCtx, err error) { + require.Error(t, err) + require.True(t, errors.Is(err, storage.ErrNotFound)) + require.Nil(t, result) + }, + }, + { + name: "no cookie with identity creates new session and sets cookie", + setup: func() (*http.Request, AuthContext) { + req := createTestRequest(connectorName, "") + authCtx := NewAuthContextWithIdentity(connectorName, identity, expiryDuration) + return req, authCtx + }, + want: func(t *testing.T, result *RememberMeCtx, err error) { + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.IsValid()) + + // Verify session details + require.Equal(t, identity.UserID, result.Session.Identity.UserID) + require.Equal(t, identity.Email, result.Session.Identity.Email) + require.Equal(t, identity.Groups, result.Session.Identity.Groups) + require.True(t, result.Session.Expiry.After(time.Now())) + + // Verify cookie is set + require.False(t, result.Cookie.Empty()) + cookie, unset := result.Cookie.Get() + require.False(t, unset) + require.Equal(t, connector_cookie_name(connectorName), cookie.Name) + require.NotEmpty(t, cookie.Value) + require.True(t, cookie.Secure) + require.True(t, cookie.HttpOnly) + require.Equal(t, http.SameSiteStrictMode, cookie.SameSite) + + // Verify session was stored + storedSession, err := sessionStore.GetSession(ctx, cookie.Value) + require.NoError(t, err) + require.Equal(t, identity.UserID, storedSession.Identity.UserID) + }, }, - } - mockSessionStore := &mockSessionStorage{ - sessions: make(map[string]storage.ActiveSession), - } - - // Helper to create a request with optional cookie. - newRequest := func(cookieValue string) *http.Request { - req := httptest.NewRequest("GET", "/auth", nil) - if cookieValue != "" { - cookieName := connector_cookie_name(connectorName) - req.AddCookie(&http.Cookie{Name: cookieName, Value: cookieValue}) - } - return req } - // Sample identity. - identity := connector.Identity{ - UserID: "user123", - Username: "testuser", - PreferredUsername: "testuser", - Email: "test@example.com", - EmailVerified: true, - Groups: []string{"group1"}, + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req, authCtx := tt.setup() + result, err := HandleRememberMe(ctx, logger, req, authCtx, store, sessionStore) + tt.want(t, result, err) + }) } +} - t.Run("No cookie, anonymous context", func(t *testing.T) { - req := newRequest("") // No cookie. - data := NewAnonymousAuthContext(connectorName, expiryDuration) - ctx, err := HandleRememberMe(ctx, testLogger, req, data, mockStore, mockSessionStore) - if !errors.Is(err, storage.ErrNotFound) { - t.Errorf("Expected ErrNotFound, got: %v", err) - } - if ctx != nil { - t.Error("Expected nil context") - } - }) - - t.Run("No cookie, with identity: Create new session and set cookie", func(t *testing.T) { - req := newRequest("") // No cookie. - data := NewAuthContextWithIdentity(connectorName, identity, expiryDuration) - rmCtx, err := HandleRememberMe(ctx, testLogger, req, data, mockStore, mockSessionStore) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - if !rmCtx.IsValid() { - t.Error("Expected valid session") - } - if rmCtx.Session.Identity.UserID != identity.UserID { - t.Errorf("Expected UserID %s, got %s", identity.UserID, rmCtx.Session.Identity.UserID) - } - if rmCtx.Cookie.Empty() || rmCtx.Cookie.unset { - t.Error("Expected cookie to be set") - } - cookie, _ := rmCtx.Cookie.Get() - if cookie.Name != connector_cookie_name(connectorName) { - t.Errorf("Unexpected cookie name: %s", cookie.Name) - } - if cookie.Value == "" { - t.Error("Expected non-empty cookie value") - } - if cookie.Expires.Before(fixedNow().Add(expiryDuration - time.Minute)) { // Allow some leeway. - t.Errorf("Unexpected expiry: %v", cookie.Expires) - } - // Verify session was stored. - storedSession, err := mockSessionStore.GetSession(ctx, cookie.Value) - if err != nil || storedSession.Identity.UserID != identity.UserID { - t.Error("Session not stored correctly") - } +func TestHandleRememberMe_WithExistingSessions(t *testing.T) { + store, sessionStore, logger := setupTestEnvironment(t) + ctx := context.Background() + connectorName := "test-connector" + expiryDuration := 24 * time.Hour + identity := createTestIdentity() + + // First create a session to test retrieval + req := createTestRequest(connectorName, "") + authCtx := NewAuthContextWithIdentity(connectorName, identity, expiryDuration) + initialResult, err := HandleRememberMe(ctx, logger, req, authCtx, store, sessionStore) + require.NoError(t, err) + require.NotNil(t, initialResult) + + cookie, _ := initialResult.Cookie.Get() + cookieValue := cookie.Value + + t.Run("valid cookie with active session returns session without new cookie", func(t *testing.T) { + req := createTestRequest(connectorName, cookieValue) + authCtx := NewAnonymousAuthContext(connectorName, expiryDuration) + + result, err := HandleRememberMe(ctx, logger, req, authCtx, store, sessionStore) + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.IsValid()) + require.Equal(t, identity.UserID, result.Session.Identity.UserID) + require.True(t, result.Cookie.Empty()) // No cookie change needed }) - t.Run("Valid cookie with active session", func(t *testing.T) { - // Pre-populate a session. - signedHash, err := generateSignedHash(identity, testKey) - if err != nil { - t.Fatal(err) - } - session := storage.ActiveSession{ - Identity: identity, - Expiry: fixedNow().Add(expiryDuration), - } - mockSessionStore.sessions[signedHash] = session + t.Run("expired session unsets cookie", func(t *testing.T) { + // Create a fresh session store to avoid ID conflicts + freshStore, freshSessionStore, freshLogger := setupTestEnvironment(t) - req := newRequest(signedHash) - data := NewAnonymousAuthContext(connectorName, expiryDuration) - rmCtx, err := HandleRememberMe(ctx, testLogger, req, data, mockStore, mockSessionStore) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - if !rmCtx.IsValid() { - t.Error("Expected valid session") - } - if rmCtx.Session.Identity.UserID != identity.UserID { - t.Errorf("Expected UserID %s, got %s", identity.UserID, rmCtx.Session.Identity.UserID) - } - if !rmCtx.Cookie.Empty() { - t.Error("Expected no cookie change (empty GetOrUnsetCookie)") + expiredIdentity := connector.Identity{ + UserID: "expired-user", + Username: "expireduser", + Email: "expired@example.com", + Groups: []string{"expired-group"}, } - }) - t.Run("Cookie present, expired session: Unset cookie", func(t *testing.T) { - // Pre-populate an expired session. - signedHash, err := generateSignedHash(identity, testKey) - if err != nil { - t.Fatal(err) - } - session := storage.ActiveSession{ - Identity: identity, - Expiry: fixedNow().Add(-time.Hour), // Expired. + // Create an expired session directly with a known identifier + expiredSession := storage.ActiveSession{ + Identity: expiredIdentity, + Expiry: time.Now().Add(-time.Hour), } - mockSessionStore.sessions[signedHash] = session - req := newRequest(signedHash) - data := NewAnonymousAuthContext(connectorName, expiryDuration) - rmCtx, err := HandleRememberMe(ctx, testLogger, req, data, mockStore, mockSessionStore) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - if rmCtx.IsValid() { - t.Error("Expected invalid (expired) session") - } - if rmCtx.Cookie.Empty() || !rmCtx.Cookie.unset { - t.Error("Expected cookie to be unset") - } - cookie, unset := rmCtx.Cookie.Get() - if !unset || cookie.Name != connector_cookie_name(connectorName) || cookie.MaxAge != -1 { - t.Error("Unexpected unset cookie") - } + // Create the session using a predictable signed identifier + // First get a signed identifier by creating a valid session + tempReq := createTestRequest(connectorName, "") + tempAuthCtx := NewAuthContextWithIdentity(connectorName, expiredIdentity, time.Hour) // short duration + tempResult, err := HandleRememberMe(ctx, freshLogger, tempReq, tempAuthCtx, freshStore, freshSessionStore) + require.NoError(t, err) + + tempCookie, _ := tempResult.Cookie.Get() + sessionID := tempCookie.Value + + // Wait a moment and directly update the session in storage to be expired + // by creating a new session store and directly setting expired session + testSessionStore := memory.NewSessionStore(freshLogger) + err = testSessionStore.CreateSession(ctx, sessionID, expiredSession) + require.NoError(t, err) + + // Test with the expired session + req := createTestRequest(connectorName, sessionID) + authCtx := NewAnonymousAuthContext(connectorName, expiryDuration) + result, err := HandleRememberMe(ctx, freshLogger, req, authCtx, freshStore, testSessionStore) + + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, result.IsValid()) + + // Cookie should be unset + require.False(t, result.Cookie.Empty()) + resultCookie, unset := result.Cookie.Get() + require.True(t, unset) + require.Equal(t, -1, resultCookie.MaxAge) }) - t.Run("Cookie present, session not found: Unset cookie", func(t *testing.T) { - signedHash := "invalid-hash" - req := newRequest(signedHash) - data := NewAnonymousAuthContext(connectorName, expiryDuration) - rmCtx, err := HandleRememberMe(ctx, testLogger, req, data, mockStore, mockSessionStore) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - if rmCtx.IsValid() { - t.Error("Expected invalid session") - } - if rmCtx.Cookie.Empty() || !rmCtx.Cookie.unset { - t.Error("Expected cookie to be unset") - } + t.Run("invalid cookie signature unsets cookie", func(t *testing.T) { + // Use an invalid JWT format + invalidCookie := "invalid.jwt.signature" + req := createTestRequest(connectorName, invalidCookie) + authCtx := NewAnonymousAuthContext(connectorName, expiryDuration) + + result, err := HandleRememberMe(ctx, logger, req, authCtx, store, sessionStore) + require.Error(t, err) + require.NotNil(t, result) + require.False(t, result.IsValid()) + + // Cookie should be unset + require.False(t, result.Cookie.Empty()) + cookie, unset := result.Cookie.Get() + require.True(t, unset) + require.Equal(t, connector_cookie_name(connectorName), cookie.Name) }) +} - t.Run("Error: Failed to get keys", func(t *testing.T) { - mockStore.err = errors.New("key error") - req := newRequest("") - data := NewAuthContextWithIdentity(connectorName, identity, expiryDuration) - _, err := HandleRememberMe(ctx, testLogger, req, data, mockStore, mockSessionStore) - if err == nil || !errors.Is(err, mockStore.err) { - t.Errorf("Expected key error, got: %v", err) - } - mockStore.err = nil // Reset. +func TestHandleRememberMe_EndToEndWorkflow(t *testing.T) { + store, sessionStore, logger := setupTestEnvironment(t) + ctx := context.Background() + connectorName := "test-connector" + expiryDuration := 24 * time.Hour + identity := createTestIdentity() + + t.Run("complete login workflow", func(t *testing.T) { + // Step 1: Initial login - no cookie present + req1 := createTestRequest(connectorName, "") + authCtx1 := NewAuthContextWithIdentity(connectorName, identity, expiryDuration) + + result1, err := HandleRememberMe(ctx, logger, req1, authCtx1, store, sessionStore) + require.NoError(t, err) + require.True(t, result1.IsValid()) + + cookie1, unset1 := result1.Cookie.Get() + require.False(t, unset1) + require.NotEmpty(t, cookie1.Value) + + // Step 2: Return visit with cookie - should recognize user + req2 := createTestRequest(connectorName, cookie1.Value) + authCtx2 := NewAnonymousAuthContext(connectorName, expiryDuration) + + result2, err := HandleRememberMe(ctx, logger, req2, authCtx2, store, sessionStore) + require.NoError(t, err) + require.True(t, result2.IsValid()) + require.Equal(t, identity.UserID, result2.Session.Identity.UserID) + require.True(t, result2.Cookie.Empty()) // No cookie change needed + + // Step 3: Verify session consistency + require.Equal(t, result1.Session.Identity.UserID, result2.Session.Identity.UserID) + require.Equal(t, result1.Session.Identity.Email, result2.Session.Identity.Email) }) - t.Run("Error: Failed to create session", func(t *testing.T) { - mockSessionStore.createErr = errors.New("create error") - req := newRequest("") - data := NewAuthContextWithIdentity(connectorName, identity, expiryDuration) - _, err := HandleRememberMe(ctx, testLogger, req, data, mockStore, mockSessionStore) - if err == nil || !errors.Is(err, mockSessionStore.createErr) { - t.Errorf("Expected create error, got: %v", err) - } - mockSessionStore.createErr = nil // Reset. + t.Run("multiple connectors isolation", func(t *testing.T) { + connector1 := "connector-1" + connector2 := "connector-2" + + // Create session for connector1 + req1 := createTestRequest(connector1, "") + authCtx1 := NewAuthContextWithIdentity(connector1, identity, expiryDuration) + result1, err := HandleRememberMe(ctx, logger, req1, authCtx1, store, sessionStore) + require.NoError(t, err) + + cookie1, _ := result1.Cookie.Get() + + // Create a request for connector2 but with connector1's cookie + // This simulates having both cookies in the browser + req2 := httptest.NewRequest("GET", "/auth", nil) + req2.AddCookie(cookie1) // connector1's cookie + + // Since connector2 looks for its own cookie name, it won't find connector1's cookie + authCtx2 := NewAnonymousAuthContext(connector2, expiryDuration) + result2, err := HandleRememberMe(ctx, logger, req2, authCtx2, store, sessionStore) + require.Error(t, err) + require.True(t, errors.Is(err, storage.ErrNotFound)) + require.Nil(t, result2) }) +} - t.Run("Error: Failed to get session", func(t *testing.T) { - mockSessionStore.getErr = errors.New("get error") - signedHash, _ := generateSignedHash(identity, testKey) - req := newRequest(signedHash) - data := NewAnonymousAuthContext(connectorName, expiryDuration) - _, err := HandleRememberMe(ctx, testLogger, req, data, mockStore, mockSessionStore) - if err == nil || !errors.Is(err, mockSessionStore.getErr) { - t.Errorf("Expected get error, got: %v", err) - } - mockSessionStore.getErr = nil // Reset. +func TestHandleRememberMe_ErrorHandling(t *testing.T) { + store, sessionStore, logger := setupTestEnvironment(t) + ctx := context.Background() + connectorName := "test-connector" + expiryDuration := 24 * time.Hour + identity := createTestIdentity() + + t.Run("session not found after valid signature verification", func(t *testing.T) { + // Create a session first to get a valid signed cookie + req := createTestRequest(connectorName, "") + authCtx := NewAuthContextWithIdentity(connectorName, identity, expiryDuration) + result, err := HandleRememberMe(ctx, logger, req, authCtx, store, sessionStore) + require.NoError(t, err) + + cookie, _ := result.Cookie.Get() + + // Create a different session store that doesn't have this session + emptySessionStore := memory.NewSessionStore(logger) + + // Try to retrieve with valid cookie but empty session store + req2 := createTestRequest(connectorName, cookie.Value) + authCtx2 := NewAnonymousAuthContext(connectorName, expiryDuration) + result2, err := HandleRememberMe(ctx, logger, req2, authCtx2, store, emptySessionStore) + + require.NoError(t, err) + require.NotNil(t, result2) + require.False(t, result2.IsValid()) + + // Should unset cookie when session not found + require.False(t, result2.Cookie.Empty()) + unsetCookie, unset := result2.Cookie.Get() + require.True(t, unset) + require.Equal(t, connector_cookie_name(connectorName), unsetCookie.Name) }) } func TestExtractCookie(t *testing.T) { connectorName := "test-connector" - req := httptest.NewRequest("GET", "/auth", nil) - req.AddCookie(&http.Cookie{Name: connector_cookie_name(connectorName), Value: "test-value"}) - req.AddCookie(&http.Cookie{Name: "other-cookie", Value: "ignored"}) - value, found := extractCookie(req, connectorName) - if !found || value != "test-value" { - t.Errorf("Expected value 'test-value', got: %s (found: %v)", value, found) - } + t.Run("cookie present", func(t *testing.T) { + req := httptest.NewRequest("GET", "/auth", nil) + req.AddCookie(&http.Cookie{Name: connector_cookie_name(connectorName), Value: "test-value"}) + req.AddCookie(&http.Cookie{Name: "other-cookie", Value: "ignored"}) - // No cookie. - req = httptest.NewRequest("GET", "/auth", nil) - value, found = extractCookie(req, connectorName) - if found || value != "" { - t.Error("Expected not found") - } + value, found := extractCookie(req, connectorName) + require.True(t, found) + require.Equal(t, "test-value", value) + }) + + t.Run("cookie not present", func(t *testing.T) { + req := httptest.NewRequest("GET", "/auth", nil) + value, found := extractCookie(req, connectorName) + require.False(t, found) + require.Equal(t, "", value) + }) } func TestConnectorCookieName(t *testing.T) { - if name := connector_cookie_name("test"); name != "dex_active_session_cookie_test" { - t.Errorf("Unexpected name: %s", name) + tests := []struct { + connector string + expected string + }{ + {"test", "dex_active_session_cookie_test"}, + {"google", "dex_active_session_cookie_google"}, + {"ldap-local", "dex_active_session_cookie_ldap-local"}, + } + + for _, tt := range tests { + t.Run(tt.connector, func(t *testing.T) { + result := connector_cookie_name(tt.connector) + require.Equal(t, tt.expected, result) + }) } } From a91ffa010479cae11bc1e1e3a1dac59cec4b8b3e Mon Sep 17 00:00:00 2001 From: Julius Foitzik Date: Sat, 18 Oct 2025 13:00:47 +0200 Subject: [PATCH 4/9] move Signed-off-by: Julius Foitzik --- {remember-me => internal/remember-me}/handler.go | 0 {remember-me => internal/remember-me}/handler_test.go | 0 server/handlers.go | 2 +- 3 files changed, 1 insertion(+), 1 deletion(-) rename {remember-me => internal/remember-me}/handler.go (100%) rename {remember-me => internal/remember-me}/handler_test.go (100%) diff --git a/remember-me/handler.go b/internal/remember-me/handler.go similarity index 100% rename from remember-me/handler.go rename to internal/remember-me/handler.go diff --git a/remember-me/handler_test.go b/internal/remember-me/handler_test.go similarity index 100% rename from remember-me/handler_test.go rename to internal/remember-me/handler_test.go diff --git a/server/handlers.go b/server/handlers.go index 47590bde..0c05bd00 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -24,7 +24,7 @@ import ( "github.com/dexidp/dex/connector" "github.com/dexidp/dex/internal/jwt" - rememberme "github.com/dexidp/dex/remember-me" + rememberme "github.com/dexidp/dex/internal/remember-me" "github.com/dexidp/dex/server/internal" "github.com/dexidp/dex/storage" ) From 2d5805c04853654192a81edc2c4bca4f60589863 Mon Sep 17 00:00:00 2001 From: Julius Foitzik Date: Sat, 18 Oct 2025 18:15:05 +0200 Subject: [PATCH 5/9] add DEP Dex Enhancement Proposal doc Signed-off-by: Julius Foitzik --- .../remember-me-2025-10-19-#32.md | 93 +++++++++++++++++++ 1 file changed, 93 insertions(+) create mode 100644 docs/enhancements/remember-me-2025-10-19-#32.md diff --git a/docs/enhancements/remember-me-2025-10-19-#32.md b/docs/enhancements/remember-me-2025-10-19-#32.md new file mode 100644 index 00000000..63db767f --- /dev/null +++ b/docs/enhancements/remember-me-2025-10-19-#32.md @@ -0,0 +1,93 @@ +# Dex Enhancement Proposal (DEP) <#32> - <2025-10-19> - Remembe Me + +## Table of Contents + +- [Summary](#summary) +- [Motivation](#motivation) + - [Goals/Pain](#goals) + - [Non-Goals](#non-goals) +- [Proposal](#proposal) + - [User Experience](#user-experience) + - [Implementation Details/Notes/Constraints](#implementation-detailsnotesconstraints) + - [Risks and Mitigations](#risks-and-mitigations) + - [Alternatives](#alternatives) +- [Future Improvements](#future-improvements) + +## Summary + +Avoid repeated re-authentications when using password-based (sessionless) connectors by +storing a server-side (dex) session of the user login and re-use it instead. + +## Context + +https://github.com/dexidp/dex/issues/32 + +## Motivation + +### Goals/Pain + +- Minimal viable implementation of remember me functionality scoped to only password-based connectors +- If the same user is authenticating through dex using n>1 applications (clients) during a session (predefined timeframe), the user should not be prompted to log in again +- Avoid bad UX where each application (client) triggers a new login with the password connector +- Implement for the in-memory storage backend + +### Non-goals + +- Implement for any other storage backend +- Implement for any non-password connector + +## Proposal + +### User Experience + +- When the user logs in once using the password-based connector he is never prompted to login again until his session expires +- Once a session has been obtained the authflow is frictionless and mostly automatic + +### Implementation Details/Notes/Constraints + +- Implementation is in a separate package to separate concerns and keep code isolated +- Add new specific interface for storage to avoid bloating the already huge storage (`storage.Storage`) interface +- Each connector has a specific cookie to allow having more than one password-based connector (also for security purposes) +- Cookies are signed just as JWT are and verified each time to ensure authenticity + +Regular password-based connector flow but with active sessions (no session found case). + +```mermaid +sequenceDiagram + User->>+Client: Start Auth Flow + Client-->>-User: Redirect to dex + User-->>+Dex: Auth Flow + Dex->>+Dex: Check for Cookie and Session + Dex-->>-User: Redirect to Login Page + User->>+Dex: Send Credentials + Dex->>+Connector: Forward Credentials + Connector-->>-Dex: Return Identity + Dex->>+Dex: Persist Session + Dex -->>- User: Redirect to client + +``` + +Improved UX flow with active session (session found). + +```mermaid +sequenceDiagram + User->>+Client: Start Auth Flow + Client-->>-User: Redirect to dex + User-->>+Dex: Auth Flow + Dex->>+Dex: Check for Cookie and Session + Dex->>+Dex: Retrieve Session + Dex -->>- User: Redirect to client +``` + +### Risks and Mitigations + +- I am not absolutely sure whether this introduces any attack vectors that could be exploited. +- This DEP does not introduce any breaking changes. + +### Alternatives + +- None. We can declare this out of scope, but other than developing. + +## Future Improvements + +- None From 478c7a50f29a6ff64f34b5f9200b23bdb1846e99 Mon Sep 17 00:00:00 2001 From: Julius Foitzik Date: Mon, 20 Oct 2025 16:55:57 +0200 Subject: [PATCH 6/9] update ldap example to include remember me ... Signed-off-by: Julius Foitzik --- examples/ldap/config-ldap.yaml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/examples/ldap/config-ldap.yaml b/examples/ldap/config-ldap.yaml index 05d16618..f37fe11a 100644 --- a/examples/ldap/config-ldap.yaml +++ b/examples/ldap/config-ldap.yaml @@ -6,6 +6,11 @@ storage: web: http: 0.0.0.0:5556 +# Uncomment this section if you want to use the remember me / action sessions +# feature. For more details search for remember-me-2025-10-19-#32.md +#sessions: +# enable: true + connectors: - type: ldap name: OpenLDAP From a7232fddffa09e21244026c37ad247fbd02c10f7 Mon Sep 17 00:00:00 2001 From: Julius Foitzik Date: Mon, 20 Oct 2025 17:30:30 +0200 Subject: [PATCH 7/9] update dist.yaml Signed-off-by: Julius Foitzik --- config.yaml.dist | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/config.yaml.dist b/config.yaml.dist index b7e1410f..e5b843ae 100644 --- a/config.yaml.dist +++ b/config.yaml.dist @@ -3,6 +3,11 @@ # path is provided, Dex's HTTP service will listen at a non-root URL. issuer: http://127.0.0.1:5556/dex +# Uncomment this section if you want to use the remember me / action sessions +# feature. For more details search for remember-me-2025-10-19-#32.md +#sessions: +# enable: true + # The storage configuration determines where Dex stores its state. # Supported options include: # - SQL flavors From 4e3b220afd2f74f6a3f3ffa628961947a2e15a46 Mon Sep 17 00:00:00 2001 From: Julius Foitzik Date: Wed, 22 Oct 2025 18:10:41 +0200 Subject: [PATCH 8/9] snapshot" Signed-off-by: Julius Foitzik --- internal/remember-me/handler_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/remember-me/handler_test.go b/internal/remember-me/handler_test.go index 640c94f1..08504cb9 100644 --- a/internal/remember-me/handler_test.go +++ b/internal/remember-me/handler_test.go @@ -24,11 +24,11 @@ import ( func setupTestEnvironment(t *testing.T) (storage.Storage, storage.ActiveSessionStorage, *slog.Logger) { logger := slog.New(slog.NewTextHandler(nil, &slog.HandlerOptions{Level: slog.LevelError})) - // Use dex's standard in-memory storage + // Use in-memory storage store := memory.New(logger) sessionStore := memory.NewSessionStore(logger) - // Initialize with real keys like dex does + // Initialize with real keys ctx := context.Background() err := store.UpdateKeys(ctx, func(old storage.Keys) (storage.Keys, error) { key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) From 48728cc8f24a94acbcbe29e0c61c635f0acdeb05 Mon Sep 17 00:00:00 2001 From: Julius Foitzik Date: Wed, 22 Oct 2025 18:17:12 +0200 Subject: [PATCH 9/9] update test Signed-off-by: Julius Foitzik --- internal/remember-me/handler_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/remember-me/handler_test.go b/internal/remember-me/handler_test.go index 08504cb9..3cd3b863 100644 --- a/internal/remember-me/handler_test.go +++ b/internal/remember-me/handler_test.go @@ -360,7 +360,7 @@ func TestConnectorCookieName(t *testing.T) { expected string }{ {"test", "dex_active_session_cookie_test"}, - {"google", "dex_active_session_cookie_google"}, + {"google", "dex_active_session_cookie_google"}, // just an example, google would not be use-case for this feature {"ldap-local", "dex_active_session_cookie_ldap-local"}, }