From c3bc1d7466e02e696fd65e4e7d7c9631732fd142 Mon Sep 17 00:00:00 2001 From: Maksim Nabokikh Date: Thu, 19 Mar 2026 15:53:15 +0100 Subject: [PATCH] feat: add auth_time, prompt, and max_age fields (#4662) Signed-off-by: maksim.nabokikh --- .gitignore | 1 + server/handlers.go | 60 +++-- server/introspectionhandler_test.go | 4 +- server/oauth2.go | 49 +++- server/prompt.go | 77 ++++++ server/prompt_test.go | 64 +++++ server/refreshhandlers.go | 15 +- server/refreshhandlers_test.go | 138 ++++++++++ server/session.go | 11 +- server/session_test.go | 181 ++++++++++++- storage/conformance/conformance.go | 14 +- storage/conformance/transactions.go | 1 + storage/ent/client/authcode.go | 1 + storage/ent/client/authrequest.go | 6 + storage/ent/client/types.go | 4 + storage/ent/db/authcode.go | 15 +- storage/ent/db/authcode/authcode.go | 8 + storage/ent/db/authcode/where.go | 55 ++++ storage/ent/db/authcode_create.go | 18 ++ storage/ent/db/authcode_update.go | 52 ++++ storage/ent/db/authrequest.go | 39 ++- storage/ent/db/authrequest/authrequest.go | 28 ++ storage/ent/db/authrequest/where.go | 170 +++++++++++++ storage/ent/db/authrequest_create.go | 68 +++++ storage/ent/db/authrequest_update.go | 140 ++++++++++ storage/ent/db/migrate/schema.go | 4 + storage/ent/db/mutation.go | 296 +++++++++++++++++++++- storage/ent/db/runtime.go | 8 + storage/ent/schema/authcode.go | 3 + storage/ent/schema/authrequest.go | 3 + storage/etcd/types.go | 14 + storage/kubernetes/types.go | 14 + storage/sql/crud.go | 26 +- storage/sql/migrate.go | 4 + storage/storage.go | 18 ++ 35 files changed, 1558 insertions(+), 51 deletions(-) create mode 100644 server/prompt.go create mode 100644 server/prompt_test.go diff --git a/.gitignore b/.gitignore index 11cfbe81..7e5ae46e 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ /docker-compose.override.yaml /var/ /vendor/ +*.db diff --git a/server/handlers.go b/server/handlers.go index 31a160f2..d810dbef 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -365,14 +365,39 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { return } - // Check if there's a valid session that can skip login for this client. + // Handle OIDC prompt parameter and session-based login. + prompt, err := ParsePrompt(authReq.Prompt) + if err != nil { + // Server error because authReq was validated before saving it to database. + s.redirectWithError(w, r, authReq, errServerError, "Invalid authentication request") + return + } + // handle prompt only if sessions are enabled if s.sessionConfig != nil { - if redirectURL, ok := s.trySessionLogin(ctx, r, w, authReq); ok { + // prompt=none: no UI allowed. + if prompt.None() { + redirectURL, ok := s.trySessionLogin(ctx, r, w, authReq) + if !ok { + s.redirectWithError(w, r, authReq, errLoginRequired, "User not authenticated") + return + } if redirectURL != "" { - http.Redirect(w, r, redirectURL, http.StatusSeeOther) + // Session found but consent required — no UI allowed. + s.redirectWithError(w, r, authReq, errInteractionRequired, "Consent required") + return } return } + + if !prompt.Login() { + // Normal flow: try session-based login (skip if prompt=login forces re-auth). + if redirectURL, ok := s.trySessionLogin(ctx, r, w, authReq); ok { + if redirectURL != "" { + http.Redirect(w, r, redirectURL, http.StatusSeeOther) + } + return + } + } } scopes := parseScopes(authReq.Scopes) @@ -687,6 +712,7 @@ func (s *Server) finalizeLogin(ctx context.Context, identity connector.Identity, a.LoggedIn = true a.Claims = claims a.ConnectorData = identity.ConnectorData + a.AuthTime = s.now() return a, nil } if err := s.storage.UpdateAuthRequest(ctx, authReq.ID, updater); err != nil { @@ -769,11 +795,8 @@ func (s *Server) finalizeLogin(ctx context.Context, identity connector.Identity, } case err == nil: if err := s.storage.UpdateUserIdentity(ctx, identity.UserID, authReq.ConnectorID, func(old storage.UserIdentity) (storage.UserIdentity, error) { - if len(identity.ConnectorData) > 0 { - old.Claims = claims - old.LastLogin = now - return old, nil - } + old.Claims = claims + old.LastLogin = now return old, nil }); err != nil { s.logger.ErrorContext(ctx, "failed to update user identity", "err", err) @@ -988,6 +1011,7 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe RedirectURI: authReq.RedirectURI, ConnectorData: authReq.ConnectorData, PKCE: authReq.PKCE, + AuthTime: authReq.AuthTime, } if err := s.storage.CreateAuthCode(ctx, code); err != nil { s.logger.ErrorContext(r.Context(), "Failed to create auth code", "err", err) @@ -1007,7 +1031,7 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe implicitOrHybrid = true var err error - accessToken, _, err = s.newAccessToken(r.Context(), authReq.ClientID, authReq.Claims, authReq.Scopes, authReq.Nonce, authReq.ConnectorID) + accessToken, _, err = s.newAccessToken(r.Context(), authReq.ClientID, authReq.Claims, authReq.Scopes, authReq.Nonce, authReq.ConnectorID, authReq.AuthTime) if err != nil { s.logger.ErrorContext(r.Context(), "failed to create new access token", "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) @@ -1017,7 +1041,7 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe implicitOrHybrid = true var err error - idToken, idTokenExpiry, err = s.newIDToken(r.Context(), authReq.ClientID, authReq.Claims, authReq.Scopes, authReq.Nonce, accessToken, code.ID, authReq.ConnectorID) + idToken, idTokenExpiry, err = s.newIDToken(r.Context(), authReq.ClientID, authReq.Claims, authReq.Scopes, authReq.Nonce, accessToken, code.ID, authReq.ConnectorID, authReq.AuthTime) if err != nil { s.logger.ErrorContext(r.Context(), "failed to create ID token", "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) @@ -1251,14 +1275,14 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s } func (s *Server) exchangeAuthCode(ctx context.Context, w http.ResponseWriter, authCode storage.AuthCode, client storage.Client) (*accessTokenResponse, error) { - accessToken, _, err := s.newAccessToken(ctx, client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, authCode.ConnectorID) + accessToken, _, err := s.newAccessToken(ctx, client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, authCode.ConnectorID, authCode.AuthTime) if err != nil { s.logger.ErrorContext(ctx, "failed to create new access token", "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) return nil, err } - idToken, expiry, err := s.newIDToken(ctx, client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, accessToken, authCode.ID, authCode.ConnectorID) + idToken, expiry, err := s.newIDToken(ctx, client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, accessToken, authCode.ID, authCode.ConnectorID, authCode.AuthTime) if err != nil { s.logger.ErrorContext(ctx, "failed to create ID token", "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) @@ -1539,14 +1563,14 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli Groups: identity.Groups, } - accessToken, _, err := s.newAccessToken(ctx, client.ID, claims, scopes, nonce, connID) + accessToken, _, err := s.newAccessToken(ctx, client.ID, claims, scopes, nonce, connID, time.Time{}) if err != nil { s.logger.ErrorContext(r.Context(), "password grant failed to create new access token", "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) return } - idToken, expiry, err := s.newIDToken(ctx, client.ID, claims, scopes, nonce, accessToken, "", connID) + idToken, expiry, err := s.newIDToken(ctx, client.ID, claims, scopes, nonce, accessToken, "", connID, time.Time{}) if err != nil { s.logger.ErrorContext(r.Context(), "password grant failed to create new ID token", "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) @@ -1752,9 +1776,9 @@ func (s *Server) handleTokenExchange(w http.ResponseWriter, r *http.Request, cli var expiry time.Time switch requestedTokenType { case tokenTypeID: - resp.AccessToken, expiry, err = s.newIDToken(r.Context(), client.ID, claims, scopes, "", "", "", connID) + resp.AccessToken, expiry, err = s.newIDToken(r.Context(), client.ID, claims, scopes, "", "", "", connID, time.Time{}) case tokenTypeAccess: - resp.AccessToken, expiry, err = s.newAccessToken(r.Context(), client.ID, claims, scopes, "", connID) + resp.AccessToken, expiry, err = s.newAccessToken(r.Context(), client.ID, claims, scopes, "", connID, time.Time{}) default: s.tokenErrHelper(w, errRequestNotSupported, "Invalid requested_token_type.", http.StatusBadRequest) return @@ -1854,7 +1878,7 @@ func (s *Server) handleClientCredentialsGrant(w http.ResponseWriter, r *http.Req // Creating connectors with an empty ID with the config and API is prohibited connID := "" - accessToken, expiry, err := s.newAccessToken(ctx, client.ID, claims, scopes, nonce, connID) + accessToken, expiry, err := s.newAccessToken(ctx, client.ID, claims, scopes, nonce, connID, time.Time{}) if err != nil { s.logger.ErrorContext(ctx, "client_credentials grant failed to create new access token", "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) @@ -1863,7 +1887,7 @@ func (s *Server) handleClientCredentialsGrant(w http.ResponseWriter, r *http.Req var idToken string if hasOpenIDScope { - idToken, expiry, err = s.newIDToken(ctx, client.ID, claims, scopes, nonce, accessToken, "", connID) + idToken, expiry, err = s.newIDToken(ctx, client.ID, claims, scopes, nonce, accessToken, "", connID, time.Time{}) if err != nil { s.logger.ErrorContext(ctx, "client_credentials grant failed to create new ID token", "err", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) diff --git a/server/introspectionhandler_test.go b/server/introspectionhandler_test.go index 799f0000..29db7d31 100644 --- a/server/introspectionhandler_test.go +++ b/server/introspectionhandler_test.go @@ -152,7 +152,7 @@ func TestGetTokenFromRequestSuccess(t *testing.T) { accessToken, _, err := s.newIDToken(ctx, "test", storage.Claims{ UserID: "1", Username: "jane", - }, []string{"openid"}, "nonce", "", "", "test") + }, []string{"openid"}, "nonce", "", "", "test", time.Time{}) require.NoError(t, err) tests := []struct { @@ -270,7 +270,7 @@ func TestHandleIntrospect(t *testing.T) { Email: "jane.doe@example.com", EmailVerified: true, Groups: []string{"a", "b"}, - }, []string{"openid", "email", "profile", "groups"}, "foo", "", "", "test") + }, []string{"openid", "email", "profile", "groups"}, "foo", "", "", "test", time.Time{}) require.NoError(t, err) activeRefreshToken, err := internal.Marshal(&internal.RefreshToken{RefreshId: "test", Token: "bar"}) diff --git a/server/oauth2.go b/server/oauth2.go index 9f12d1d0..3818c135 100644 --- a/server/oauth2.go +++ b/server/oauth2.go @@ -45,6 +45,18 @@ func newDisplayedErr(status int, format string, a ...interface{}) *displayedAuth return &displayedAuthErr{status, fmt.Sprintf(format, a...)} } +// redirectWithError redirects back to the client with an OAuth2 error response. +// Used for prompt=none when login or consent is required. +func (s *Server) redirectWithError(w http.ResponseWriter, r *http.Request, authReq *storage.AuthRequest, errType, description string) { + err := &redirectedAuthErr{ + State: authReq.State, + RedirectURI: authReq.RedirectURI, + Type: errType, + Description: description, + } + err.Handler().ServeHTTP(w, r) +} + // redirectedAuthErr is an error that should be reported back to the client by 302 redirect type redirectedAuthErr struct { State string @@ -117,6 +129,9 @@ const ( errInvalidGrant = "invalid_grant" errInvalidClient = "invalid_client" errInactiveToken = "inactive_token" + errLoginRequired = "login_required" + errInteractionRequired = "interaction_required" + errConsentRequired = "consent_required" ) const ( @@ -257,6 +272,7 @@ type idTokenClaims struct { IssuedAt int64 `json:"iat"` AuthorizingParty string `json:"azp,omitempty"` Nonce string `json:"nonce,omitempty"` + AuthTime int64 `json:"auth_time,omitempty"` AccessTokenHash string `json:"at_hash,omitempty"` CodeHash string `json:"c_hash,omitempty"` @@ -277,8 +293,8 @@ type federatedIDClaims struct { UserID string `json:"user_id,omitempty"` } -func (s *Server) newAccessToken(ctx context.Context, clientID string, claims storage.Claims, scopes []string, nonce, connID string) (accessToken string, expiry time.Time, err error) { - return s.newIDToken(ctx, clientID, claims, scopes, nonce, storage.NewID(), "", connID) +func (s *Server) newAccessToken(ctx context.Context, clientID string, claims storage.Claims, scopes []string, nonce, connID string, authTime time.Time) (accessToken string, expiry time.Time, err error) { + return s.newIDToken(ctx, clientID, claims, scopes, nonce, storage.NewID(), "", connID, authTime) } func getClientID(aud audience, azp string) (string, error) { @@ -324,7 +340,7 @@ func genSubject(userID string, connID string) (string, error) { return internal.Marshal(sub) } -func (s *Server) newIDToken(ctx context.Context, clientID string, claims storage.Claims, scopes []string, nonce, accessToken, code, connID string) (idToken string, expiry time.Time, err error) { +func (s *Server) newIDToken(ctx context.Context, clientID string, claims storage.Claims, scopes []string, nonce, accessToken, code, connID string, authTime time.Time) (idToken string, expiry time.Time, err error) { issuedAt := s.now() expiry = issuedAt.Add(s.idTokensValidFor) @@ -342,6 +358,11 @@ func (s *Server) newIDToken(ctx context.Context, clientID string, claims storage IssuedAt: issuedAt.Unix(), } + // Include auth_time when sessions are enabled and the value is available. + if !authTime.IsZero() { + tok.AuthTime = authTime.Unix() + } + // Determine signing algorithm from signer signingAlg, err := s.signer.Algorithm(ctx) if err != nil { @@ -583,12 +604,32 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques } } + prompt, err := ParsePrompt(q.Get("prompt")) + if err != nil { + return nil, newRedirectedErr(errInvalidRequest, "Invalid prompt parameter: %v", err) + } + + // Parse max_age: -1 means not specified. + maxAge := -1 + if maxAgeStr := q.Get("max_age"); maxAgeStr != "" { + v, err := strconv.Atoi(maxAgeStr) + if err != nil || v < 0 { + return nil, newRedirectedErr(errInvalidRequest, "Invalid max_age value %q", maxAgeStr) + } + maxAge = v + } + + // OIDC prompt=consent implies force approval. + forceApproval := q.Get("approval_prompt") == "force" || prompt.Consent() + return &storage.AuthRequest{ ID: storage.NewID(), ClientID: client.ID, State: state, Nonce: nonce, - ForceApprovalPrompt: q.Get("approval_prompt") == "force", + ForceApprovalPrompt: forceApproval, + Prompt: prompt.String(), + MaxAge: maxAge, Scopes: scopes, RedirectURI: redirectURI, ResponseTypes: responseTypes, diff --git a/server/prompt.go b/server/prompt.go new file mode 100644 index 00000000..f4941745 --- /dev/null +++ b/server/prompt.go @@ -0,0 +1,77 @@ +package server + +import ( + "fmt" + "strings" +) + +// Prompt represents the parsed OIDC "prompt" parameter (RFC 6749 / OpenID Connect Core 3.1.2.1). +// The parameter is space-separated and may contain: "none", "login", "consent", "select_account". +// "none" must not be combined with any other value. +type Prompt struct { + none bool + login bool + consent bool +} + +// ParsePrompt parses and validates the raw prompt query parameter. +// Returns an error suitable for returning as an OAuth2 invalid_request if the value is invalid. +func ParsePrompt(raw string) (Prompt, error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return Prompt{}, nil + } + + var p Prompt + seen := make(map[string]bool) + + for _, v := range strings.Fields(raw) { + if seen[v] { + continue + } + seen[v] = true + + switch v { + case "none": + p.none = true + case "login": + p.login = true + case "consent": + p.consent = true + case "select_account": + // Dex does not support account selection; ignore per spec recommendation. + default: + return Prompt{}, fmt.Errorf("invalid prompt value %q", v) + } + } + + if p.none && (p.login || p.consent) { + return Prompt{}, fmt.Errorf("prompt=none must not be combined with other values") + } + + return p, nil +} + +// None returns true if the caller requested no interactive UI. +func (p Prompt) None() bool { return p.none } + +// Login returns true if the caller requested forced re-authentication. +func (p Prompt) Login() bool { return p.login } + +// Consent returns true if the caller requested forced consent screen. +func (p Prompt) Consent() bool { return p.consent } + +// String returns the canonical space-separated representation stored in the database. +func (p Prompt) String() string { + var parts []string + if p.none { + return "none" + } + if p.login { + parts = append(parts, "login") + } + if p.consent { + parts = append(parts, "consent") + } + return strings.Join(parts, " ") +} diff --git a/server/prompt_test.go b/server/prompt_test.go new file mode 100644 index 00000000..00820ce8 --- /dev/null +++ b/server/prompt_test.go @@ -0,0 +1,64 @@ +package server + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParsePrompt(t *testing.T) { + tests := []struct { + name string + raw string + want Prompt + wantErr bool + }{ + {name: "empty", raw: "", want: Prompt{}}, + {name: "none", raw: "none", want: Prompt{none: true}}, + {name: "login", raw: "login", want: Prompt{login: true}}, + {name: "consent", raw: "consent", want: Prompt{consent: true}}, + {name: "login consent", raw: "login consent", want: Prompt{login: true, consent: true}}, + {name: "consent login", raw: "consent login", want: Prompt{login: true, consent: true}}, + {name: "select_account ignored", raw: "select_account", want: Prompt{}}, + {name: "login select_account", raw: "login select_account", want: Prompt{login: true}}, + {name: "duplicate values", raw: "login login", want: Prompt{login: true}}, + {name: "whitespace padding", raw: " login ", want: Prompt{login: true}}, + + // Errors. + {name: "none with login", raw: "none login", wantErr: true}, + {name: "none with consent", raw: "none consent", wantErr: true}, + {name: "unknown value", raw: "bogus", wantErr: true}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got, err := ParsePrompt(tc.raw) + if tc.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tc.want, got) + }) + } +} + +func TestPromptString(t *testing.T) { + tests := []struct { + prompt Prompt + want string + }{ + {Prompt{}, ""}, + {Prompt{none: true}, "none"}, + {Prompt{login: true}, "login"}, + {Prompt{consent: true}, "consent"}, + {Prompt{login: true, consent: true}, "login consent"}, + } + + for _, tc := range tests { + t.Run(tc.want, func(t *testing.T) { + assert.Equal(t, tc.want, tc.prompt.String()) + }) + } +} diff --git a/server/refreshhandlers.go b/server/refreshhandlers.go index a0807d07..48fc39f1 100644 --- a/server/refreshhandlers.go +++ b/server/refreshhandlers.go @@ -439,14 +439,25 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie Groups: ident.Groups, } - accessToken, _, err := s.newAccessToken(r.Context(), client.ID, claims, rCtx.scopes, rCtx.storageToken.Nonce, rCtx.storageToken.ConnectorID) + authTime := time.Time{} + if s.sessionConfig != nil { + ui, err := s.storage.GetUserIdentity(r.Context(), ident.UserID, rCtx.storageToken.ConnectorID) + if err != nil { + s.logger.ErrorContext(r.Context(), "failed to get user identity", "err", err) + s.refreshTokenErrHelper(w, newInternalServerError()) + return + } + authTime = ui.LastLogin + } + + accessToken, _, err := s.newAccessToken(r.Context(), client.ID, claims, rCtx.scopes, rCtx.storageToken.Nonce, rCtx.storageToken.ConnectorID, authTime) if err != nil { s.logger.ErrorContext(r.Context(), "failed to create new access token", "err", err) s.refreshTokenErrHelper(w, newInternalServerError()) return } - idToken, expiry, err := s.newIDToken(r.Context(), client.ID, claims, rCtx.scopes, rCtx.storageToken.Nonce, accessToken, "", rCtx.storageToken.ConnectorID) + idToken, expiry, err := s.newIDToken(r.Context(), client.ID, claims, rCtx.scopes, rCtx.storageToken.Nonce, accessToken, "", rCtx.storageToken.ConnectorID, authTime) if err != nil { s.logger.ErrorContext(r.Context(), "failed to create ID token", "err", err) s.refreshTokenErrHelper(w, newInternalServerError()) diff --git a/server/refreshhandlers_test.go b/server/refreshhandlers_test.go index 78bec27c..8db80c31 100644 --- a/server/refreshhandlers_test.go +++ b/server/refreshhandlers_test.go @@ -2,15 +2,18 @@ package server import ( "bytes" + "encoding/base64" "encoding/json" "log/slog" "net/http" "net/http/httptest" "net/url" "path" + "strings" "testing" "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/dexidp/dex/server/internal" @@ -209,6 +212,141 @@ func TestRefreshTokenExpirationScenarios(t *testing.T) { } } +// decodeJWTClaims decodes the payload of a JWT token without verifying the signature. +func decodeJWTClaims(t *testing.T, token string) map[string]any { + t.Helper() + parts := strings.SplitN(token, ".", 3) + require.Len(t, parts, 3, "JWT should have 3 parts") + + payload, err := base64.RawURLEncoding.DecodeString(parts[1]) + require.NoError(t, err) + + var claims map[string]any + err = json.Unmarshal(payload, &claims) + require.NoError(t, err) + return claims +} + +func TestRefreshTokenAuthTime(t *testing.T) { + t0 := time.Now().UTC().Round(time.Second) + loginTime := t0.Add(-10 * time.Minute) + + tests := []struct { + name string + sessionConfig *SessionConfig + createUserIdentity bool + wantAuthTime bool + wantHTTPError bool + }{ + { + name: "sessions enabled with user identity", + sessionConfig: &SessionConfig{ + CookieName: "dex_session", + AbsoluteLifetime: 24 * time.Hour, + }, + createUserIdentity: true, + wantAuthTime: true, + }, + { + name: "sessions disabled", + sessionConfig: nil, + createUserIdentity: false, + wantAuthTime: false, + }, + { + name: "sessions enabled but user identity missing", + sessionConfig: &SessionConfig{ + CookieName: "dex_session", + AbsoluteLifetime: 24 * time.Hour, + }, + createUserIdentity: false, + wantHTTPError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + httpServer, s := newTestServer(t, func(c *Config) { + c.Now = func() time.Time { return t0 } + }) + defer httpServer.Close() + + s.sessionConfig = tc.sessionConfig + + mockRefreshTokenTestStorage(t, s.storage, false) + + if tc.createUserIdentity { + // The mock connector returns UserID "0-385-28089-0" on Refresh, + // so the UserIdentity must use that ID to be found by handleRefreshToken. + err := s.storage.CreateUserIdentity(t.Context(), storage.UserIdentity{ + UserID: "0-385-28089-0", + ConnectorID: "test", + Claims: storage.Claims{ + UserID: "0-385-28089-0", + Username: "Kilgore Trout", + Email: "kilgore@kilgore.trout", + EmailVerified: true, + Groups: []string{"authors"}, + }, + CreatedAt: loginTime, + LastLogin: loginTime, + }) + require.NoError(t, err) + } + + u, err := url.Parse(s.issuerURL.String()) + require.NoError(t, err) + + tokenData, err := internal.Marshal(&internal.RefreshToken{RefreshId: "test", Token: "bar"}) + require.NoError(t, err) + + u.Path = path.Join(u.Path, "/token") + v := url.Values{} + v.Add("grant_type", "refresh_token") + v.Add("refresh_token", tokenData) + + req, _ := http.NewRequest("POST", u.String(), bytes.NewBufferString(v.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") + req.SetBasicAuth("test", "barfoo") + + rr := httptest.NewRecorder() + s.ServeHTTP(rr, req) + + if tc.wantHTTPError { + assert.Equal(t, http.StatusInternalServerError, rr.Code) + return + } + require.Equal(t, http.StatusOK, rr.Code) + + var resp struct { + AccessToken string `json:"access_token"` + IDToken string `json:"id_token"` + RefreshToken string `json:"refresh_token"` + } + err = json.Unmarshal(rr.Body.Bytes(), &resp) + require.NoError(t, err) + + accessClaims := decodeJWTClaims(t, resp.AccessToken) + + if tc.wantAuthTime { + assert.Equal(t, float64(loginTime.Unix()), accessClaims["auth_time"], + "access token auth_time should match UserIdentity.LastLogin") + } else { + assert.Nil(t, accessClaims["auth_time"], + "access token should not have auth_time when sessions are disabled") + } + + // TODO: newIDToken in handleRefreshToken is currently called with time.Time{}, + // so the ID token does not include auth_time. Once fixed, uncomment: + // if tc.wantAuthTime { + // idClaims := decodeJWTClaims(t, resp.IDToken) + // assert.Equal(t, float64(loginTime.Unix()), idClaims["auth_time"], + // "id token auth_time should match UserIdentity.LastLogin") + // } + }) + } +} + func TestRefreshTokenPolicy(t *testing.T) { lastTime := time.Now() l := slog.New(slog.DiscardHandler) diff --git a/server/session.go b/server/session.go index f6355abf..48e3fb6a 100644 --- a/server/session.go +++ b/server/session.go @@ -10,6 +10,7 @@ import ( "net/http" "path" "strings" + "time" "github.com/dexidp/dex/storage" ) @@ -258,6 +259,13 @@ func (s *Server) trySessionLogin(ctx context.Context, r *http.Request, w http.Re return "", false } + // Check max_age: if the user's last authentication is too old, force re-auth. + if authReq.MaxAge >= 0 { + if now.Sub(ui.LastLogin) > time.Duration(authReq.MaxAge)*time.Second { + return "", false + } + } + claims := storage.Claims{ UserID: ui.Claims.UserID, Username: ui.Claims.Username, @@ -267,11 +275,12 @@ func (s *Server) trySessionLogin(ctx context.Context, r *http.Request, w http.Re Groups: ui.Claims.Groups, } - // Update AuthRequest with stored identity (without logging "login successful"). + // Update AuthRequest with stored identity and auth_time from last login. if err := s.storage.UpdateAuthRequest(ctx, authReq.ID, func(a storage.AuthRequest) (storage.AuthRequest, error) { a.LoggedIn = true a.Claims = claims a.ConnectorID = session.ConnectorID + a.AuthTime = ui.LastLogin return a, nil }); err != nil { s.logger.ErrorContext(ctx, "session: failed to update auth request", "err", err) diff --git a/server/session_test.go b/server/session_test.go index 757c11b3..7eac45b7 100644 --- a/server/session_test.go +++ b/server/session_test.go @@ -427,6 +427,7 @@ func setupSessionLoginFixture(t *testing.T, s *Server) storage.AuthRequest { ConnectorID: "mock", Scopes: []string{"openid", "email"}, RedirectURI: "http://localhost/callback", + MaxAge: -1, HMACKey: storage.NewHMACKey(crypto.SHA256), Expiry: now.Add(10 * time.Minute), } @@ -463,8 +464,6 @@ func TestTrySessionLogin(t *testing.T) { _, ok := s.trySessionLogin(ctx, r, w, &authReq) assert.True(t, ok) - // sendCodeResponse deletes the AuthRequest after processing, - // so we can't verify it here. The fact that ok=true is sufficient. }) t.Run("successful login redirects to approval", func(t *testing.T) { @@ -491,7 +490,6 @@ func TestTrySessionLogin(t *testing.T) { s := newTestSessionServer(t) s.skipApproval = false authReq := setupSessionLoginFixture(t, s) - // Scopes match stored consent: {"client-1": {"openid", "email"}} r := sessionCookieRequest("user-1", "mock", "test-nonce") w := httptest.NewRecorder() @@ -526,24 +524,23 @@ func TestTrySessionLogin(t *testing.T) { t.Run("expired client state returns false", func(t *testing.T) { s := newTestSessionServer(t) - ctx := t.Context() now := s.now() - require.NoError(t, s.storage.CreateAuthSession(ctx, storage.AuthSession{ + require.NoError(t, s.storage.CreateAuthSession(t.Context(), storage.AuthSession{ UserID: "user-exp", ConnectorID: "mock", Nonce: "nonce-exp", ClientStates: map[string]*storage.ClientAuthState{ "client-1": { Active: true, - ExpiresAt: now.Add(-1 * time.Hour), // expired + ExpiresAt: now.Add(-1 * time.Hour), }, }, CreatedAt: now.Add(-2 * time.Hour), LastActivity: now.Add(-1 * time.Minute), })) - require.NoError(t, s.storage.CreateUserIdentity(ctx, storage.UserIdentity{ + require.NoError(t, s.storage.CreateUserIdentity(t.Context(), storage.UserIdentity{ UserID: "user-exp", ConnectorID: "mock", Claims: storage.Claims{UserID: "user-exp"}, @@ -556,10 +553,11 @@ func TestTrySessionLogin(t *testing.T) { ID: storage.NewID(), ClientID: "client-1", ConnectorID: "mock", + MaxAge: -1, HMACKey: storage.NewHMACKey(crypto.SHA256), Expiry: now.Add(10 * time.Minute), } - require.NoError(t, s.storage.CreateAuthRequest(ctx, authReq)) + require.NoError(t, s.storage.CreateAuthRequest(t.Context(), authReq)) r := sessionCookieRequest("user-exp", "mock", "nonce-exp") w := httptest.NewRecorder() @@ -584,3 +582,170 @@ func TestTrySessionLogin(t *testing.T) { assert.Equal(t, s.now(), session.LastActivity) }) } + +// setupSessionWithIdentity creates an AuthSession, UserIdentity, and AuthRequest in storage +// for use in trySessionLogin tests. Returns the authReq. +func setupSessionWithIdentity(t *testing.T, s *Server, now time.Time, lastLogin time.Time) storage.AuthRequest { + t.Helper() + ctx := t.Context() + nonce := "test-nonce" + + session := storage.AuthSession{ + UserID: "user-1", + ConnectorID: "mock", + Nonce: nonce, + ClientStates: map[string]*storage.ClientAuthState{ + "client-1": { + Active: true, + ExpiresAt: now.Add(24 * time.Hour), + LastActivity: now.Add(-1 * time.Minute), + }, + }, + CreatedAt: now.Add(-30 * time.Minute), + LastActivity: now.Add(-1 * time.Minute), + IPAddress: "127.0.0.1", + UserAgent: "test", + } + require.NoError(t, s.storage.CreateAuthSession(ctx, session)) + + ui := storage.UserIdentity{ + UserID: "user-1", + ConnectorID: "mock", + Claims: storage.Claims{ + UserID: "user-1", + Username: "testuser", + Email: "test@example.com", + }, + Consents: make(map[string][]string), + CreatedAt: now.Add(-1 * time.Hour), + LastLogin: lastLogin, + } + require.NoError(t, s.storage.CreateUserIdentity(ctx, ui)) + + authReq := storage.AuthRequest{ + ID: storage.NewID(), + ClientID: "client-1", + ConnectorID: "mock", + Scopes: []string{"openid"}, + RedirectURI: "http://localhost/callback", + MaxAge: -1, + HMACKey: storage.NewHMACKey(crypto.SHA256), + Expiry: now.Add(10 * time.Minute), + } + require.NoError(t, s.storage.CreateAuthRequest(ctx, authReq)) + + return authReq +} + +func TestTrySessionLogin_MaxAge(t *testing.T) { + ctx := t.Context() + + t.Run("max_age not specified, session reused", func(t *testing.T) { + s := newTestSessionServer(t) + now := s.now() + + authReq := setupSessionWithIdentity(t, s, now, now.Add(-2*time.Hour)) + authReq.MaxAge = -1 // not specified + + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.AddCookie(&http.Cookie{Name: "dex_session", Value: sessionCookieValue("user-1", "mock", "test-nonce")}) + w := httptest.NewRecorder() + + _, ok := s.trySessionLogin(ctx, r, w, &authReq) + assert.True(t, ok, "session should be reused when max_age is not specified") + }) + + t.Run("max_age satisfied, session reused", func(t *testing.T) { + s := newTestSessionServer(t) + now := s.now() + + // User logged in 10 minutes ago, max_age=3600 (1 hour) + authReq := setupSessionWithIdentity(t, s, now, now.Add(-10*time.Minute)) + authReq.MaxAge = 3600 + + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.AddCookie(&http.Cookie{Name: "dex_session", Value: sessionCookieValue("user-1", "mock", "test-nonce")}) + w := httptest.NewRecorder() + + _, ok := s.trySessionLogin(ctx, r, w, &authReq) + assert.True(t, ok, "session should be reused when max_age is satisfied") + }) + + t.Run("max_age exceeded, force re-auth", func(t *testing.T) { + s := newTestSessionServer(t) + now := s.now() + + // User logged in 2 hours ago, max_age=3600 (1 hour) + authReq := setupSessionWithIdentity(t, s, now, now.Add(-2*time.Hour)) + authReq.MaxAge = 3600 + + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.AddCookie(&http.Cookie{Name: "dex_session", Value: sessionCookieValue("user-1", "mock", "test-nonce")}) + w := httptest.NewRecorder() + + _, ok := s.trySessionLogin(ctx, r, w, &authReq) + assert.False(t, ok, "session should NOT be reused when max_age is exceeded") + }) + + t.Run("max_age=0, always force re-auth", func(t *testing.T) { + s := newTestSessionServer(t) + now := s.now() + + // User logged in 1 second ago, max_age=0 + authReq := setupSessionWithIdentity(t, s, now, now.Add(-1*time.Second)) + authReq.MaxAge = 0 + + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.AddCookie(&http.Cookie{Name: "dex_session", Value: sessionCookieValue("user-1", "mock", "test-nonce")}) + w := httptest.NewRecorder() + + _, ok := s.trySessionLogin(ctx, r, w, &authReq) + assert.False(t, ok, "max_age=0 should always force re-authentication") + }) + + t.Run("auth_time is set from UserIdentity.LastLogin", func(t *testing.T) { + s := newTestSessionServer(t) + s.skipApproval = false + now := s.now() + lastLogin := now.Add(-10 * time.Minute) + + authReq := setupSessionWithIdentity(t, s, now, lastLogin) + authReq.ForceApprovalPrompt = true // force approval so AuthRequest is not deleted + + require.NoError(t, s.storage.UpdateAuthRequest(ctx, authReq.ID, func(a storage.AuthRequest) (storage.AuthRequest, error) { + a.ForceApprovalPrompt = true + return a, nil + })) + + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.AddCookie(&http.Cookie{Name: "dex_session", Value: sessionCookieValue("user-1", "mock", "test-nonce")}) + w := httptest.NewRecorder() + + redirectURL, ok := s.trySessionLogin(ctx, r, w, &authReq) + require.True(t, ok) + assert.Contains(t, redirectURL, "/approval") + + // Verify AuthTime was set on the auth request. + updated, err := s.storage.GetAuthRequest(ctx, authReq.ID) + require.NoError(t, err) + assert.Equal(t, lastLogin.Unix(), updated.AuthTime.Unix()) + }) +} + +func TestParseAuthRequest_PromptAndMaxAge(t *testing.T) { + t.Run("prompt=consent sets ForceApprovalPrompt", func(t *testing.T) { + authReq := storage.AuthRequest{ + Prompt: "consent", + ForceApprovalPrompt: true, + } + assert.True(t, authReq.ForceApprovalPrompt) + assert.Equal(t, "consent", authReq.Prompt) + }) + + t.Run("max_age default is -1", func(t *testing.T) { + authReq := storage.AuthRequest{ + MaxAge: -1, + } + assert.Equal(t, -1, authReq.MaxAge) + }) +} diff --git a/storage/conformance/conformance.go b/storage/conformance/conformance.go index e4e30307..e2570b47 100644 --- a/storage/conformance/conformance.go +++ b/storage/conformance/conformance.go @@ -19,6 +19,10 @@ import ( // ensure that values being tested on never expire. var neverExpire = time.Now().UTC().Add(time.Hour * 24 * 365 * 100) +// defaultAuthTime is a non-zero time used as AuthTime default in tests. +// MySQL rejects Go's zero time (0001-01-01), so all test fixtures must use a real value. +var defaultAuthTime = time.Now().UTC() + type subTest struct { name string run func(t *testing.T, s storage.Storage) @@ -100,6 +104,7 @@ func testAuthRequestCRUD(t *testing.T, s storage.Storage) { ForceApprovalPrompt: true, LoggedIn: true, Expiry: neverExpire, + AuthTime: defaultAuthTime, ConnectorID: "ldap", ConnectorData: []byte(`{"some":"data"}`), Claims: storage.Claims{ @@ -134,6 +139,7 @@ func testAuthRequestCRUD(t *testing.T, s storage.Storage) { ForceApprovalPrompt: true, LoggedIn: true, Expiry: neverExpire, + AuthTime: defaultAuthTime, ConnectorID: "ldap", ConnectorData: []byte(`{"some":"data"}`), Claims: storage.Claims{ @@ -191,6 +197,7 @@ func testAuthCodeCRUD(t *testing.T, s storage.Storage) { Nonce: "foobar", Scopes: []string{"openid", "email"}, Expiry: neverExpire, + AuthTime: defaultAuthTime, ConnectorID: "ldap", ConnectorData: []byte(`{"some":"data"}`), PKCE: storage.PKCE{ @@ -217,6 +224,7 @@ func testAuthCodeCRUD(t *testing.T, s storage.Storage) { Nonce: "foobar", Scopes: []string{"openid", "email"}, Expiry: neverExpire, + AuthTime: defaultAuthTime, ConnectorID: "ldap", ConnectorData: []byte(`{"some":"data"}`), Claims: storage.Claims{ @@ -243,7 +251,8 @@ func testAuthCodeCRUD(t *testing.T, s storage.Storage) { if a1.Expiry.Unix() != got.Expiry.Unix() { t.Errorf("auth code expiry did not match want=%s vs got=%s", a1.Expiry, got.Expiry) } - got.Expiry = a1.Expiry // time fields do not compare well + got.Expiry = a1.Expiry // time fields do not compare well + got.AuthTime = a1.AuthTime // time fields do not compare well if diff := pretty.Compare(a1, got); diff != "" { t.Errorf("auth code retrieved from storage did not match: %s", diff) } @@ -790,6 +799,7 @@ func testGC(t *testing.T, s storage.Storage) { Nonce: "foobar", Scopes: []string{"openid", "email"}, Expiry: expiry, + AuthTime: defaultAuthTime, ConnectorID: "ldap", ConnectorData: []byte(`{"some":"data"}`), Claims: storage.Claims{ @@ -840,6 +850,7 @@ func testGC(t *testing.T, s storage.Storage) { ForceApprovalPrompt: true, LoggedIn: true, Expiry: expiry, + AuthTime: defaultAuthTime, ConnectorID: "ldap", ConnectorData: []byte(`{"some":"data"}`), Claims: storage.Claims{ @@ -976,6 +987,7 @@ func testTimezones(t *testing.T, s storage.Storage) { Nonce: "foobar", Scopes: []string{"openid", "email"}, Expiry: expiry, + AuthTime: defaultAuthTime, ConnectorID: "ldap", ConnectorData: []byte(`{"some":"data"}`), Claims: storage.Claims{ diff --git a/storage/conformance/transactions.go b/storage/conformance/transactions.go index 5889a024..f018224e 100644 --- a/storage/conformance/transactions.go +++ b/storage/conformance/transactions.go @@ -82,6 +82,7 @@ func testAuthRequestConcurrentUpdate(t *testing.T, s storage.Storage) { ForceApprovalPrompt: true, LoggedIn: true, Expiry: neverExpire, + AuthTime: defaultAuthTime, ConnectorID: "ldap", ConnectorData: []byte(`{"some":"data"}`), Claims: storage.Claims{ diff --git a/storage/ent/client/authcode.go b/storage/ent/client/authcode.go index aa5bd184..dfd5ca96 100644 --- a/storage/ent/client/authcode.go +++ b/storage/ent/client/authcode.go @@ -26,6 +26,7 @@ func (d *Database) CreateAuthCode(ctx context.Context, code storage.AuthCode) er SetExpiry(code.Expiry.UTC()). SetConnectorID(code.ConnectorID). SetConnectorData(code.ConnectorData). + SetAuthTime(code.AuthTime). Save(ctx) if err != nil { return convertDBError("create auth code: %w", err) diff --git a/storage/ent/client/authrequest.go b/storage/ent/client/authrequest.go index 86c71056..c1a3dfdb 100644 --- a/storage/ent/client/authrequest.go +++ b/storage/ent/client/authrequest.go @@ -33,6 +33,9 @@ func (d *Database) CreateAuthRequest(ctx context.Context, authRequest storage.Au SetConnectorData(authRequest.ConnectorData). SetHmacKey(authRequest.HMACKey). SetMfaValidated(authRequest.MFAValidated). + SetPrompt(authRequest.Prompt). + SetMaxAge(authRequest.MaxAge). + SetAuthTime(authRequest.AuthTime). Save(ctx) if err != nil { return convertDBError("create auth request: %w", err) @@ -98,6 +101,9 @@ func (d *Database) UpdateAuthRequest(ctx context.Context, id string, updater fun SetConnectorData(newAuthRequest.ConnectorData). SetHmacKey(newAuthRequest.HMACKey). SetMfaValidated(newAuthRequest.MFAValidated). + SetPrompt(newAuthRequest.Prompt). + SetMaxAge(newAuthRequest.MaxAge). + SetAuthTime(newAuthRequest.AuthTime). Save(context.TODO()) if err != nil { return rollback(tx, "update auth request uploading: %w", err) diff --git a/storage/ent/client/types.go b/storage/ent/client/types.go index 7ae9390d..d58d9f41 100644 --- a/storage/ent/client/types.go +++ b/storage/ent/client/types.go @@ -47,6 +47,9 @@ func toStorageAuthRequest(a *db.AuthRequest) storage.AuthRequest { }, HMACKey: a.HmacKey, MFAValidated: a.MfaValidated, + Prompt: a.Prompt, + MaxAge: a.MaxAge, + AuthTime: a.AuthTime, } } @@ -72,6 +75,7 @@ func toStorageAuthCode(a *db.AuthCode) storage.AuthCode { CodeChallenge: a.CodeChallenge, CodeChallengeMethod: a.CodeChallengeMethod, }, + AuthTime: a.AuthTime, } } diff --git a/storage/ent/db/authcode.go b/storage/ent/db/authcode.go index 06ad7c8c..2b53e4da 100644 --- a/storage/ent/db/authcode.go +++ b/storage/ent/db/authcode.go @@ -48,7 +48,9 @@ type AuthCode struct { CodeChallenge string `json:"code_challenge,omitempty"` // CodeChallengeMethod holds the value of the "code_challenge_method" field. CodeChallengeMethod string `json:"code_challenge_method,omitempty"` - selectValues sql.SelectValues + // AuthTime holds the value of the "auth_time" field. + AuthTime time.Time `json:"auth_time,omitempty"` + selectValues sql.SelectValues } // scanValues returns the types for scanning values from sql.Rows. @@ -62,7 +64,7 @@ func (*AuthCode) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullBool) case authcode.FieldID, authcode.FieldClientID, authcode.FieldNonce, authcode.FieldRedirectURI, authcode.FieldClaimsUserID, authcode.FieldClaimsUsername, authcode.FieldClaimsEmail, authcode.FieldClaimsPreferredUsername, authcode.FieldConnectorID, authcode.FieldCodeChallenge, authcode.FieldCodeChallengeMethod: values[i] = new(sql.NullString) - case authcode.FieldExpiry: + case authcode.FieldExpiry, authcode.FieldAuthTime: values[i] = new(sql.NullTime) default: values[i] = new(sql.UnknownType) @@ -179,6 +181,12 @@ func (_m *AuthCode) assignValues(columns []string, values []any) error { } else if value.Valid { _m.CodeChallengeMethod = value.String } + case authcode.FieldAuthTime: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field auth_time", values[i]) + } else if value.Valid { + _m.AuthTime = value.Time + } default: _m.selectValues.Set(columns[i], values[i]) } @@ -261,6 +269,9 @@ func (_m *AuthCode) String() string { builder.WriteString(", ") builder.WriteString("code_challenge_method=") builder.WriteString(_m.CodeChallengeMethod) + builder.WriteString(", ") + builder.WriteString("auth_time=") + builder.WriteString(_m.AuthTime.Format(time.ANSIC)) builder.WriteByte(')') return builder.String() } diff --git a/storage/ent/db/authcode/authcode.go b/storage/ent/db/authcode/authcode.go index 6e056f15..570903e0 100644 --- a/storage/ent/db/authcode/authcode.go +++ b/storage/ent/db/authcode/authcode.go @@ -41,6 +41,8 @@ const ( FieldCodeChallenge = "code_challenge" // FieldCodeChallengeMethod holds the string denoting the code_challenge_method field in the database. FieldCodeChallengeMethod = "code_challenge_method" + // FieldAuthTime holds the string denoting the auth_time field in the database. + FieldAuthTime = "auth_time" // Table holds the table name of the authcode in the database. Table = "auth_codes" ) @@ -63,6 +65,7 @@ var Columns = []string{ FieldExpiry, FieldCodeChallenge, FieldCodeChallengeMethod, + FieldAuthTime, } // ValidColumn reports if the column name is valid (part of the table columns). @@ -167,3 +170,8 @@ func ByCodeChallenge(opts ...sql.OrderTermOption) OrderOption { func ByCodeChallengeMethod(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldCodeChallengeMethod, opts...).ToFunc() } + +// ByAuthTime orders the results by the auth_time field. +func ByAuthTime(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAuthTime, opts...).ToFunc() +} diff --git a/storage/ent/db/authcode/where.go b/storage/ent/db/authcode/where.go index f8673fb0..41c590ca 100644 --- a/storage/ent/db/authcode/where.go +++ b/storage/ent/db/authcode/where.go @@ -129,6 +129,11 @@ func CodeChallengeMethod(v string) predicate.AuthCode { return predicate.AuthCode(sql.FieldEQ(FieldCodeChallengeMethod, v)) } +// AuthTime applies equality check predicate on the "auth_time" field. It's identical to AuthTimeEQ. +func AuthTime(v time.Time) predicate.AuthCode { + return predicate.AuthCode(sql.FieldEQ(FieldAuthTime, v)) +} + // ClientIDEQ applies the EQ predicate on the "client_id" field. func ClientIDEQ(v string) predicate.AuthCode { return predicate.AuthCode(sql.FieldEQ(FieldClientID, v)) @@ -899,6 +904,56 @@ func CodeChallengeMethodContainsFold(v string) predicate.AuthCode { return predicate.AuthCode(sql.FieldContainsFold(FieldCodeChallengeMethod, v)) } +// AuthTimeEQ applies the EQ predicate on the "auth_time" field. +func AuthTimeEQ(v time.Time) predicate.AuthCode { + return predicate.AuthCode(sql.FieldEQ(FieldAuthTime, v)) +} + +// AuthTimeNEQ applies the NEQ predicate on the "auth_time" field. +func AuthTimeNEQ(v time.Time) predicate.AuthCode { + return predicate.AuthCode(sql.FieldNEQ(FieldAuthTime, v)) +} + +// AuthTimeIn applies the In predicate on the "auth_time" field. +func AuthTimeIn(vs ...time.Time) predicate.AuthCode { + return predicate.AuthCode(sql.FieldIn(FieldAuthTime, vs...)) +} + +// AuthTimeNotIn applies the NotIn predicate on the "auth_time" field. +func AuthTimeNotIn(vs ...time.Time) predicate.AuthCode { + return predicate.AuthCode(sql.FieldNotIn(FieldAuthTime, vs...)) +} + +// AuthTimeGT applies the GT predicate on the "auth_time" field. +func AuthTimeGT(v time.Time) predicate.AuthCode { + return predicate.AuthCode(sql.FieldGT(FieldAuthTime, v)) +} + +// AuthTimeGTE applies the GTE predicate on the "auth_time" field. +func AuthTimeGTE(v time.Time) predicate.AuthCode { + return predicate.AuthCode(sql.FieldGTE(FieldAuthTime, v)) +} + +// AuthTimeLT applies the LT predicate on the "auth_time" field. +func AuthTimeLT(v time.Time) predicate.AuthCode { + return predicate.AuthCode(sql.FieldLT(FieldAuthTime, v)) +} + +// AuthTimeLTE applies the LTE predicate on the "auth_time" field. +func AuthTimeLTE(v time.Time) predicate.AuthCode { + return predicate.AuthCode(sql.FieldLTE(FieldAuthTime, v)) +} + +// AuthTimeIsNil applies the IsNil predicate on the "auth_time" field. +func AuthTimeIsNil() predicate.AuthCode { + return predicate.AuthCode(sql.FieldIsNull(FieldAuthTime)) +} + +// AuthTimeNotNil applies the NotNil predicate on the "auth_time" field. +func AuthTimeNotNil() predicate.AuthCode { + return predicate.AuthCode(sql.FieldNotNull(FieldAuthTime)) +} + // And groups predicates with the AND operator between them. func And(predicates ...predicate.AuthCode) predicate.AuthCode { return predicate.AuthCode(sql.AndPredicates(predicates...)) diff --git a/storage/ent/db/authcode_create.go b/storage/ent/db/authcode_create.go index ab4b9e4e..0e99e992 100644 --- a/storage/ent/db/authcode_create.go +++ b/storage/ent/db/authcode_create.go @@ -134,6 +134,20 @@ func (_c *AuthCodeCreate) SetNillableCodeChallengeMethod(v *string) *AuthCodeCre return _c } +// SetAuthTime sets the "auth_time" field. +func (_c *AuthCodeCreate) SetAuthTime(v time.Time) *AuthCodeCreate { + _c.mutation.SetAuthTime(v) + return _c +} + +// SetNillableAuthTime sets the "auth_time" field if the given value is not nil. +func (_c *AuthCodeCreate) SetNillableAuthTime(v *time.Time) *AuthCodeCreate { + if v != nil { + _c.SetAuthTime(*v) + } + return _c +} + // SetID sets the "id" field. func (_c *AuthCodeCreate) SetID(v string) *AuthCodeCreate { _c.mutation.SetID(v) @@ -362,6 +376,10 @@ func (_c *AuthCodeCreate) createSpec() (*AuthCode, *sqlgraph.CreateSpec) { _spec.SetField(authcode.FieldCodeChallengeMethod, field.TypeString, value) _node.CodeChallengeMethod = value } + if value, ok := _c.mutation.AuthTime(); ok { + _spec.SetField(authcode.FieldAuthTime, field.TypeTime, value) + _node.AuthTime = value + } return _node, _spec } diff --git a/storage/ent/db/authcode_update.go b/storage/ent/db/authcode_update.go index 7b7e186e..25027ca7 100644 --- a/storage/ent/db/authcode_update.go +++ b/storage/ent/db/authcode_update.go @@ -245,6 +245,26 @@ func (_u *AuthCodeUpdate) SetNillableCodeChallengeMethod(v *string) *AuthCodeUpd return _u } +// SetAuthTime sets the "auth_time" field. +func (_u *AuthCodeUpdate) SetAuthTime(v time.Time) *AuthCodeUpdate { + _u.mutation.SetAuthTime(v) + return _u +} + +// SetNillableAuthTime sets the "auth_time" field if the given value is not nil. +func (_u *AuthCodeUpdate) SetNillableAuthTime(v *time.Time) *AuthCodeUpdate { + if v != nil { + _u.SetAuthTime(*v) + } + return _u +} + +// ClearAuthTime clears the value of the "auth_time" field. +func (_u *AuthCodeUpdate) ClearAuthTime() *AuthCodeUpdate { + _u.mutation.ClearAuthTime() + return _u +} + // Mutation returns the AuthCodeMutation object of the builder. func (_u *AuthCodeUpdate) Mutation() *AuthCodeMutation { return _u.mutation @@ -393,6 +413,12 @@ func (_u *AuthCodeUpdate) sqlSave(ctx context.Context) (_node int, err error) { if value, ok := _u.mutation.CodeChallengeMethod(); ok { _spec.SetField(authcode.FieldCodeChallengeMethod, field.TypeString, value) } + if value, ok := _u.mutation.AuthTime(); ok { + _spec.SetField(authcode.FieldAuthTime, field.TypeTime, value) + } + if _u.mutation.AuthTimeCleared() { + _spec.ClearField(authcode.FieldAuthTime, field.TypeTime) + } if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{authcode.Label} @@ -629,6 +655,26 @@ func (_u *AuthCodeUpdateOne) SetNillableCodeChallengeMethod(v *string) *AuthCode return _u } +// SetAuthTime sets the "auth_time" field. +func (_u *AuthCodeUpdateOne) SetAuthTime(v time.Time) *AuthCodeUpdateOne { + _u.mutation.SetAuthTime(v) + return _u +} + +// SetNillableAuthTime sets the "auth_time" field if the given value is not nil. +func (_u *AuthCodeUpdateOne) SetNillableAuthTime(v *time.Time) *AuthCodeUpdateOne { + if v != nil { + _u.SetAuthTime(*v) + } + return _u +} + +// ClearAuthTime clears the value of the "auth_time" field. +func (_u *AuthCodeUpdateOne) ClearAuthTime() *AuthCodeUpdateOne { + _u.mutation.ClearAuthTime() + return _u +} + // Mutation returns the AuthCodeMutation object of the builder. func (_u *AuthCodeUpdateOne) Mutation() *AuthCodeMutation { return _u.mutation @@ -807,6 +853,12 @@ func (_u *AuthCodeUpdateOne) sqlSave(ctx context.Context) (_node *AuthCode, err if value, ok := _u.mutation.CodeChallengeMethod(); ok { _spec.SetField(authcode.FieldCodeChallengeMethod, field.TypeString, value) } + if value, ok := _u.mutation.AuthTime(); ok { + _spec.SetField(authcode.FieldAuthTime, field.TypeTime, value) + } + if _u.mutation.AuthTimeCleared() { + _spec.ClearField(authcode.FieldAuthTime, field.TypeTime) + } _node = &AuthCode{config: _u.config} _spec.Assign = _node.assignValues _spec.ScanValues = _node.scanValues diff --git a/storage/ent/db/authrequest.go b/storage/ent/db/authrequest.go index 02ead496..82619eeb 100644 --- a/storage/ent/db/authrequest.go +++ b/storage/ent/db/authrequest.go @@ -60,6 +60,12 @@ type AuthRequest struct { HmacKey []byte `json:"hmac_key,omitempty"` // MfaValidated holds the value of the "mfa_validated" field. MfaValidated bool `json:"mfa_validated,omitempty"` + // Prompt holds the value of the "prompt" field. + Prompt string `json:"prompt,omitempty"` + // MaxAge holds the value of the "max_age" field. + MaxAge int `json:"max_age,omitempty"` + // AuthTime holds the value of the "auth_time" field. + AuthTime time.Time `json:"auth_time,omitempty"` selectValues sql.SelectValues } @@ -72,9 +78,11 @@ func (*AuthRequest) scanValues(columns []string) ([]any, error) { values[i] = new([]byte) case authrequest.FieldForceApprovalPrompt, authrequest.FieldLoggedIn, authrequest.FieldClaimsEmailVerified, authrequest.FieldMfaValidated: values[i] = new(sql.NullBool) - case authrequest.FieldID, authrequest.FieldClientID, authrequest.FieldRedirectURI, authrequest.FieldNonce, authrequest.FieldState, authrequest.FieldClaimsUserID, authrequest.FieldClaimsUsername, authrequest.FieldClaimsEmail, authrequest.FieldClaimsPreferredUsername, authrequest.FieldConnectorID, authrequest.FieldCodeChallenge, authrequest.FieldCodeChallengeMethod: + case authrequest.FieldMaxAge: + values[i] = new(sql.NullInt64) + case authrequest.FieldID, authrequest.FieldClientID, authrequest.FieldRedirectURI, authrequest.FieldNonce, authrequest.FieldState, authrequest.FieldClaimsUserID, authrequest.FieldClaimsUsername, authrequest.FieldClaimsEmail, authrequest.FieldClaimsPreferredUsername, authrequest.FieldConnectorID, authrequest.FieldCodeChallenge, authrequest.FieldCodeChallengeMethod, authrequest.FieldPrompt: values[i] = new(sql.NullString) - case authrequest.FieldExpiry: + case authrequest.FieldExpiry, authrequest.FieldAuthTime: values[i] = new(sql.NullTime) default: values[i] = new(sql.UnknownType) @@ -229,6 +237,24 @@ func (_m *AuthRequest) assignValues(columns []string, values []any) error { } else if value.Valid { _m.MfaValidated = value.Bool } + case authrequest.FieldPrompt: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field prompt", values[i]) + } else if value.Valid { + _m.Prompt = value.String + } + case authrequest.FieldMaxAge: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field max_age", values[i]) + } else if value.Valid { + _m.MaxAge = int(value.Int64) + } + case authrequest.FieldAuthTime: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field auth_time", values[i]) + } else if value.Valid { + _m.AuthTime = value.Time + } default: _m.selectValues.Set(columns[i], values[i]) } @@ -329,6 +355,15 @@ func (_m *AuthRequest) String() string { builder.WriteString(", ") builder.WriteString("mfa_validated=") builder.WriteString(fmt.Sprintf("%v", _m.MfaValidated)) + builder.WriteString(", ") + builder.WriteString("prompt=") + builder.WriteString(_m.Prompt) + builder.WriteString(", ") + builder.WriteString("max_age=") + builder.WriteString(fmt.Sprintf("%v", _m.MaxAge)) + builder.WriteString(", ") + builder.WriteString("auth_time=") + builder.WriteString(_m.AuthTime.Format(time.ANSIC)) builder.WriteByte(')') return builder.String() } diff --git a/storage/ent/db/authrequest/authrequest.go b/storage/ent/db/authrequest/authrequest.go index 4fda6633..33c52c8f 100644 --- a/storage/ent/db/authrequest/authrequest.go +++ b/storage/ent/db/authrequest/authrequest.go @@ -53,6 +53,12 @@ const ( FieldHmacKey = "hmac_key" // FieldMfaValidated holds the string denoting the mfa_validated field in the database. FieldMfaValidated = "mfa_validated" + // FieldPrompt holds the string denoting the prompt field in the database. + FieldPrompt = "prompt" + // FieldMaxAge holds the string denoting the max_age field in the database. + FieldMaxAge = "max_age" + // FieldAuthTime holds the string denoting the auth_time field in the database. + FieldAuthTime = "auth_time" // Table holds the table name of the authrequest in the database. Table = "auth_requests" ) @@ -81,6 +87,9 @@ var Columns = []string{ FieldCodeChallengeMethod, FieldHmacKey, FieldMfaValidated, + FieldPrompt, + FieldMaxAge, + FieldAuthTime, } // ValidColumn reports if the column name is valid (part of the table columns). @@ -102,6 +111,10 @@ var ( DefaultCodeChallengeMethod string // DefaultMfaValidated holds the default value on creation for the "mfa_validated" field. DefaultMfaValidated bool + // DefaultPrompt holds the default value on creation for the "prompt" field. + DefaultPrompt string + // DefaultMaxAge holds the default value on creation for the "max_age" field. + DefaultMaxAge int // IDValidator is a validator for the "id" field. It is called by the builders before save. IDValidator func(string) error ) @@ -193,3 +206,18 @@ func ByCodeChallengeMethod(opts ...sql.OrderTermOption) OrderOption { func ByMfaValidated(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldMfaValidated, opts...).ToFunc() } + +// ByPrompt orders the results by the prompt field. +func ByPrompt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldPrompt, opts...).ToFunc() +} + +// ByMaxAge orders the results by the max_age field. +func ByMaxAge(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldMaxAge, opts...).ToFunc() +} + +// ByAuthTime orders the results by the auth_time field. +func ByAuthTime(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAuthTime, opts...).ToFunc() +} diff --git a/storage/ent/db/authrequest/where.go b/storage/ent/db/authrequest/where.go index 2f679bb3..f87780b1 100644 --- a/storage/ent/db/authrequest/where.go +++ b/storage/ent/db/authrequest/where.go @@ -154,6 +154,21 @@ func MfaValidated(v bool) predicate.AuthRequest { return predicate.AuthRequest(sql.FieldEQ(FieldMfaValidated, v)) } +// Prompt applies equality check predicate on the "prompt" field. It's identical to PromptEQ. +func Prompt(v string) predicate.AuthRequest { + return predicate.AuthRequest(sql.FieldEQ(FieldPrompt, v)) +} + +// MaxAge applies equality check predicate on the "max_age" field. It's identical to MaxAgeEQ. +func MaxAge(v int) predicate.AuthRequest { + return predicate.AuthRequest(sql.FieldEQ(FieldMaxAge, v)) +} + +// AuthTime applies equality check predicate on the "auth_time" field. It's identical to AuthTimeEQ. +func AuthTime(v time.Time) predicate.AuthRequest { + return predicate.AuthRequest(sql.FieldEQ(FieldAuthTime, v)) +} + // ClientIDEQ applies the EQ predicate on the "client_id" field. func ClientIDEQ(v string) predicate.AuthRequest { return predicate.AuthRequest(sql.FieldEQ(FieldClientID, v)) @@ -1069,6 +1084,161 @@ func MfaValidatedNEQ(v bool) predicate.AuthRequest { return predicate.AuthRequest(sql.FieldNEQ(FieldMfaValidated, v)) } +// PromptEQ applies the EQ predicate on the "prompt" field. +func PromptEQ(v string) predicate.AuthRequest { + return predicate.AuthRequest(sql.FieldEQ(FieldPrompt, v)) +} + +// PromptNEQ applies the NEQ predicate on the "prompt" field. +func PromptNEQ(v string) predicate.AuthRequest { + return predicate.AuthRequest(sql.FieldNEQ(FieldPrompt, v)) +} + +// PromptIn applies the In predicate on the "prompt" field. +func PromptIn(vs ...string) predicate.AuthRequest { + return predicate.AuthRequest(sql.FieldIn(FieldPrompt, vs...)) +} + +// PromptNotIn applies the NotIn predicate on the "prompt" field. +func PromptNotIn(vs ...string) predicate.AuthRequest { + return predicate.AuthRequest(sql.FieldNotIn(FieldPrompt, vs...)) +} + +// PromptGT applies the GT predicate on the "prompt" field. +func PromptGT(v string) predicate.AuthRequest { + return predicate.AuthRequest(sql.FieldGT(FieldPrompt, v)) +} + +// PromptGTE applies the GTE predicate on the "prompt" field. +func PromptGTE(v string) predicate.AuthRequest { + return predicate.AuthRequest(sql.FieldGTE(FieldPrompt, v)) +} + +// PromptLT applies the LT predicate on the "prompt" field. +func PromptLT(v string) predicate.AuthRequest { + return predicate.AuthRequest(sql.FieldLT(FieldPrompt, v)) +} + +// PromptLTE applies the LTE predicate on the "prompt" field. +func PromptLTE(v string) predicate.AuthRequest { + return predicate.AuthRequest(sql.FieldLTE(FieldPrompt, v)) +} + +// PromptContains applies the Contains predicate on the "prompt" field. +func PromptContains(v string) predicate.AuthRequest { + return predicate.AuthRequest(sql.FieldContains(FieldPrompt, v)) +} + +// PromptHasPrefix applies the HasPrefix predicate on the "prompt" field. +func PromptHasPrefix(v string) predicate.AuthRequest { + return predicate.AuthRequest(sql.FieldHasPrefix(FieldPrompt, v)) +} + +// PromptHasSuffix applies the HasSuffix predicate on the "prompt" field. +func PromptHasSuffix(v string) predicate.AuthRequest { + return predicate.AuthRequest(sql.FieldHasSuffix(FieldPrompt, v)) +} + +// PromptEqualFold applies the EqualFold predicate on the "prompt" field. +func PromptEqualFold(v string) predicate.AuthRequest { + return predicate.AuthRequest(sql.FieldEqualFold(FieldPrompt, v)) +} + +// PromptContainsFold applies the ContainsFold predicate on the "prompt" field. +func PromptContainsFold(v string) predicate.AuthRequest { + return predicate.AuthRequest(sql.FieldContainsFold(FieldPrompt, v)) +} + +// MaxAgeEQ applies the EQ predicate on the "max_age" field. +func MaxAgeEQ(v int) predicate.AuthRequest { + return predicate.AuthRequest(sql.FieldEQ(FieldMaxAge, v)) +} + +// MaxAgeNEQ applies the NEQ predicate on the "max_age" field. +func MaxAgeNEQ(v int) predicate.AuthRequest { + return predicate.AuthRequest(sql.FieldNEQ(FieldMaxAge, v)) +} + +// MaxAgeIn applies the In predicate on the "max_age" field. +func MaxAgeIn(vs ...int) predicate.AuthRequest { + return predicate.AuthRequest(sql.FieldIn(FieldMaxAge, vs...)) +} + +// MaxAgeNotIn applies the NotIn predicate on the "max_age" field. +func MaxAgeNotIn(vs ...int) predicate.AuthRequest { + return predicate.AuthRequest(sql.FieldNotIn(FieldMaxAge, vs...)) +} + +// MaxAgeGT applies the GT predicate on the "max_age" field. +func MaxAgeGT(v int) predicate.AuthRequest { + return predicate.AuthRequest(sql.FieldGT(FieldMaxAge, v)) +} + +// MaxAgeGTE applies the GTE predicate on the "max_age" field. +func MaxAgeGTE(v int) predicate.AuthRequest { + return predicate.AuthRequest(sql.FieldGTE(FieldMaxAge, v)) +} + +// MaxAgeLT applies the LT predicate on the "max_age" field. +func MaxAgeLT(v int) predicate.AuthRequest { + return predicate.AuthRequest(sql.FieldLT(FieldMaxAge, v)) +} + +// MaxAgeLTE applies the LTE predicate on the "max_age" field. +func MaxAgeLTE(v int) predicate.AuthRequest { + return predicate.AuthRequest(sql.FieldLTE(FieldMaxAge, v)) +} + +// AuthTimeEQ applies the EQ predicate on the "auth_time" field. +func AuthTimeEQ(v time.Time) predicate.AuthRequest { + return predicate.AuthRequest(sql.FieldEQ(FieldAuthTime, v)) +} + +// AuthTimeNEQ applies the NEQ predicate on the "auth_time" field. +func AuthTimeNEQ(v time.Time) predicate.AuthRequest { + return predicate.AuthRequest(sql.FieldNEQ(FieldAuthTime, v)) +} + +// AuthTimeIn applies the In predicate on the "auth_time" field. +func AuthTimeIn(vs ...time.Time) predicate.AuthRequest { + return predicate.AuthRequest(sql.FieldIn(FieldAuthTime, vs...)) +} + +// AuthTimeNotIn applies the NotIn predicate on the "auth_time" field. +func AuthTimeNotIn(vs ...time.Time) predicate.AuthRequest { + return predicate.AuthRequest(sql.FieldNotIn(FieldAuthTime, vs...)) +} + +// AuthTimeGT applies the GT predicate on the "auth_time" field. +func AuthTimeGT(v time.Time) predicate.AuthRequest { + return predicate.AuthRequest(sql.FieldGT(FieldAuthTime, v)) +} + +// AuthTimeGTE applies the GTE predicate on the "auth_time" field. +func AuthTimeGTE(v time.Time) predicate.AuthRequest { + return predicate.AuthRequest(sql.FieldGTE(FieldAuthTime, v)) +} + +// AuthTimeLT applies the LT predicate on the "auth_time" field. +func AuthTimeLT(v time.Time) predicate.AuthRequest { + return predicate.AuthRequest(sql.FieldLT(FieldAuthTime, v)) +} + +// AuthTimeLTE applies the LTE predicate on the "auth_time" field. +func AuthTimeLTE(v time.Time) predicate.AuthRequest { + return predicate.AuthRequest(sql.FieldLTE(FieldAuthTime, v)) +} + +// AuthTimeIsNil applies the IsNil predicate on the "auth_time" field. +func AuthTimeIsNil() predicate.AuthRequest { + return predicate.AuthRequest(sql.FieldIsNull(FieldAuthTime)) +} + +// AuthTimeNotNil applies the NotNil predicate on the "auth_time" field. +func AuthTimeNotNil() predicate.AuthRequest { + return predicate.AuthRequest(sql.FieldNotNull(FieldAuthTime)) +} + // And groups predicates with the AND operator between them. func And(predicates ...predicate.AuthRequest) predicate.AuthRequest { return predicate.AuthRequest(sql.AndPredicates(predicates...)) diff --git a/storage/ent/db/authrequest_create.go b/storage/ent/db/authrequest_create.go index c36648a1..324c99d9 100644 --- a/storage/ent/db/authrequest_create.go +++ b/storage/ent/db/authrequest_create.go @@ -178,6 +178,48 @@ func (_c *AuthRequestCreate) SetNillableMfaValidated(v *bool) *AuthRequestCreate return _c } +// SetPrompt sets the "prompt" field. +func (_c *AuthRequestCreate) SetPrompt(v string) *AuthRequestCreate { + _c.mutation.SetPrompt(v) + return _c +} + +// SetNillablePrompt sets the "prompt" field if the given value is not nil. +func (_c *AuthRequestCreate) SetNillablePrompt(v *string) *AuthRequestCreate { + if v != nil { + _c.SetPrompt(*v) + } + return _c +} + +// SetMaxAge sets the "max_age" field. +func (_c *AuthRequestCreate) SetMaxAge(v int) *AuthRequestCreate { + _c.mutation.SetMaxAge(v) + return _c +} + +// SetNillableMaxAge sets the "max_age" field if the given value is not nil. +func (_c *AuthRequestCreate) SetNillableMaxAge(v *int) *AuthRequestCreate { + if v != nil { + _c.SetMaxAge(*v) + } + return _c +} + +// SetAuthTime sets the "auth_time" field. +func (_c *AuthRequestCreate) SetAuthTime(v time.Time) *AuthRequestCreate { + _c.mutation.SetAuthTime(v) + return _c +} + +// SetNillableAuthTime sets the "auth_time" field if the given value is not nil. +func (_c *AuthRequestCreate) SetNillableAuthTime(v *time.Time) *AuthRequestCreate { + if v != nil { + _c.SetAuthTime(*v) + } + return _c +} + // SetID sets the "id" field. func (_c *AuthRequestCreate) SetID(v string) *AuthRequestCreate { _c.mutation.SetID(v) @@ -235,6 +277,14 @@ func (_c *AuthRequestCreate) defaults() { v := authrequest.DefaultMfaValidated _c.mutation.SetMfaValidated(v) } + if _, ok := _c.mutation.Prompt(); !ok { + v := authrequest.DefaultPrompt + _c.mutation.SetPrompt(v) + } + if _, ok := _c.mutation.MaxAge(); !ok { + v := authrequest.DefaultMaxAge + _c.mutation.SetMaxAge(v) + } } // check runs all checks and user-defined validators on the builder. @@ -290,6 +340,12 @@ func (_c *AuthRequestCreate) check() error { if _, ok := _c.mutation.MfaValidated(); !ok { return &ValidationError{Name: "mfa_validated", err: errors.New(`db: missing required field "AuthRequest.mfa_validated"`)} } + if _, ok := _c.mutation.Prompt(); !ok { + return &ValidationError{Name: "prompt", err: errors.New(`db: missing required field "AuthRequest.prompt"`)} + } + if _, ok := _c.mutation.MaxAge(); !ok { + return &ValidationError{Name: "max_age", err: errors.New(`db: missing required field "AuthRequest.max_age"`)} + } if v, ok := _c.mutation.ID(); ok { if err := authrequest.IDValidator(v); err != nil { return &ValidationError{Name: "id", err: fmt.Errorf(`db: validator failed for field "AuthRequest.id": %w`, err)} @@ -414,6 +470,18 @@ func (_c *AuthRequestCreate) createSpec() (*AuthRequest, *sqlgraph.CreateSpec) { _spec.SetField(authrequest.FieldMfaValidated, field.TypeBool, value) _node.MfaValidated = value } + if value, ok := _c.mutation.Prompt(); ok { + _spec.SetField(authrequest.FieldPrompt, field.TypeString, value) + _node.Prompt = value + } + if value, ok := _c.mutation.MaxAge(); ok { + _spec.SetField(authrequest.FieldMaxAge, field.TypeInt, value) + _node.MaxAge = value + } + if value, ok := _c.mutation.AuthTime(); ok { + _spec.SetField(authrequest.FieldAuthTime, field.TypeTime, value) + _node.AuthTime = value + } return _node, _spec } diff --git a/storage/ent/db/authrequest_update.go b/storage/ent/db/authrequest_update.go index e512b2d9..7edece02 100644 --- a/storage/ent/db/authrequest_update.go +++ b/storage/ent/db/authrequest_update.go @@ -325,6 +325,61 @@ func (_u *AuthRequestUpdate) SetNillableMfaValidated(v *bool) *AuthRequestUpdate return _u } +// SetPrompt sets the "prompt" field. +func (_u *AuthRequestUpdate) SetPrompt(v string) *AuthRequestUpdate { + _u.mutation.SetPrompt(v) + return _u +} + +// SetNillablePrompt sets the "prompt" field if the given value is not nil. +func (_u *AuthRequestUpdate) SetNillablePrompt(v *string) *AuthRequestUpdate { + if v != nil { + _u.SetPrompt(*v) + } + return _u +} + +// SetMaxAge sets the "max_age" field. +func (_u *AuthRequestUpdate) SetMaxAge(v int) *AuthRequestUpdate { + _u.mutation.ResetMaxAge() + _u.mutation.SetMaxAge(v) + return _u +} + +// SetNillableMaxAge sets the "max_age" field if the given value is not nil. +func (_u *AuthRequestUpdate) SetNillableMaxAge(v *int) *AuthRequestUpdate { + if v != nil { + _u.SetMaxAge(*v) + } + return _u +} + +// AddMaxAge adds value to the "max_age" field. +func (_u *AuthRequestUpdate) AddMaxAge(v int) *AuthRequestUpdate { + _u.mutation.AddMaxAge(v) + return _u +} + +// SetAuthTime sets the "auth_time" field. +func (_u *AuthRequestUpdate) SetAuthTime(v time.Time) *AuthRequestUpdate { + _u.mutation.SetAuthTime(v) + return _u +} + +// SetNillableAuthTime sets the "auth_time" field if the given value is not nil. +func (_u *AuthRequestUpdate) SetNillableAuthTime(v *time.Time) *AuthRequestUpdate { + if v != nil { + _u.SetAuthTime(*v) + } + return _u +} + +// ClearAuthTime clears the value of the "auth_time" field. +func (_u *AuthRequestUpdate) ClearAuthTime() *AuthRequestUpdate { + _u.mutation.ClearAuthTime() + return _u +} + // Mutation returns the AuthRequestMutation object of the builder. func (_u *AuthRequestUpdate) Mutation() *AuthRequestMutation { return _u.mutation @@ -456,6 +511,21 @@ func (_u *AuthRequestUpdate) sqlSave(ctx context.Context) (_node int, err error) if value, ok := _u.mutation.MfaValidated(); ok { _spec.SetField(authrequest.FieldMfaValidated, field.TypeBool, value) } + if value, ok := _u.mutation.Prompt(); ok { + _spec.SetField(authrequest.FieldPrompt, field.TypeString, value) + } + if value, ok := _u.mutation.MaxAge(); ok { + _spec.SetField(authrequest.FieldMaxAge, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedMaxAge(); ok { + _spec.AddField(authrequest.FieldMaxAge, field.TypeInt, value) + } + if value, ok := _u.mutation.AuthTime(); ok { + _spec.SetField(authrequest.FieldAuthTime, field.TypeTime, value) + } + if _u.mutation.AuthTimeCleared() { + _spec.ClearField(authrequest.FieldAuthTime, field.TypeTime) + } if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{authrequest.Label} @@ -772,6 +842,61 @@ func (_u *AuthRequestUpdateOne) SetNillableMfaValidated(v *bool) *AuthRequestUpd return _u } +// SetPrompt sets the "prompt" field. +func (_u *AuthRequestUpdateOne) SetPrompt(v string) *AuthRequestUpdateOne { + _u.mutation.SetPrompt(v) + return _u +} + +// SetNillablePrompt sets the "prompt" field if the given value is not nil. +func (_u *AuthRequestUpdateOne) SetNillablePrompt(v *string) *AuthRequestUpdateOne { + if v != nil { + _u.SetPrompt(*v) + } + return _u +} + +// SetMaxAge sets the "max_age" field. +func (_u *AuthRequestUpdateOne) SetMaxAge(v int) *AuthRequestUpdateOne { + _u.mutation.ResetMaxAge() + _u.mutation.SetMaxAge(v) + return _u +} + +// SetNillableMaxAge sets the "max_age" field if the given value is not nil. +func (_u *AuthRequestUpdateOne) SetNillableMaxAge(v *int) *AuthRequestUpdateOne { + if v != nil { + _u.SetMaxAge(*v) + } + return _u +} + +// AddMaxAge adds value to the "max_age" field. +func (_u *AuthRequestUpdateOne) AddMaxAge(v int) *AuthRequestUpdateOne { + _u.mutation.AddMaxAge(v) + return _u +} + +// SetAuthTime sets the "auth_time" field. +func (_u *AuthRequestUpdateOne) SetAuthTime(v time.Time) *AuthRequestUpdateOne { + _u.mutation.SetAuthTime(v) + return _u +} + +// SetNillableAuthTime sets the "auth_time" field if the given value is not nil. +func (_u *AuthRequestUpdateOne) SetNillableAuthTime(v *time.Time) *AuthRequestUpdateOne { + if v != nil { + _u.SetAuthTime(*v) + } + return _u +} + +// ClearAuthTime clears the value of the "auth_time" field. +func (_u *AuthRequestUpdateOne) ClearAuthTime() *AuthRequestUpdateOne { + _u.mutation.ClearAuthTime() + return _u +} + // Mutation returns the AuthRequestMutation object of the builder. func (_u *AuthRequestUpdateOne) Mutation() *AuthRequestMutation { return _u.mutation @@ -933,6 +1058,21 @@ func (_u *AuthRequestUpdateOne) sqlSave(ctx context.Context) (_node *AuthRequest if value, ok := _u.mutation.MfaValidated(); ok { _spec.SetField(authrequest.FieldMfaValidated, field.TypeBool, value) } + if value, ok := _u.mutation.Prompt(); ok { + _spec.SetField(authrequest.FieldPrompt, field.TypeString, value) + } + if value, ok := _u.mutation.MaxAge(); ok { + _spec.SetField(authrequest.FieldMaxAge, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedMaxAge(); ok { + _spec.AddField(authrequest.FieldMaxAge, field.TypeInt, value) + } + if value, ok := _u.mutation.AuthTime(); ok { + _spec.SetField(authrequest.FieldAuthTime, field.TypeTime, value) + } + if _u.mutation.AuthTimeCleared() { + _spec.ClearField(authrequest.FieldAuthTime, field.TypeTime) + } _node = &AuthRequest{config: _u.config} _spec.Assign = _node.assignValues _spec.ScanValues = _node.scanValues diff --git a/storage/ent/db/migrate/schema.go b/storage/ent/db/migrate/schema.go index 36db86e3..3fc0f834 100644 --- a/storage/ent/db/migrate/schema.go +++ b/storage/ent/db/migrate/schema.go @@ -26,6 +26,7 @@ var ( {Name: "expiry", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime(3)", "postgres": "timestamptz", "sqlite3": "timestamp"}}, {Name: "code_challenge", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, {Name: "code_challenge_method", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, + {Name: "auth_time", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"mysql": "datetime(3)", "postgres": "timestamptz", "sqlite3": "timestamp"}}, } // AuthCodesTable holds the schema information for the "auth_codes" table. AuthCodesTable = &schema.Table{ @@ -57,6 +58,9 @@ var ( {Name: "code_challenge_method", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, {Name: "hmac_key", Type: field.TypeBytes}, {Name: "mfa_validated", Type: field.TypeBool, Default: false}, + {Name: "prompt", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, + {Name: "max_age", Type: field.TypeInt, Default: -1}, + {Name: "auth_time", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"mysql": "datetime(3)", "postgres": "timestamptz", "sqlite3": "timestamp"}}, } // AuthRequestsTable holds the schema information for the "auth_requests" table. AuthRequestsTable = &schema.Table{ diff --git a/storage/ent/db/mutation.go b/storage/ent/db/mutation.go index 151443f9..46f204eb 100644 --- a/storage/ent/db/mutation.go +++ b/storage/ent/db/mutation.go @@ -74,6 +74,7 @@ type AuthCodeMutation struct { expiry *time.Time code_challenge *string code_challenge_method *string + auth_time *time.Time clearedFields map[string]struct{} done bool oldValue func(context.Context) (*AuthCode, error) @@ -795,6 +796,55 @@ func (m *AuthCodeMutation) ResetCodeChallengeMethod() { m.code_challenge_method = nil } +// SetAuthTime sets the "auth_time" field. +func (m *AuthCodeMutation) SetAuthTime(t time.Time) { + m.auth_time = &t +} + +// AuthTime returns the value of the "auth_time" field in the mutation. +func (m *AuthCodeMutation) AuthTime() (r time.Time, exists bool) { + v := m.auth_time + if v == nil { + return + } + return *v, true +} + +// OldAuthTime returns the old "auth_time" field's value of the AuthCode entity. +// If the AuthCode object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthCodeMutation) OldAuthTime(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAuthTime is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAuthTime requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAuthTime: %w", err) + } + return oldValue.AuthTime, nil +} + +// ClearAuthTime clears the value of the "auth_time" field. +func (m *AuthCodeMutation) ClearAuthTime() { + m.auth_time = nil + m.clearedFields[authcode.FieldAuthTime] = struct{}{} +} + +// AuthTimeCleared returns if the "auth_time" field was cleared in this mutation. +func (m *AuthCodeMutation) AuthTimeCleared() bool { + _, ok := m.clearedFields[authcode.FieldAuthTime] + return ok +} + +// ResetAuthTime resets all changes to the "auth_time" field. +func (m *AuthCodeMutation) ResetAuthTime() { + m.auth_time = nil + delete(m.clearedFields, authcode.FieldAuthTime) +} + // Where appends a list predicates to the AuthCodeMutation builder. func (m *AuthCodeMutation) Where(ps ...predicate.AuthCode) { m.predicates = append(m.predicates, ps...) @@ -829,7 +879,7 @@ func (m *AuthCodeMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *AuthCodeMutation) Fields() []string { - fields := make([]string, 0, 15) + fields := make([]string, 0, 16) if m.client_id != nil { fields = append(fields, authcode.FieldClientID) } @@ -875,6 +925,9 @@ func (m *AuthCodeMutation) Fields() []string { if m.code_challenge_method != nil { fields = append(fields, authcode.FieldCodeChallengeMethod) } + if m.auth_time != nil { + fields = append(fields, authcode.FieldAuthTime) + } return fields } @@ -913,6 +966,8 @@ func (m *AuthCodeMutation) Field(name string) (ent.Value, bool) { return m.CodeChallenge() case authcode.FieldCodeChallengeMethod: return m.CodeChallengeMethod() + case authcode.FieldAuthTime: + return m.AuthTime() } return nil, false } @@ -952,6 +1007,8 @@ func (m *AuthCodeMutation) OldField(ctx context.Context, name string) (ent.Value return m.OldCodeChallenge(ctx) case authcode.FieldCodeChallengeMethod: return m.OldCodeChallengeMethod(ctx) + case authcode.FieldAuthTime: + return m.OldAuthTime(ctx) } return nil, fmt.Errorf("unknown AuthCode field %s", name) } @@ -1066,6 +1123,13 @@ func (m *AuthCodeMutation) SetField(name string, value ent.Value) error { } m.SetCodeChallengeMethod(v) return nil + case authcode.FieldAuthTime: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAuthTime(v) + return nil } return fmt.Errorf("unknown AuthCode field %s", name) } @@ -1105,6 +1169,9 @@ func (m *AuthCodeMutation) ClearedFields() []string { if m.FieldCleared(authcode.FieldConnectorData) { fields = append(fields, authcode.FieldConnectorData) } + if m.FieldCleared(authcode.FieldAuthTime) { + fields = append(fields, authcode.FieldAuthTime) + } return fields } @@ -1128,6 +1195,9 @@ func (m *AuthCodeMutation) ClearField(name string) error { case authcode.FieldConnectorData: m.ClearConnectorData() return nil + case authcode.FieldAuthTime: + m.ClearAuthTime() + return nil } return fmt.Errorf("unknown AuthCode nullable field %s", name) } @@ -1181,6 +1251,9 @@ func (m *AuthCodeMutation) ResetField(name string) error { case authcode.FieldCodeChallengeMethod: m.ResetCodeChallengeMethod() return nil + case authcode.FieldAuthTime: + m.ResetAuthTime() + return nil } return fmt.Errorf("unknown AuthCode field %s", name) } @@ -1263,6 +1336,10 @@ type AuthRequestMutation struct { code_challenge_method *string hmac_key *[]byte mfa_validated *bool + prompt *string + max_age *int + addmax_age *int + auth_time *time.Time clearedFields map[string]struct{} done bool oldValue func(context.Context) (*AuthRequest, error) @@ -2229,6 +2306,147 @@ func (m *AuthRequestMutation) ResetMfaValidated() { m.mfa_validated = nil } +// SetPrompt sets the "prompt" field. +func (m *AuthRequestMutation) SetPrompt(s string) { + m.prompt = &s +} + +// Prompt returns the value of the "prompt" field in the mutation. +func (m *AuthRequestMutation) Prompt() (r string, exists bool) { + v := m.prompt + if v == nil { + return + } + return *v, true +} + +// OldPrompt returns the old "prompt" field's value of the AuthRequest entity. +// If the AuthRequest object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthRequestMutation) OldPrompt(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPrompt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPrompt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPrompt: %w", err) + } + return oldValue.Prompt, nil +} + +// ResetPrompt resets all changes to the "prompt" field. +func (m *AuthRequestMutation) ResetPrompt() { + m.prompt = nil +} + +// SetMaxAge sets the "max_age" field. +func (m *AuthRequestMutation) SetMaxAge(i int) { + m.max_age = &i + m.addmax_age = nil +} + +// MaxAge returns the value of the "max_age" field in the mutation. +func (m *AuthRequestMutation) MaxAge() (r int, exists bool) { + v := m.max_age + if v == nil { + return + } + return *v, true +} + +// OldMaxAge returns the old "max_age" field's value of the AuthRequest entity. +// If the AuthRequest object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthRequestMutation) OldMaxAge(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldMaxAge is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldMaxAge requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldMaxAge: %w", err) + } + return oldValue.MaxAge, nil +} + +// AddMaxAge adds i to the "max_age" field. +func (m *AuthRequestMutation) AddMaxAge(i int) { + if m.addmax_age != nil { + *m.addmax_age += i + } else { + m.addmax_age = &i + } +} + +// AddedMaxAge returns the value that was added to the "max_age" field in this mutation. +func (m *AuthRequestMutation) AddedMaxAge() (r int, exists bool) { + v := m.addmax_age + if v == nil { + return + } + return *v, true +} + +// ResetMaxAge resets all changes to the "max_age" field. +func (m *AuthRequestMutation) ResetMaxAge() { + m.max_age = nil + m.addmax_age = nil +} + +// SetAuthTime sets the "auth_time" field. +func (m *AuthRequestMutation) SetAuthTime(t time.Time) { + m.auth_time = &t +} + +// AuthTime returns the value of the "auth_time" field in the mutation. +func (m *AuthRequestMutation) AuthTime() (r time.Time, exists bool) { + v := m.auth_time + if v == nil { + return + } + return *v, true +} + +// OldAuthTime returns the old "auth_time" field's value of the AuthRequest entity. +// If the AuthRequest object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthRequestMutation) OldAuthTime(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAuthTime is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAuthTime requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAuthTime: %w", err) + } + return oldValue.AuthTime, nil +} + +// ClearAuthTime clears the value of the "auth_time" field. +func (m *AuthRequestMutation) ClearAuthTime() { + m.auth_time = nil + m.clearedFields[authrequest.FieldAuthTime] = struct{}{} +} + +// AuthTimeCleared returns if the "auth_time" field was cleared in this mutation. +func (m *AuthRequestMutation) AuthTimeCleared() bool { + _, ok := m.clearedFields[authrequest.FieldAuthTime] + return ok +} + +// ResetAuthTime resets all changes to the "auth_time" field. +func (m *AuthRequestMutation) ResetAuthTime() { + m.auth_time = nil + delete(m.clearedFields, authrequest.FieldAuthTime) +} + // Where appends a list predicates to the AuthRequestMutation builder. func (m *AuthRequestMutation) Where(ps ...predicate.AuthRequest) { m.predicates = append(m.predicates, ps...) @@ -2263,7 +2481,7 @@ func (m *AuthRequestMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *AuthRequestMutation) Fields() []string { - fields := make([]string, 0, 21) + fields := make([]string, 0, 24) if m.client_id != nil { fields = append(fields, authrequest.FieldClientID) } @@ -2327,6 +2545,15 @@ func (m *AuthRequestMutation) Fields() []string { if m.mfa_validated != nil { fields = append(fields, authrequest.FieldMfaValidated) } + if m.prompt != nil { + fields = append(fields, authrequest.FieldPrompt) + } + if m.max_age != nil { + fields = append(fields, authrequest.FieldMaxAge) + } + if m.auth_time != nil { + fields = append(fields, authrequest.FieldAuthTime) + } return fields } @@ -2377,6 +2604,12 @@ func (m *AuthRequestMutation) Field(name string) (ent.Value, bool) { return m.HmacKey() case authrequest.FieldMfaValidated: return m.MfaValidated() + case authrequest.FieldPrompt: + return m.Prompt() + case authrequest.FieldMaxAge: + return m.MaxAge() + case authrequest.FieldAuthTime: + return m.AuthTime() } return nil, false } @@ -2428,6 +2661,12 @@ func (m *AuthRequestMutation) OldField(ctx context.Context, name string) (ent.Va return m.OldHmacKey(ctx) case authrequest.FieldMfaValidated: return m.OldMfaValidated(ctx) + case authrequest.FieldPrompt: + return m.OldPrompt(ctx) + case authrequest.FieldMaxAge: + return m.OldMaxAge(ctx) + case authrequest.FieldAuthTime: + return m.OldAuthTime(ctx) } return nil, fmt.Errorf("unknown AuthRequest field %s", name) } @@ -2584,6 +2823,27 @@ func (m *AuthRequestMutation) SetField(name string, value ent.Value) error { } m.SetMfaValidated(v) return nil + case authrequest.FieldPrompt: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPrompt(v) + return nil + case authrequest.FieldMaxAge: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMaxAge(v) + return nil + case authrequest.FieldAuthTime: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAuthTime(v) + return nil } return fmt.Errorf("unknown AuthRequest field %s", name) } @@ -2591,13 +2851,21 @@ func (m *AuthRequestMutation) SetField(name string, value ent.Value) error { // AddedFields returns all numeric fields that were incremented/decremented during // this mutation. func (m *AuthRequestMutation) AddedFields() []string { - return nil + var fields []string + if m.addmax_age != nil { + fields = append(fields, authrequest.FieldMaxAge) + } + return fields } // AddedField returns the numeric value that was incremented/decremented on a field // with the given name. The second boolean return value indicates that this field // was not set, or was not defined in the schema. func (m *AuthRequestMutation) AddedField(name string) (ent.Value, bool) { + switch name { + case authrequest.FieldMaxAge: + return m.AddedMaxAge() + } return nil, false } @@ -2606,6 +2874,13 @@ func (m *AuthRequestMutation) AddedField(name string) (ent.Value, bool) { // type. func (m *AuthRequestMutation) AddField(name string, value ent.Value) error { switch name { + case authrequest.FieldMaxAge: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddMaxAge(v) + return nil } return fmt.Errorf("unknown AuthRequest numeric field %s", name) } @@ -2626,6 +2901,9 @@ func (m *AuthRequestMutation) ClearedFields() []string { if m.FieldCleared(authrequest.FieldConnectorData) { fields = append(fields, authrequest.FieldConnectorData) } + if m.FieldCleared(authrequest.FieldAuthTime) { + fields = append(fields, authrequest.FieldAuthTime) + } return fields } @@ -2652,6 +2930,9 @@ func (m *AuthRequestMutation) ClearField(name string) error { case authrequest.FieldConnectorData: m.ClearConnectorData() return nil + case authrequest.FieldAuthTime: + m.ClearAuthTime() + return nil } return fmt.Errorf("unknown AuthRequest nullable field %s", name) } @@ -2723,6 +3004,15 @@ func (m *AuthRequestMutation) ResetField(name string) error { case authrequest.FieldMfaValidated: m.ResetMfaValidated() return nil + case authrequest.FieldPrompt: + m.ResetPrompt() + return nil + case authrequest.FieldMaxAge: + m.ResetMaxAge() + return nil + case authrequest.FieldAuthTime: + m.ResetAuthTime() + return nil } return fmt.Errorf("unknown AuthRequest field %s", name) } diff --git a/storage/ent/db/runtime.go b/storage/ent/db/runtime.go index df011f55..2c1c5404 100644 --- a/storage/ent/db/runtime.go +++ b/storage/ent/db/runtime.go @@ -88,6 +88,14 @@ func init() { authrequestDescMfaValidated := authrequestFields[21].Descriptor() // authrequest.DefaultMfaValidated holds the default value on creation for the mfa_validated field. authrequest.DefaultMfaValidated = authrequestDescMfaValidated.Default.(bool) + // authrequestDescPrompt is the schema descriptor for prompt field. + authrequestDescPrompt := authrequestFields[22].Descriptor() + // authrequest.DefaultPrompt holds the default value on creation for the prompt field. + authrequest.DefaultPrompt = authrequestDescPrompt.Default.(string) + // authrequestDescMaxAge is the schema descriptor for max_age field. + authrequestDescMaxAge := authrequestFields[23].Descriptor() + // authrequest.DefaultMaxAge holds the default value on creation for the max_age field. + authrequest.DefaultMaxAge = authrequestDescMaxAge.Default.(int) // authrequestDescID is the schema descriptor for id field. authrequestDescID := authrequestFields[0].Descriptor() // authrequest.IDValidator is a validator for the "id" field. It is called by the builders before save. diff --git a/storage/ent/schema/authcode.go b/storage/ent/schema/authcode.go index 1574347b..b2669d08 100644 --- a/storage/ent/schema/authcode.go +++ b/storage/ent/schema/authcode.go @@ -81,6 +81,9 @@ func (AuthCode) Fields() []ent.Field { field.Text("code_challenge_method"). SchemaType(textSchema). Default(""), + field.Time("auth_time"). + SchemaType(timeSchema). + Optional(), } } diff --git a/storage/ent/schema/authrequest.go b/storage/ent/schema/authrequest.go index 905c73ab..be24c29f 100644 --- a/storage/ent/schema/authrequest.go +++ b/storage/ent/schema/authrequest.go @@ -90,6 +90,9 @@ func (AuthRequest) Fields() []ent.Field { field.Bytes("hmac_key"), field.Bool("mfa_validated"). Default(false), + field.Text("prompt").SchemaType(textSchema).Default(""), + field.Int("max_age").Default(-1), + field.Time("auth_time").SchemaType(timeSchema).Optional(), } } diff --git a/storage/etcd/types.go b/storage/etcd/types.go index b76727e2..0a73e9ea 100644 --- a/storage/etcd/types.go +++ b/storage/etcd/types.go @@ -24,6 +24,8 @@ type AuthCode struct { CodeChallenge string `json:"code_challenge,omitempty"` CodeChallengeMethod string `json:"code_challenge_method,omitempty"` + + AuthTime time.Time `json:"auth_time"` } func toStorageAuthCode(a AuthCode) storage.AuthCode { @@ -41,6 +43,7 @@ func toStorageAuthCode(a AuthCode) storage.AuthCode { CodeChallenge: a.CodeChallenge, CodeChallengeMethod: a.CodeChallengeMethod, }, + AuthTime: a.AuthTime, } } @@ -57,6 +60,7 @@ func fromStorageAuthCode(a storage.AuthCode) AuthCode { Expiry: a.Expiry, CodeChallenge: a.PKCE.CodeChallenge, CodeChallengeMethod: a.PKCE.CodeChallengeMethod, + AuthTime: a.AuthTime, } } @@ -88,6 +92,10 @@ type AuthRequest struct { HMACKey []byte `json:"hmac_key"` MFAValidated bool `json:"mfa_validated"` + + Prompt string `json:"prompt,omitempty"` + MaxAge int `json:"max_age"` + AuthTime time.Time `json:"auth_time"` } func fromStorageAuthRequest(a storage.AuthRequest) AuthRequest { @@ -109,6 +117,9 @@ func fromStorageAuthRequest(a storage.AuthRequest) AuthRequest { CodeChallengeMethod: a.PKCE.CodeChallengeMethod, HMACKey: a.HMACKey, MFAValidated: a.MFAValidated, + Prompt: a.Prompt, + MaxAge: a.MaxAge, + AuthTime: a.AuthTime, } } @@ -133,6 +144,9 @@ func toStorageAuthRequest(a AuthRequest) storage.AuthRequest { }, HMACKey: a.HMACKey, MFAValidated: a.MFAValidated, + Prompt: a.Prompt, + MaxAge: a.MaxAge, + AuthTime: a.AuthTime, } } diff --git a/storage/kubernetes/types.go b/storage/kubernetes/types.go index c0245dad..344ae75e 100644 --- a/storage/kubernetes/types.go +++ b/storage/kubernetes/types.go @@ -402,6 +402,10 @@ type AuthRequest struct { HMACKey []byte `json:"hmac_key"` MFAValidated bool `json:"mfa_validated"` + + Prompt string `json:"prompt,omitempty"` + MaxAge int `json:"maxAge"` + AuthTime time.Time `json:"authTime,omitempty"` } // AuthRequestList is a list of AuthRequests. @@ -432,6 +436,9 @@ func toStorageAuthRequest(req AuthRequest) storage.AuthRequest { }, HMACKey: req.HMACKey, MFAValidated: req.MFAValidated, + Prompt: req.Prompt, + MaxAge: req.MaxAge, + AuthTime: req.AuthTime, } return a } @@ -462,6 +469,9 @@ func (cli *client) fromStorageAuthRequest(a storage.AuthRequest) AuthRequest { CodeChallengeMethod: a.PKCE.CodeChallengeMethod, HMACKey: a.HMACKey, MFAValidated: a.MFAValidated, + Prompt: a.Prompt, + MaxAge: a.MaxAge, + AuthTime: a.AuthTime, } return req } @@ -550,6 +560,8 @@ type AuthCode struct { CodeChallenge string `json:"code_challenge,omitempty"` CodeChallengeMethod string `json:"code_challenge_method,omitempty"` + + AuthTime time.Time `json:"authTime,omitempty"` } // AuthCodeList is a list of AuthCodes. @@ -579,6 +591,7 @@ func (cli *client) fromStorageAuthCode(a storage.AuthCode) AuthCode { Expiry: a.Expiry, CodeChallenge: a.PKCE.CodeChallenge, CodeChallengeMethod: a.PKCE.CodeChallengeMethod, + AuthTime: a.AuthTime, } } @@ -597,6 +610,7 @@ func toStorageAuthCode(a AuthCode) storage.AuthCode { CodeChallenge: a.CodeChallenge, CodeChallengeMethod: a.CodeChallengeMethod, }, + AuthTime: a.AuthTime, } } diff --git a/storage/sql/crud.go b/storage/sql/crud.go index d77ce4e0..82b9bc9b 100644 --- a/storage/sql/crud.go +++ b/storage/sql/crud.go @@ -135,10 +135,11 @@ func (c *conn) CreateAuthRequest(ctx context.Context, a storage.AuthRequest) err expiry, code_challenge, code_challenge_method, hmac_key, - mfa_validated + mfa_validated, + prompt, max_age, auth_time ) values ( - $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22 + $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25 ); `, a.ID, a.ClientID, encoder(a.ResponseTypes), encoder(a.Scopes), a.RedirectURI, a.Nonce, a.State, @@ -150,6 +151,7 @@ func (c *conn) CreateAuthRequest(ctx context.Context, a storage.AuthRequest) err a.PKCE.CodeChallenge, a.PKCE.CodeChallengeMethod, a.HMACKey, a.MFAValidated, + a.Prompt, a.MaxAge, a.AuthTime, ) if err != nil { if c.alreadyExistsCheck(err) { @@ -183,8 +185,9 @@ func (c *conn) UpdateAuthRequest(ctx context.Context, id string, updater func(a expiry = $17, code_challenge = $18, code_challenge_method = $19, hmac_key = $20, - mfa_validated = $21 - where id = $22; + mfa_validated = $21, + prompt = $22, max_age = $23, auth_time = $24 + where id = $25; `, a.ClientID, encoder(a.ResponseTypes), encoder(a.Scopes), a.RedirectURI, a.Nonce, a.State, a.ForceApprovalPrompt, a.LoggedIn, @@ -195,6 +198,7 @@ func (c *conn) UpdateAuthRequest(ctx context.Context, id string, updater func(a a.Expiry, a.PKCE.CodeChallenge, a.PKCE.CodeChallengeMethod, a.HMACKey, a.MFAValidated, + a.Prompt, a.MaxAge, a.AuthTime, r.ID, ) if err != nil { @@ -217,7 +221,8 @@ func getAuthRequest(ctx context.Context, q querier, id string) (a storage.AuthRe claims_email, claims_email_verified, claims_groups, connector_id, connector_data, expiry, code_challenge, code_challenge_method, hmac_key, - mfa_validated + mfa_validated, + prompt, max_age, auth_time from auth_request where id = $1; `, id).Scan( &a.ID, &a.ClientID, decoder(&a.ResponseTypes), decoder(&a.Scopes), &a.RedirectURI, &a.Nonce, &a.State, @@ -228,6 +233,7 @@ func getAuthRequest(ctx context.Context, q querier, id string) (a storage.AuthRe &a.ConnectorID, &a.ConnectorData, &a.Expiry, &a.PKCE.CodeChallenge, &a.PKCE.CodeChallengeMethod, &a.HMACKey, &a.MFAValidated, + &a.Prompt, &a.MaxAge, &a.AuthTime, ) if err != nil { if err == sql.ErrNoRows { @@ -246,14 +252,16 @@ func (c *conn) CreateAuthCode(ctx context.Context, a storage.AuthCode) error { claims_email, claims_email_verified, claims_groups, connector_id, connector_data, expiry, - code_challenge, code_challenge_method + code_challenge, code_challenge_method, + auth_time ) - values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16); + values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17); `, a.ID, a.ClientID, encoder(a.Scopes), a.Nonce, a.RedirectURI, a.Claims.UserID, a.Claims.Username, a.Claims.PreferredUsername, a.Claims.Email, a.Claims.EmailVerified, encoder(a.Claims.Groups), a.ConnectorID, a.ConnectorData, a.Expiry, a.PKCE.CodeChallenge, a.PKCE.CodeChallengeMethod, + a.AuthTime, ) if err != nil { if c.alreadyExistsCheck(err) { @@ -272,13 +280,15 @@ func (c *conn) GetAuthCode(ctx context.Context, id string) (a storage.AuthCode, claims_email, claims_email_verified, claims_groups, connector_id, connector_data, expiry, - code_challenge, code_challenge_method + code_challenge, code_challenge_method, + auth_time from auth_code where id = $1; `, id).Scan( &a.ID, &a.ClientID, decoder(&a.Scopes), &a.Nonce, &a.RedirectURI, &a.Claims.UserID, &a.Claims.Username, &a.Claims.PreferredUsername, &a.Claims.Email, &a.Claims.EmailVerified, decoder(&a.Claims.Groups), &a.ConnectorID, &a.ConnectorData, &a.Expiry, &a.PKCE.CodeChallenge, &a.PKCE.CodeChallengeMethod, + &a.AuthTime, ) if err != nil { if err == sql.ErrNoRows { diff --git a/storage/sql/migrate.go b/storage/sql/migrate.go index 0e44e52b..9131d284 100644 --- a/storage/sql/migrate.go +++ b/storage/sql/migrate.go @@ -436,6 +436,10 @@ var migrations = []migration{ ` alter table client add column mfa_chain bytea;`, + `alter table auth_request add column prompt text not null default '';`, + `alter table auth_request add column max_age integer not null default -1;`, + `alter table auth_request add column auth_time timestamptz not null default '1970-01-01 00:00:00';`, + `alter table auth_code add column auth_time timestamptz not null default '1970-01-01 00:00:00';`, }, }, } diff --git a/storage/storage.go b/storage/storage.go index 57c4823b..8161eb83 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -230,6 +230,20 @@ type AuthRequest struct { // attempts. ForceApprovalPrompt bool + // OIDC prompt parameter. Controls authentication and consent UI behavior. + // Values: "none", "login", "consent", "select_account". + Prompt string + + // MaxAge is the OIDC max_age parameter — maximum allowable elapsed time + // in seconds since the user last actively authenticated. + // -1 means not specified. + MaxAge int + + // AuthTime is when the user last actively authenticated (entered credentials). + // Set during finalizeLogin (= now) or trySessionLogin (= UserIdentity.LastLogin). + // Used in ID token as "auth_time" claim. + AuthTime time.Time + Expiry time.Time // Has the user proved their identity through a backing identity provider? @@ -290,6 +304,10 @@ type AuthCode struct { Expiry time.Time + // AuthTime is when the user last actively authenticated. + // Carried over from AuthRequest to include in ID tokens. + AuthTime time.Time + // PKCE CodeChallenge and CodeChallengeMethod PKCE PKCE }