diff --git a/FEATURE-NOTES.md b/FEATURE-NOTES.md new file mode 100644 index 00000000..153af150 --- /dev/null +++ b/FEATURE-NOTES.md @@ -0,0 +1,27 @@ +Goal is to come up with a compact and minimal design for https://github.com/dexidp/dex/issues/32 + + +# Notes + + - Use cookies to identify a returning user + - Sign cookie using one of the internal private keys + - Verify cookie when present + - Generate and store cookie or cookie encrypted value on Store + - Requires extension of storage.Storage interface + - Feature Flag + - Only introduce code in code-path of Password-based login providers + - Write cookie to response on success login (see Ref(1)) + - Cookie ExpiresIn should be less or equal to the minted JWT ExpiresIn + - I think simple store the storage.Claims+identity.ConnectorData and inject them into finalizeLogin, if the user is already logged in + - An ActiveSession is only valid once it has an associated AccessToken + - The ActiveSession should expire after a configurable amount of time + + + +Ref(1): + +probably here, and only in the case of http.MethodPost. + +```go +func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) { +``` diff --git a/cmd/dex/config.go b/cmd/dex/config.go index c76ff030..fc28e486 100644 --- a/cmd/dex/config.go +++ b/cmd/dex/config.go @@ -33,6 +33,9 @@ type Config struct { Expiry Expiry `json:"expiry"` Logger Logger `json:"logger"` + // Experimental Feature to remember users when they authenticated using a password-based connector + Sessions Sessions `json:"sessions"` + Frontend server.WebConfig `json:"frontend"` // StaticConnectors are user defined connectors specified in the ConfigMap @@ -51,6 +54,10 @@ type Config struct { // querying the storage. Cannot be specified without enabling a passwords // database. StaticPasswords []password `json:"staticPasswords"` + + // If enabled, the server will maintain a active sessions for password connectors + // to identify returning users to avoid re-entering credentials if the session is still valid. + ExperimentalEnableRememberMe bool `json:"experimentalEnableRememberMe"` } // Validate the configuration @@ -169,6 +176,10 @@ type Web struct { ClientRemoteIP ClientRemoteIP `json:"clientRemoteIP"` } +type Sessions struct { + Enable bool `json:"enable"` +} + type ClientRemoteIP struct { Header string `json:"header"` TrustedProxies []string `json:"trustedProxies"` diff --git a/cmd/dex/serve.go b/cmd/dex/serve.go index ac715e60..9a3dc481 100644 --- a/cmd/dex/serve.go +++ b/cmd/dex/serve.go @@ -307,6 +307,10 @@ func runServe(options serveOptions) error { PrometheusRegistry: prometheusRegistry, HealthChecker: healthChecker, ContinueOnConnectorFailure: featureflags.ContinueOnConnectorFailure.Enabled(), + EnableRememberMe: c.Sessions.Enable, + } + if c.Sessions.Enable { + logger.Info("remember me experimental feature enabled") } if c.Expiry.SigningKeys != "" { signingKeys, err := time.ParseDuration(c.Expiry.SigningKeys) diff --git a/config.yaml.dist b/config.yaml.dist index c187ca3c..131722b2 100644 --- a/config.yaml.dist +++ b/config.yaml.dist @@ -3,6 +3,11 @@ # path is provided, Dex's HTTP service will listen at a non-root URL. issuer: http://127.0.0.1:5556/dex +# Uncomment this section if you want to use the remember me / action sessions +# feature. For more details search for remember-me-2025-10-19-#32.md +#sessions: +# enable: true + # The storage configuration determines where Dex stores its state. # Supported options include: # - SQL flavors diff --git a/docs/enhancements/remember-me-2025-10-19-#32.md b/docs/enhancements/remember-me-2025-10-19-#32.md new file mode 100644 index 00000000..63db767f --- /dev/null +++ b/docs/enhancements/remember-me-2025-10-19-#32.md @@ -0,0 +1,93 @@ +# Dex Enhancement Proposal (DEP) <#32> - <2025-10-19> - Remembe Me + +## Table of Contents + +- [Summary](#summary) +- [Motivation](#motivation) + - [Goals/Pain](#goals) + - [Non-Goals](#non-goals) +- [Proposal](#proposal) + - [User Experience](#user-experience) + - [Implementation Details/Notes/Constraints](#implementation-detailsnotesconstraints) + - [Risks and Mitigations](#risks-and-mitigations) + - [Alternatives](#alternatives) +- [Future Improvements](#future-improvements) + +## Summary + +Avoid repeated re-authentications when using password-based (sessionless) connectors by +storing a server-side (dex) session of the user login and re-use it instead. + +## Context + +https://github.com/dexidp/dex/issues/32 + +## Motivation + +### Goals/Pain + +- Minimal viable implementation of remember me functionality scoped to only password-based connectors +- If the same user is authenticating through dex using n>1 applications (clients) during a session (predefined timeframe), the user should not be prompted to log in again +- Avoid bad UX where each application (client) triggers a new login with the password connector +- Implement for the in-memory storage backend + +### Non-goals + +- Implement for any other storage backend +- Implement for any non-password connector + +## Proposal + +### User Experience + +- When the user logs in once using the password-based connector he is never prompted to login again until his session expires +- Once a session has been obtained the authflow is frictionless and mostly automatic + +### Implementation Details/Notes/Constraints + +- Implementation is in a separate package to separate concerns and keep code isolated +- Add new specific interface for storage to avoid bloating the already huge storage (`storage.Storage`) interface +- Each connector has a specific cookie to allow having more than one password-based connector (also for security purposes) +- Cookies are signed just as JWT are and verified each time to ensure authenticity + +Regular password-based connector flow but with active sessions (no session found case). + +```mermaid +sequenceDiagram + User->>+Client: Start Auth Flow + Client-->>-User: Redirect to dex + User-->>+Dex: Auth Flow + Dex->>+Dex: Check for Cookie and Session + Dex-->>-User: Redirect to Login Page + User->>+Dex: Send Credentials + Dex->>+Connector: Forward Credentials + Connector-->>-Dex: Return Identity + Dex->>+Dex: Persist Session + Dex -->>- User: Redirect to client + +``` + +Improved UX flow with active session (session found). + +```mermaid +sequenceDiagram + User->>+Client: Start Auth Flow + Client-->>-User: Redirect to dex + User-->>+Dex: Auth Flow + Dex->>+Dex: Check for Cookie and Session + Dex->>+Dex: Retrieve Session + Dex -->>- User: Redirect to client +``` + +### Risks and Mitigations + +- I am not absolutely sure whether this introduces any attack vectors that could be exploited. +- This DEP does not introduce any breaking changes. + +### Alternatives + +- None. We can declare this out of scope, but other than developing. + +## Future Improvements + +- None diff --git a/examples/ldap/config-ldap.yaml b/examples/ldap/config-ldap.yaml index 05d16618..f37fe11a 100644 --- a/examples/ldap/config-ldap.yaml +++ b/examples/ldap/config-ldap.yaml @@ -6,6 +6,11 @@ storage: web: http: 0.0.0.0:5556 +# Uncomment this section if you want to use the remember me / action sessions +# feature. For more details search for remember-me-2025-10-19-#32.md +#sessions: +# enable: true + connectors: - type: ldap name: OpenLDAP diff --git a/internal/jwt/keyset.go b/internal/jwt/keyset.go new file mode 100644 index 00000000..0d6ddcaa --- /dev/null +++ b/internal/jwt/keyset.go @@ -0,0 +1,56 @@ +package jwt + +import ( + "context" + "errors" + + "github.com/go-jose/go-jose/v4" + + "github.com/dexidp/dex/storage" +) + +var ErrFailedVerify = errors.New("failed to verify id token signature") + +// StorageKeySet implements the oidc.KeySet interface backed by Dex storage +type StorageKeySet struct { + storage.Storage +} + +func NewStorageKeySet(store storage.Storage) *StorageKeySet { + return &StorageKeySet{ + store, + } +} + +func (s *StorageKeySet) VerifySignature(ctx context.Context, jwt string) (payload []byte, err error) { + jws, err := jose.ParseSigned(jwt, []jose.SignatureAlgorithm{jose.RS256, jose.RS384, jose.RS512, jose.ES256, jose.ES384, jose.ES512}) + if err != nil { + return nil, err + } + + keyID := "" + for _, sig := range jws.Signatures { + keyID = sig.Header.KeyID + break + } + + skeys, err := s.Storage.GetKeys(ctx) + if err != nil { + return nil, err + } + + keys := []*jose.JSONWebKey{skeys.SigningKeyPub} + for _, vk := range skeys.VerificationKeys { + keys = append(keys, vk.PublicKey) + } + + for _, key := range keys { + if keyID == "" || key.KeyID == keyID { + if payload, err := jws.Verify(key); err == nil { + return payload, nil + } + } + } + + return nil, ErrFailedVerify +} diff --git a/internal/jwt/signature.go b/internal/jwt/signature.go new file mode 100644 index 00000000..ad6d2afd --- /dev/null +++ b/internal/jwt/signature.go @@ -0,0 +1,59 @@ +package jwt + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rsa" + "errors" + "fmt" + + "github.com/go-jose/go-jose/v4" +) + +// Determine the signature algorithm for a JWT. +func SignatureAlgorithm(jwk *jose.JSONWebKey) (alg jose.SignatureAlgorithm, err error) { + if jwk.Key == nil { + return alg, errors.New("no signing key") + } + switch key := jwk.Key.(type) { + case *rsa.PrivateKey: + // Because OIDC mandates that we support RS256, we always return that + // value. In the future, we might want to make this configurable on a + // per client basis. For example allowing PS256 or ECDSA variants. + // + // See https://github.com/dexidp/dex/issues/692 + return jose.RS256, nil + case *ecdsa.PrivateKey: + // We don't actually support ECDSA keys yet, but they're tested for + // in case we want to in the future. + // + // These values are prescribed depending on the ECDSA key type. We + // can't return different values. + switch key.Params() { + case elliptic.P256().Params(): + return jose.ES256, nil + case elliptic.P384().Params(): + return jose.ES384, nil + case elliptic.P521().Params(): + return jose.ES512, nil + default: + return alg, errors.New("unsupported ecdsa curve") + } + default: + return alg, fmt.Errorf("unsupported signing key type %T", key) + } +} + +func SignPayload(key *jose.JSONWebKey, alg jose.SignatureAlgorithm, payload []byte) (jws string, err error) { + signingKey := jose.SigningKey{Key: key, Algorithm: alg} + + signer, err := jose.NewSigner(signingKey, &jose.SignerOptions{}) + if err != nil { + return "", fmt.Errorf("new signer: %v", err) + } + signature, err := signer.Sign(payload) + if err != nil { + return "", fmt.Errorf("signing payload: %v", err) + } + return signature.CompactSerialize() +} diff --git a/internal/remember-me/handler.go b/internal/remember-me/handler.go new file mode 100644 index 00000000..f2247a4c --- /dev/null +++ b/internal/remember-me/handler.go @@ -0,0 +1,187 @@ +package rememberme + +import ( + "context" + "crypto/sha3" + "errors" + "fmt" + "log/slog" + "net/http" + "time" + + "github.com/dexidp/dex/connector" + "github.com/dexidp/dex/internal/jwt" + "github.com/dexidp/dex/storage" +) + +const ACTIVE_SESSION_COOKIE_NAME = "dex_active_session_cookie" + +var emptySession = storage.ActiveSession{} + +type AuthContext struct { + connectorName string + identity *connector.Identity + configuredExpiryDuration time.Duration +} + +func NewAnonymousAuthContext(connectorName string, configuredExpiryDuration time.Duration) AuthContext { + return AuthContext{connectorName, nil, configuredExpiryDuration} +} + +func NewAuthContextWithIdentity(connectorName string, identity connector.Identity, configuredExpiryDuration time.Duration) AuthContext { + return AuthContext{connectorName, &identity, configuredExpiryDuration} +} + +type GetOrUnsetCookie struct { + cookie *http.Cookie + unset bool +} + +func (c GetOrUnsetCookie) Empty() bool { + return c.unset == false && c.cookie == nil +} + +func (c GetOrUnsetCookie) Get() (*http.Cookie, bool) { + // TODO(juf): would prefer to not return internal pointer + return c.cookie, c.unset +} + +func RequestUnsetCookie(cookieName string) GetOrUnsetCookie { + return GetOrUnsetCookie{ + &http.Cookie{Name: cookieName, Path: "/", MaxAge: -1, Secure: true, HttpOnly: true, SameSite: http.SameSiteStrictMode}, true, + } +} + +func RequestSetCookie(cookie http.Cookie) GetOrUnsetCookie { + return GetOrUnsetCookie{ + &cookie, false, + } +} + +type RememberMeCtx struct { + Session storage.ActiveSession + Cookie GetOrUnsetCookie +} + +func (ctx RememberMeCtx) IsValid() bool { + return ctx.Session.Expiry.After(time.Now()) +} + +// connector_cookie_name creates a string which is used to identify the cookie that matches the given connector. +// The purpose is to avoid having one cookie for multiple providers where you only authenticate once and suddenly would have +// access to other connectors. +func connector_cookie_name(connName string) string { + return fmt.Sprintf("%s_%s", ACTIVE_SESSION_COOKIE_NAME, connName) +} + +// HandleRememberMe either retrieves or creates a Session based on the cookie for the respective connector present in the http.Request. +// It is also responsible for issuing the unsetting / expiration of either an invalid or expired cookie. +// +// The current "design" of the cookie is a sha3 hash of the connector.Identity object as JWK signed payload. +func HandleRememberMe(ctx context.Context, logger *slog.Logger, req *http.Request, data AuthContext, store storage.Storage, sessionStore storage.ActiveSessionStorage) (*RememberMeCtx, error) { + keys, err := store.GetKeys(ctx) + if err != nil { + logger.ErrorContext(req.Context(), "failed to get keys", "err", err) + return nil, err + } + signAlg, err := jwt.SignatureAlgorithm(keys.SigningKey) + if err != nil { + logger.ErrorContext(req.Context(), "failed to get signAlg", "err", err) + return nil, err + } + if val, found := extractCookie(req, data.connectorName); found { + cookieName := connector_cookie_name(data.connectorName) + logger.DebugContext(req.Context(), "returning user cookie found, checking for active session", "connectorName", data.connectorName) + keyset := jwt.NewStorageKeySet(store) + logger.DebugContext(req.Context(), "verifying cookie", "connectorName", data.connectorName) + _, err := keyset.VerifySignature(ctx, val) + if err != nil { + return &RememberMeCtx{ + Session: emptySession, + Cookie: RequestUnsetCookie(cookieName), + }, err + } + session, err := sessionStore.GetSession(ctx, val) + if err != nil { + if errors.Is(err, storage.ErrNotFound) { + return &RememberMeCtx{ + Session: session, + Cookie: RequestUnsetCookie(cookieName), + }, nil + } + logger.ErrorContext(req.Context(), "failed to get active session", "err", err, "connectorName", data.connectorName) + return nil, err + } + cookie := GetOrUnsetCookie{nil, false} + if session.Expiry.Before(time.Now()) { + logger.DebugContext(req.Context(), "session expired unsetting cookie", "connectorName", data.connectorName) + cookie = RequestUnsetCookie(cookieName) + } + return &RememberMeCtx{ + Session: session, + Cookie: cookie, + }, nil + } else { + if data.identity == nil { + logger.DebugContext(req.Context(), "identity is empty, returning early", "connectorName", data.connectorName) + return nil, storage.ErrNotFound + } + h := sha3.New512() + h.Write([]byte(data.identity.Email)) + for _, g := range data.identity.Groups { + h.Write([]byte(g)) + } + h.Write([]byte(data.identity.UserID)) + h.Write([]byte(data.identity.Username)) + h.Write([]byte(data.identity.PreferredUsername)) + hash := fmt.Sprintf("%x", h.Sum(nil)) + signedHash, err := jwt.SignPayload(keys.SigningKey, signAlg, []byte(hash)) + if err != nil { + logger.ErrorContext(req.Context(), "failed to get sign payload", "err", err, "connectorName", data.connectorName) + return nil, err + } + // TODO(juf): Double check what we need to persist and are given + // in the context of whether we need to make an "auto-redirect" + // Because technically we do not return the ID nor RefreshToken to the user + // instead we redirect him back to the caller with an authCode + session := storage.ActiveSession{ + Identity: *data.identity, // TODO(juf): Avoid nil pointer + // TODO(juf): Think about changing to use Token IssuedAt date instead of now to have + // alignment with the token + Expiry: time.Now().Add(data.configuredExpiryDuration), + } + logger.DebugContext(req.Context(), "creating active session for user", "connectorName", data.connectorName) + if err := sessionStore.CreateSession(ctx, signedHash, session); err != nil { + logger.ErrorContext(req.Context(), "failed to store active session", "err", err, "connectorName", data.connectorName) + return nil, err + } + + return &RememberMeCtx{ + Session: session, + Cookie: RequestSetCookie(http.Cookie{ + Name: connector_cookie_name(data.connectorName), + Value: signedHash, + Path: "/", + Domain: "", // TODO(juf): Check if we need to set this + Expires: session.Expiry, + MaxAge: int(time.Until(session.Expiry).Seconds()), + Secure: true, + HttpOnly: true, + SameSite: http.SameSiteStrictMode, + }), + }, nil + } +} + +func extractCookie(req *http.Request, connName string) (value string, found bool) { + cookies := req.Cookies() + if len(cookies) > 0 { + for _, ck := range cookies { + if ck.Name != connector_cookie_name(connName) { + continue + } + return ck.Value, true + } + } + return "", false +} diff --git a/internal/remember-me/handler_test.go b/internal/remember-me/handler_test.go new file mode 100644 index 00000000..3cd3b863 --- /dev/null +++ b/internal/remember-me/handler_test.go @@ -0,0 +1,373 @@ +package rememberme + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "errors" + "log/slog" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/go-jose/go-jose/v4" + "github.com/stretchr/testify/require" + + "github.com/dexidp/dex/connector" + "github.com/dexidp/dex/storage" + "github.com/dexidp/dex/storage/memory" +) + +// setupTestEnvironment creates a realistic test environment following dex patterns +func setupTestEnvironment(t *testing.T) (storage.Storage, storage.ActiveSessionStorage, *slog.Logger) { + logger := slog.New(slog.NewTextHandler(nil, &slog.HandlerOptions{Level: slog.LevelError})) + + // Use in-memory storage + store := memory.New(logger) + sessionStore := memory.NewSessionStore(logger) + + // Initialize with real keys + ctx := context.Background() + err := store.UpdateKeys(ctx, func(old storage.Keys) (storage.Keys, error) { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + signingKey := &jose.JSONWebKey{Key: key} + signingKeyPub := &jose.JSONWebKey{Key: &key.PublicKey} + + return storage.Keys{ + SigningKey: signingKey, + SigningKeyPub: signingKeyPub, + }, nil + }) + require.NoError(t, err) + + return store, sessionStore, logger +} + +// createTestRequest creates an HTTP request with optional cookie +func createTestRequest(connectorName string, cookieValue string) *http.Request { + req := httptest.NewRequest("GET", "/auth", nil) + if cookieValue != "" { + cookieName := connector_cookie_name(connectorName) + req.AddCookie(&http.Cookie{Name: cookieName, Value: cookieValue}) + } + return req +} + +// createTestIdentity creates a sample identity for testing +func createTestIdentity() connector.Identity { + return connector.Identity{ + UserID: "user123", + Username: "testuser", + PreferredUsername: "testuser", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"group1", "group2"}, + } +} + +func TestHandleRememberMe_Integration(t *testing.T) { + store, sessionStore, logger := setupTestEnvironment(t) + ctx := context.Background() + connectorName := "test-connector" + expiryDuration := 24 * time.Hour + identity := createTestIdentity() + + tests := []struct { + name string + setup func() (*http.Request, AuthContext) + want func(t *testing.T, result *RememberMeCtx, err error) + }{ + { + name: "no cookie with anonymous context returns ErrNotFound", + setup: func() (*http.Request, AuthContext) { + req := createTestRequest(connectorName, "") + authCtx := NewAnonymousAuthContext(connectorName, expiryDuration) + return req, authCtx + }, + want: func(t *testing.T, result *RememberMeCtx, err error) { + require.Error(t, err) + require.True(t, errors.Is(err, storage.ErrNotFound)) + require.Nil(t, result) + }, + }, + { + name: "no cookie with identity creates new session and sets cookie", + setup: func() (*http.Request, AuthContext) { + req := createTestRequest(connectorName, "") + authCtx := NewAuthContextWithIdentity(connectorName, identity, expiryDuration) + return req, authCtx + }, + want: func(t *testing.T, result *RememberMeCtx, err error) { + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.IsValid()) + + // Verify session details + require.Equal(t, identity.UserID, result.Session.Identity.UserID) + require.Equal(t, identity.Email, result.Session.Identity.Email) + require.Equal(t, identity.Groups, result.Session.Identity.Groups) + require.True(t, result.Session.Expiry.After(time.Now())) + + // Verify cookie is set + require.False(t, result.Cookie.Empty()) + cookie, unset := result.Cookie.Get() + require.False(t, unset) + require.Equal(t, connector_cookie_name(connectorName), cookie.Name) + require.NotEmpty(t, cookie.Value) + require.True(t, cookie.Secure) + require.True(t, cookie.HttpOnly) + require.Equal(t, http.SameSiteStrictMode, cookie.SameSite) + + // Verify session was stored + storedSession, err := sessionStore.GetSession(ctx, cookie.Value) + require.NoError(t, err) + require.Equal(t, identity.UserID, storedSession.Identity.UserID) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req, authCtx := tt.setup() + result, err := HandleRememberMe(ctx, logger, req, authCtx, store, sessionStore) + tt.want(t, result, err) + }) + } +} + +func TestHandleRememberMe_WithExistingSessions(t *testing.T) { + store, sessionStore, logger := setupTestEnvironment(t) + ctx := context.Background() + connectorName := "test-connector" + expiryDuration := 24 * time.Hour + identity := createTestIdentity() + + // First create a session to test retrieval + req := createTestRequest(connectorName, "") + authCtx := NewAuthContextWithIdentity(connectorName, identity, expiryDuration) + initialResult, err := HandleRememberMe(ctx, logger, req, authCtx, store, sessionStore) + require.NoError(t, err) + require.NotNil(t, initialResult) + + cookie, _ := initialResult.Cookie.Get() + cookieValue := cookie.Value + + t.Run("valid cookie with active session returns session without new cookie", func(t *testing.T) { + req := createTestRequest(connectorName, cookieValue) + authCtx := NewAnonymousAuthContext(connectorName, expiryDuration) + + result, err := HandleRememberMe(ctx, logger, req, authCtx, store, sessionStore) + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.IsValid()) + require.Equal(t, identity.UserID, result.Session.Identity.UserID) + require.True(t, result.Cookie.Empty()) // No cookie change needed + }) + + t.Run("expired session unsets cookie", func(t *testing.T) { + // Create a fresh session store to avoid ID conflicts + freshStore, freshSessionStore, freshLogger := setupTestEnvironment(t) + + expiredIdentity := connector.Identity{ + UserID: "expired-user", + Username: "expireduser", + Email: "expired@example.com", + Groups: []string{"expired-group"}, + } + + // Create an expired session directly with a known identifier + expiredSession := storage.ActiveSession{ + Identity: expiredIdentity, + Expiry: time.Now().Add(-time.Hour), + } + + // Create the session using a predictable signed identifier + // First get a signed identifier by creating a valid session + tempReq := createTestRequest(connectorName, "") + tempAuthCtx := NewAuthContextWithIdentity(connectorName, expiredIdentity, time.Hour) // short duration + tempResult, err := HandleRememberMe(ctx, freshLogger, tempReq, tempAuthCtx, freshStore, freshSessionStore) + require.NoError(t, err) + + tempCookie, _ := tempResult.Cookie.Get() + sessionID := tempCookie.Value + + // Wait a moment and directly update the session in storage to be expired + // by creating a new session store and directly setting expired session + testSessionStore := memory.NewSessionStore(freshLogger) + err = testSessionStore.CreateSession(ctx, sessionID, expiredSession) + require.NoError(t, err) + + // Test with the expired session + req := createTestRequest(connectorName, sessionID) + authCtx := NewAnonymousAuthContext(connectorName, expiryDuration) + result, err := HandleRememberMe(ctx, freshLogger, req, authCtx, freshStore, testSessionStore) + + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, result.IsValid()) + + // Cookie should be unset + require.False(t, result.Cookie.Empty()) + resultCookie, unset := result.Cookie.Get() + require.True(t, unset) + require.Equal(t, -1, resultCookie.MaxAge) + }) + + t.Run("invalid cookie signature unsets cookie", func(t *testing.T) { + // Use an invalid JWT format + invalidCookie := "invalid.jwt.signature" + req := createTestRequest(connectorName, invalidCookie) + authCtx := NewAnonymousAuthContext(connectorName, expiryDuration) + + result, err := HandleRememberMe(ctx, logger, req, authCtx, store, sessionStore) + require.Error(t, err) + require.NotNil(t, result) + require.False(t, result.IsValid()) + + // Cookie should be unset + require.False(t, result.Cookie.Empty()) + cookie, unset := result.Cookie.Get() + require.True(t, unset) + require.Equal(t, connector_cookie_name(connectorName), cookie.Name) + }) +} + +func TestHandleRememberMe_EndToEndWorkflow(t *testing.T) { + store, sessionStore, logger := setupTestEnvironment(t) + ctx := context.Background() + connectorName := "test-connector" + expiryDuration := 24 * time.Hour + identity := createTestIdentity() + + t.Run("complete login workflow", func(t *testing.T) { + // Step 1: Initial login - no cookie present + req1 := createTestRequest(connectorName, "") + authCtx1 := NewAuthContextWithIdentity(connectorName, identity, expiryDuration) + + result1, err := HandleRememberMe(ctx, logger, req1, authCtx1, store, sessionStore) + require.NoError(t, err) + require.True(t, result1.IsValid()) + + cookie1, unset1 := result1.Cookie.Get() + require.False(t, unset1) + require.NotEmpty(t, cookie1.Value) + + // Step 2: Return visit with cookie - should recognize user + req2 := createTestRequest(connectorName, cookie1.Value) + authCtx2 := NewAnonymousAuthContext(connectorName, expiryDuration) + + result2, err := HandleRememberMe(ctx, logger, req2, authCtx2, store, sessionStore) + require.NoError(t, err) + require.True(t, result2.IsValid()) + require.Equal(t, identity.UserID, result2.Session.Identity.UserID) + require.True(t, result2.Cookie.Empty()) // No cookie change needed + + // Step 3: Verify session consistency + require.Equal(t, result1.Session.Identity.UserID, result2.Session.Identity.UserID) + require.Equal(t, result1.Session.Identity.Email, result2.Session.Identity.Email) + }) + + t.Run("multiple connectors isolation", func(t *testing.T) { + connector1 := "connector-1" + connector2 := "connector-2" + + // Create session for connector1 + req1 := createTestRequest(connector1, "") + authCtx1 := NewAuthContextWithIdentity(connector1, identity, expiryDuration) + result1, err := HandleRememberMe(ctx, logger, req1, authCtx1, store, sessionStore) + require.NoError(t, err) + + cookie1, _ := result1.Cookie.Get() + + // Create a request for connector2 but with connector1's cookie + // This simulates having both cookies in the browser + req2 := httptest.NewRequest("GET", "/auth", nil) + req2.AddCookie(cookie1) // connector1's cookie + + // Since connector2 looks for its own cookie name, it won't find connector1's cookie + authCtx2 := NewAnonymousAuthContext(connector2, expiryDuration) + result2, err := HandleRememberMe(ctx, logger, req2, authCtx2, store, sessionStore) + require.Error(t, err) + require.True(t, errors.Is(err, storage.ErrNotFound)) + require.Nil(t, result2) + }) +} + +func TestHandleRememberMe_ErrorHandling(t *testing.T) { + store, sessionStore, logger := setupTestEnvironment(t) + ctx := context.Background() + connectorName := "test-connector" + expiryDuration := 24 * time.Hour + identity := createTestIdentity() + + t.Run("session not found after valid signature verification", func(t *testing.T) { + // Create a session first to get a valid signed cookie + req := createTestRequest(connectorName, "") + authCtx := NewAuthContextWithIdentity(connectorName, identity, expiryDuration) + result, err := HandleRememberMe(ctx, logger, req, authCtx, store, sessionStore) + require.NoError(t, err) + + cookie, _ := result.Cookie.Get() + + // Create a different session store that doesn't have this session + emptySessionStore := memory.NewSessionStore(logger) + + // Try to retrieve with valid cookie but empty session store + req2 := createTestRequest(connectorName, cookie.Value) + authCtx2 := NewAnonymousAuthContext(connectorName, expiryDuration) + result2, err := HandleRememberMe(ctx, logger, req2, authCtx2, store, emptySessionStore) + + require.NoError(t, err) + require.NotNil(t, result2) + require.False(t, result2.IsValid()) + + // Should unset cookie when session not found + require.False(t, result2.Cookie.Empty()) + unsetCookie, unset := result2.Cookie.Get() + require.True(t, unset) + require.Equal(t, connector_cookie_name(connectorName), unsetCookie.Name) + }) +} + +func TestExtractCookie(t *testing.T) { + connectorName := "test-connector" + + t.Run("cookie present", func(t *testing.T) { + req := httptest.NewRequest("GET", "/auth", nil) + req.AddCookie(&http.Cookie{Name: connector_cookie_name(connectorName), Value: "test-value"}) + req.AddCookie(&http.Cookie{Name: "other-cookie", Value: "ignored"}) + + value, found := extractCookie(req, connectorName) + require.True(t, found) + require.Equal(t, "test-value", value) + }) + + t.Run("cookie not present", func(t *testing.T) { + req := httptest.NewRequest("GET", "/auth", nil) + value, found := extractCookie(req, connectorName) + require.False(t, found) + require.Equal(t, "", value) + }) +} + +func TestConnectorCookieName(t *testing.T) { + tests := []struct { + connector string + expected string + }{ + {"test", "dex_active_session_cookie_test"}, + {"google", "dex_active_session_cookie_google"}, // just an example, google would not be use-case for this feature + {"ldap-local", "dex_active_session_cookie_ldap-local"}, + } + + for _, tt := range tests { + t.Run(tt.connector, func(t *testing.T) { + result := connector_cookie_name(tt.connector) + require.Equal(t, tt.expected, result) + }) + } +} diff --git a/server/handlers.go b/server/handlers.go index e46c7b8f..d2654b9e 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -7,6 +7,7 @@ import ( "crypto/subtle" "encoding/base64" "encoding/json" + "errors" "fmt" "html/template" "net/http" @@ -22,6 +23,8 @@ import ( "github.com/gorilla/mux" "github.com/dexidp/dex/connector" + "github.com/dexidp/dex/internal/jwt" + rememberme "github.com/dexidp/dex/internal/remember-me" "github.com/dexidp/dex/server/internal" "github.com/dexidp/dex/storage" ) @@ -321,6 +324,11 @@ func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) { s.renderError(r, w, http.StatusBadRequest, "User session error.") return } + if s.enableRememberMe { + s.logger.InfoContext(r.Context(), "enableRememberMe enabled") + } else { + s.logger.InfoContext(r.Context(), "enableRememberMe disabled") + } backLink := r.URL.Query().Get("back") @@ -363,6 +371,54 @@ func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: + if s.enableRememberMe { + rememberData, err := rememberme.HandleRememberMe(ctx, s.logger, r, rememberme.NewAnonymousAuthContext(authReq.ConnectorID, s.idTokensValidFor), s.storage, s.sessionStorage) + if err != nil { + if !errors.Is(err, storage.ErrNotFound) { + if rememberData != nil && !rememberData.Cookie.Empty() { + // Overwrite or unset the cookie in certain error cases to allow for "natural" + // recovery, e.g., the cookie is present but malformatted, then it should be unset. + cookie, _ := rememberData.Cookie.Get() + http.SetCookie(w, cookie) + } + s.logger.ErrorContext(r.Context(), "failed to call HandleRememberMe handler", "err", err) + s.renderError(r, w, http.StatusInternalServerError, "TODO") + return + } + } + if !errors.Is(err, storage.ErrNotFound) { + s.logger.DebugContext(r.Context(), "returning user session was found") + if !rememberData.Cookie.Empty() { + cookie, unset := rememberData.Cookie.Get() + if unset { + s.logger.DebugContext(r.Context(), "unsetting cookie", "cookie", *cookie) + http.SetCookie(w, cookie) + } + } + if rememberData.IsValid() { + redirectURL, canSkipApproval, err := s.finalizeLogin(ctx, rememberData.Session.Identity, authReq, conn) + if err != nil { + s.logger.ErrorContext(r.Context(), "failed to finalize login using rememberme", "err", err) + s.renderError(r, w, http.StatusInternalServerError, "Login error.") + return + } + + if canSkipApproval { + authReq, err = s.storage.GetAuthRequest(ctx, authReq.ID) + if err != nil { + s.logger.ErrorContext(r.Context(), "failed to get finalized auth request using rememberme", "err", err) + s.renderError(r, w, http.StatusInternalServerError, "Login error.") + return + } + s.sendCodeResponse(w, r, authReq) + return + } + + http.Redirect(w, r, redirectURL, http.StatusSeeOther) + return + } + } + } if err := s.templates.password(r, w, r.URL.String(), "", usernamePrompt(pwConn), false, backLink); err != nil { s.logger.ErrorContext(r.Context(), "server template error", "err", err) } @@ -390,6 +446,23 @@ func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) { s.renderError(r, w, http.StatusInternalServerError, "Login error.") return } + if s.enableRememberMe { + rememberData, err := rememberme.HandleRememberMe(ctx, s.logger, r, rememberme.NewAuthContextWithIdentity(authReq.ConnectorID, identity, s.idTokensValidFor), s.storage, s.sessionStorage) + if err != nil { + if !errors.Is(err, storage.ErrNotFound) { + s.logger.ErrorContext(r.Context(), "failed to call HandleRememberMe handler", "err", err) + s.renderError(r, w, http.StatusInternalServerError, "TODO") + return + } + } + if !errors.Is(err, storage.ErrNotFound) { + s.logger.ErrorContext(r.Context(), "did find returning user") + if !rememberData.Cookie.Empty() { + cookie, _ := rememberData.Cookie.Get() + http.SetCookie(w, cookie) // this sets or unsets the cookie based on it's content + } + } + } if canSkipApproval { authReq, err = s.storage.GetAuthRequest(ctx, authReq.ID) @@ -1099,7 +1172,7 @@ func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) { } rawIDToken := auth[len(prefix):] - verifier := oidc.NewVerifier(s.issuerURL.String(), &storageKeySet{s.storage}, &oidc.Config{SkipClientIDCheck: true}) + verifier := oidc.NewVerifier(s.issuerURL.String(), jwt.NewStorageKeySet(s.storage), &oidc.Config{SkipClientIDCheck: true}) idToken, err := verifier.Verify(ctx, rawIDToken) if err != nil { s.logger.ErrorContext(r.Context(), "failed to verify ID token", "err", err) diff --git a/server/introspectionhandler.go b/server/introspectionhandler.go index 4b0073db..32ca166d 100644 --- a/server/introspectionhandler.go +++ b/server/introspectionhandler.go @@ -9,6 +9,7 @@ import ( "github.com/coreos/go-oidc/v3/oidc" + "github.com/dexidp/dex/internal/jwt" "github.com/dexidp/dex/server/internal" ) @@ -245,7 +246,7 @@ func (s *Server) introspectRefreshToken(ctx context.Context, token string) (*Int } func (s *Server) introspectAccessToken(ctx context.Context, token string) (*Introspection, error) { - verifier := oidc.NewVerifier(s.issuerURL.String(), &storageKeySet{s.storage}, &oidc.Config{SkipClientIDCheck: true}) + verifier := oidc.NewVerifier(s.issuerURL.String(), jwt.NewStorageKeySet(s.storage), &oidc.Config{SkipClientIDCheck: true}) idToken, err := verifier.Verify(ctx, token) if err != nil { return nil, newIntrospectInactiveTokenError() diff --git a/server/oauth2.go b/server/oauth2.go index 7268bcfd..7951bc68 100644 --- a/server/oauth2.go +++ b/server/oauth2.go @@ -17,6 +17,7 @@ import ( "net" "net/http" "net/url" + "slices" "strconv" "strings" "time" @@ -259,15 +260,6 @@ func accessTokenHash(alg jose.SignatureAlgorithm, accessToken string) (string, e type audience []string -func (a audience) contains(aud string) bool { - for _, e := range a { - if aud == e { - return true - } - } - return false -} - func (a audience) MarshalJSON() ([]byte, error) { if len(a) == 1 { return json.Marshal(a[0]) @@ -333,7 +325,7 @@ func getAudience(clientID string, scopes []string) audience { aud = audience{clientID} // Client asked for cross client audience: // if the current client was not requested explicitly - } else if !aud.contains(clientID) { + } else if !slices.Contains(aud, clientID) { // by default it becomes one of entries in Audience aud = append(aud, clientID) } @@ -704,41 +696,3 @@ func validateConnectorID(connectors []storage.Connector, connectorID string) boo } return false } - -// storageKeySet implements the oidc.KeySet interface backed by Dex storage -type storageKeySet struct { - storage.Storage -} - -func (s *storageKeySet) VerifySignature(ctx context.Context, jwt string) (payload []byte, err error) { - jws, err := jose.ParseSigned(jwt, []jose.SignatureAlgorithm{jose.RS256, jose.RS384, jose.RS512, jose.ES256, jose.ES384, jose.ES512}) - if err != nil { - return nil, err - } - - keyID := "" - for _, sig := range jws.Signatures { - keyID = sig.Header.KeyID - break - } - - skeys, err := s.Storage.GetKeys(ctx) - if err != nil { - return nil, err - } - - keys := []*jose.JSONWebKey{skeys.SigningKeyPub} - for _, vk := range skeys.VerificationKeys { - keys = append(keys, vk.PublicKey) - } - - for _, key := range keys { - if keyID == "" || key.KeyID == keyID { - if payload, err := jws.Verify(key); err == nil { - return payload, nil - } - } - } - - return nil, errors.New("failed to verify id token signature") -} diff --git a/server/oauth2_test.go b/server/oauth2_test.go index 3dff30d6..e9aecf4f 100644 --- a/server/oauth2_test.go +++ b/server/oauth2_test.go @@ -12,6 +12,7 @@ import ( "github.com/go-jose/go-jose/v4" "github.com/stretchr/testify/require" + jwt_lib "github.com/dexidp/dex/internal/jwt" "github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage/memory" ) @@ -668,7 +669,7 @@ func TestStorageKeySet(t *testing.T) { t.Fatal(err) } - keySet := &storageKeySet{s} + keySet := jwt_lib.NewStorageKeySet(s) _, err = keySet.VerifySignature(t.Context(), jwt) if (err != nil && !tc.wantErr) || (err == nil && tc.wantErr) { diff --git a/server/server.go b/server/server.go index d81a0f71..ac3f5ba5 100644 --- a/server/server.go +++ b/server/server.go @@ -46,6 +46,7 @@ import ( "github.com/dexidp/dex/connector/openshift" "github.com/dexidp/dex/connector/saml" "github.com/dexidp/dex/storage" + "github.com/dexidp/dex/storage/memory" "github.com/dexidp/dex/web" ) @@ -65,6 +66,9 @@ type Connector struct { type Config struct { Issuer string + // Use cookies and keep active sessions in storage + EnableRememberMe bool + // The backing persistence layer. Storage storage.Storage @@ -172,7 +176,8 @@ type Server struct { // Map of connector IDs to connectors. connectors map[string]Connector - storage storage.Storage + storage storage.Storage + sessionStorage storage.ActiveSessionStorage mux http.Handler @@ -181,6 +186,8 @@ type Server struct { // If enabled, don't prompt user for approval after logging in through connector. skipApproval bool + enableRememberMe bool + // If enabled, show the connector selection screen even if there's only one alwaysShowLogin bool @@ -316,6 +323,10 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) templates: tmpls, passwordConnector: c.PasswordConnector, logger: c.Logger, + enableRememberMe: c.EnableRememberMe, + } + if c.EnableRememberMe { + s.sessionStorage = memory.NewSessionStore(s.logger) } // Retrieves connector objects in backend storage. This list includes the static connectors @@ -651,6 +662,14 @@ func (s *Server) startGarbageCollection(ctx context.Context, frequency time.Dura "requests", r.AuthRequests, "auth_codes", r.AuthCodes, "device_requests", r.DeviceRequests, "device_tokens", r.DeviceTokens) } + if s.sessionStorage != nil { + if r, err := s.sessionStorage.GarbageCollect(ctx, now()); err != nil { + s.logger.ErrorContext(ctx, "garbage collection for session storage failed", "err", err) + } else if !r.IsEmpty() { + s.logger.InfoContext(ctx, "garbage collection for session storage run", + "sessions", r.Sessions) + } + } } } }() diff --git a/storage/conformance/gen_jwks.go b/storage/conformance/gen_jwks.go index 0029b9b8..b5affcc8 100644 --- a/storage/conformance/gen_jwks.go +++ b/storage/conformance/gen_jwks.go @@ -105,7 +105,7 @@ func main() { if err != nil { log.Fatalf("gofmt failed: %v", err) } - if err := os.WriteFile("jwks.go", out, 0644); err != nil { + if err := os.WriteFile("jwks.go", out, 0o644); err != nil { log.Fatal(err) } } diff --git a/storage/memory/memory.go b/storage/memory/memory.go index eff75e71..9eb7ea25 100644 --- a/storage/memory/memory.go +++ b/storage/memory/memory.go @@ -11,7 +11,10 @@ import ( "github.com/dexidp/dex/storage" ) -var _ storage.Storage = (*memStorage)(nil) +var ( + _ storage.Storage = (*memStorage)(nil) + _ storage.ActiveSessionStorage = (*memStorage)(nil) +) // New returns an in memory storage. func New(logger *slog.Logger) storage.Storage { @@ -25,10 +28,18 @@ func New(logger *slog.Logger) storage.Storage { connectors: make(map[string]storage.Connector), deviceRequests: make(map[string]storage.DeviceRequest), deviceTokens: make(map[string]storage.DeviceToken), + sessions: make(map[string]storage.ActiveSession), logger: logger, } } +func NewSessionStore(logger *slog.Logger) *memStorage { + return &memStorage{ + sessions: make(map[string]storage.ActiveSession), + logger: logger, + } +} + // Config is an implementation of a storage configuration. // // TODO(ericchiang): Actually define a storage config interface and have registration. @@ -52,12 +63,37 @@ type memStorage struct { connectors map[string]storage.Connector deviceRequests map[string]storage.DeviceRequest deviceTokens map[string]storage.DeviceToken + sessions map[string]storage.ActiveSession keys storage.Keys logger *slog.Logger } +// CreateSession implements storage.ActiveSessionStorage. +func (s *memStorage) CreateSession(ctx context.Context, identifier string, data storage.ActiveSession) (err error) { + s.tx(func() { + if _, ok := s.sessions[identifier]; ok { + err = storage.ErrAlreadyExists + } else { + s.sessions[identifier] = data + } + }) + return +} + +// GetSession implements storage.ActiveSessionStorage. +func (s *memStorage) GetSession(ctx context.Context, identifier string) (session storage.ActiveSession, err error) { + s.tx(func() { + var ok bool + if session, ok = s.sessions[identifier]; !ok { + err = storage.ErrNotFound + return + } + }) + return +} + type offlineSessionID struct { userID string connID string @@ -97,6 +133,12 @@ func (s *memStorage) GarbageCollect(ctx context.Context, now time.Time) (result result.DeviceTokens++ } } + for id, a := range s.sessions { + if now.After(a.Expiry) { + delete(s.sessions, id) + result.Sessions++ + } + } }) return result, nil } diff --git a/storage/storage.go b/storage/storage.go index 79a2fca3..d547870f 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -12,6 +12,8 @@ import ( "time" "github.com/go-jose/go-jose/v4" + + "github.com/dexidp/dex/connector" ) var ( @@ -60,6 +62,7 @@ type GCResult struct { AuthCodes int64 DeviceRequests int64 DeviceTokens int64 + Sessions int64 } // IsEmpty returns whether the garbage collection result is empty or not. @@ -67,7 +70,8 @@ func (g *GCResult) IsEmpty() bool { return g.AuthRequests == 0 && g.AuthCodes == 0 && g.DeviceRequests == 0 && - g.DeviceTokens == 0 + g.DeviceTokens == 0 && + g.Sessions == 0 } // Storage is the storage interface used by the server. Implementations are @@ -142,6 +146,19 @@ type Storage interface { GarbageCollect(ctx context.Context, now time.Time) (GCResult, error) } +type ActiveSessionStorage interface { + GetSession(ctx context.Context, identifier string) (ActiveSession, error) + CreateSession(ctx context.Context, identifier string, data ActiveSession) error + GarbageCollect(ctx context.Context, now time.Time) (GCResult, error) +} + +type ActiveSession struct { + // TODO(juf): Think about storing only claim/identity data or reference to OfflineSession + // and create a new token every time instead + Expiry time.Time + Identity connector.Identity +} + // Client represents an OAuth2 client. // // For further reading see: