diff --git a/cmd/dex/config.go b/cmd/dex/config.go index 8efc3c38..018e717c 100644 --- a/cmd/dex/config.go +++ b/cmd/dex/config.go @@ -65,6 +65,19 @@ type Config struct { // querying the storage. Cannot be specified without enabling a passwords // database. StaticPasswords []password `json:"staticPasswords"` + + // MFA holds multi-factor authentication configuration. + MFA MFAConfig `json:"mfa"` +} + +// MFAConfig holds multi-factor authentication settings. +type MFAConfig struct { + // Authenticators defines MFA providers available for clients to reference. + Authenticators []MFAAuthenticator `json:"authenticators"` + + // DefaultMFAChain is the default ordered list of authenticator IDs applied + // to clients that don't specify their own mfaChain. Empty means no MFA by default. + DefaultMFAChain []string `json:"defaultMFAChain"` } // Validate the configuration @@ -103,6 +116,55 @@ func (c Config) Validate() error { if len(checkErrors) != 0 { return fmt.Errorf("invalid Config:\n\t-\t%s", strings.Join(checkErrors, "\n\t-\t")) } + + if err := c.validateMFA(); err != nil { + return err + } + + return nil +} + +func (c Config) validateMFA() error { + mfa := c.MFA + if len(mfa.Authenticators) == 0 && len(mfa.DefaultMFAChain) == 0 { + return nil + } + + if !featureflags.SessionsEnabled.Enabled() { + return fmt.Errorf("mfa requires sessions to be enabled (DEX_SESSIONS_ENABLED=true)") + } + + knownTypes := map[string]bool{"TOTP": true} + ids := make(map[string]bool, len(mfa.Authenticators)) + + for _, auth := range mfa.Authenticators { + if auth.ID == "" { + return fmt.Errorf("mfa.authenticators: authenticator must have an id") + } + if ids[auth.ID] { + return fmt.Errorf("mfa.authenticators: duplicate authenticator id %q", auth.ID) + } + ids[auth.ID] = true + + if !knownTypes[auth.Type] { + return fmt.Errorf("mfa.authenticators: unknown type %q for authenticator %q", auth.Type, auth.ID) + } + } + + for _, authID := range mfa.DefaultMFAChain { + if !ids[authID] { + return fmt.Errorf("mfa.defaultMFAChain: references unknown authenticator %q", authID) + } + } + + for _, client := range c.StaticClients { + for _, authID := range client.MFAChain { + if !ids[authID] { + return fmt.Errorf("staticClients: client %q references unknown MFA authenticator %q", client.ID, authID) + } + } + } + return nil } @@ -585,3 +647,20 @@ type RefreshToken struct { AbsoluteLifetime string `json:"absoluteLifetime"` ValidIfNotUsedFor string `json:"validIfNotUsedFor"` } + +// MFAAuthenticator defines a multi-factor authentication provider. +type MFAAuthenticator struct { + ID string `json:"id"` + Type string `json:"type"` + Config json.RawMessage `json:"config"` + + // ConnectorTypes limits this authenticator to specific connector types (e.g., "ldap", "oidc", "saml"). + // If empty, the authenticator applies to all connector types. + ConnectorTypes []string `json:"connectorTypes"` +} + +// TOTPConfig holds configuration for a TOTP authenticator. +type TOTPConfig struct { + // Issuer is the name of the service shown in the authenticator app. + Issuer string `json:"issuer"` +} diff --git a/cmd/dex/serve.go b/cmd/dex/serve.go index 5cc0877a..3e650e97 100644 --- a/cmd/dex/serve.go +++ b/cmd/dex/serve.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "crypto/x509" + "encoding/json" "errors" "fmt" "log/slog" @@ -384,6 +385,8 @@ func runServe(options serveOptions) error { ContinueOnConnectorFailure: featureflags.ContinueOnConnectorFailure.Enabled(), Signer: signerInstance, IDTokensValidFor: idTokensValidFor, + MFAProviders: buildMFAProviders(c.MFA.Authenticators, logger), + DefaultMFAChain: c.MFA.DefaultMFAChain, } if c.Expiry.AuthRequests != "" { @@ -759,3 +762,24 @@ func loadTLSConfig(certFile, keyFile, caFile string, baseConfig *tls.Config) (*t func recordBuildInfo() { buildInfo.WithLabelValues(version, runtime.Version(), fmt.Sprintf("%s/%s", runtime.GOOS, runtime.GOARCH)).Set(1) } + +func buildMFAProviders(authenticators []MFAAuthenticator, logger *slog.Logger) map[string]server.MFAProvider { + if len(authenticators) == 0 { + return nil + } + + providers := make(map[string]server.MFAProvider, len(authenticators)) + for _, auth := range authenticators { + switch auth.Type { + case "TOTP": + var cfg TOTPConfig + if err := json.Unmarshal(auth.Config, &cfg); err != nil { + logger.Error("failed to parse TOTP config", "id", auth.ID, "err", err) + continue + } + providers[auth.ID] = server.NewTOTPProvider(cfg.Issuer, auth.ConnectorTypes) + logger.Info("MFA authenticator configured", "id", auth.ID, "type", auth.Type) + } + } + return providers +} diff --git a/examples/config-dev.yaml b/examples/config-dev.yaml index 0e8bb575..cf36bdef 100644 --- a/examples/config-dev.yaml +++ b/examples/config-dev.yaml @@ -130,6 +130,25 @@ telemetry: # # Supported code challenge methods. Defaults to ["S256", "plain"]. # codeChallengeMethodsSupported: ["S256", "plain"] +# Multi-factor authentication configuration. +# Requires DEX_SESSIONS_ENABLED=true feature flag. +# mfa: +# authenticators: +# - id: totp-1 +# type: TOTP +# config: +# issuer: "dex" +# # Optional: limit this authenticator to specific connector types (e.g., ldap, oidc, saml). +# # If omitted or empty, applies to all connector types. +# # It is recommended to use this option to prevent MFA from being used for connectors +# # with their own MFA mechanisms, e.g., OIDC, Google, etc. (but technically, it is possible). +# connectorTypes: +# - mockCallback +# # Default MFA chain applied to clients that don't specify their own mfaChain. +# # If omitted or empty, no MFA is required by default. +# defaultMFAChain: +# - totp-1 + # Instead of reading from an external storage, use this list of clients. # # If this option isn't chosen clients may be added through the gRPC API. @@ -144,6 +163,11 @@ staticClients: # If omitted or empty, all connectors are allowed. # allowedConnectors: # - mock + # Optional: ordered list of MFA authenticator IDs the user must complete during login. + # References authenticator IDs from mfa.authenticators. + # If omitted, mfa.defaultMFAChain is used. + # mfaChain: + # - totp-1 # Example using environment variables # Set DEX_CLIENT_ID and DEX_SECURE_CLIENT_SECRET before starting Dex diff --git a/go.mod b/go.mod index a9f5c352..16509877 100644 --- a/go.mod +++ b/go.mod @@ -28,6 +28,7 @@ require ( github.com/oklog/run v1.2.0 github.com/openbao/openbao/api/v2 v2.5.1 github.com/pkg/errors v0.9.1 + github.com/pquerna/otp v1.5.0 github.com/prometheus/client_golang v1.23.2 github.com/russellhaering/goxmldsig v1.5.0 github.com/spf13/cobra v1.10.2 @@ -58,6 +59,7 @@ require ( github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/bmatcuk/doublestar v1.3.4 // indirect + github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/coreos/go-semver v0.3.1 // indirect diff --git a/go.sum b/go.sum index b8f97985..5e0d8abc 100644 --- a/go.sum +++ b/go.sum @@ -42,6 +42,8 @@ github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/bmatcuk/doublestar v1.3.4 h1:gPypJ5xD31uhX6Tf54sDPUOBXTqKH4c9aPY66CyQrS0= github.com/bmatcuk/doublestar v1.3.4/go.mod h1:wiQtGV+rzVYxB7WIlirSN++5HPtPlXEo9MEoZQC/PmE= +github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI= +github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= @@ -201,6 +203,8 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pquerna/otp v1.5.0 h1:NMMR+WrmaqXU4EzdGJEE1aUUI0AMRzsp96fFFWNPwxs= +github.com/pquerna/otp v1.5.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg= github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o= github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg= github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= diff --git a/server/handlers.go b/server/handlers.go index 20fd85bf..9eada05c 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -760,6 +760,36 @@ func (s *Server) finalizeLogin(ctx context.Context, identity connector.Identity, userIdentity = &ui } + // an HMAC is used here to ensure that the request ID is unpredictable, ensuring that an attacker who intercepted the original + // flow would be unable to poll for the result at the /approval endpoint + h := hmac.New(sha256.New, authReq.HMACKey) + h.Write([]byte(authReq.ID)) + mac := h.Sum(nil) + hmacParam := base64.RawURLEncoding.EncodeToString(mac) + + // Check if the client requires MFA. + mfaChain, err := s.mfaChainForClient(ctx, authReq.ClientID, authReq.ConnectorID) + if err != nil { + return "", false, fmt.Errorf("failed to get MFA chain for client: %v", err) + } + if len(mfaChain) > 0 { + // Redirect to MFA verification for the first authenticator in the chain. + v := url.Values{} + v.Set("req", authReq.ID) + v.Set("hmac", hmacParam) + v.Set("authenticator", mfaChain[0]) + returnURL := path.Join(s.issuerURL.Path, "/totp/verify") + "?" + v.Encode() + return returnURL, false, nil + } + + // No MFA required — mark as validated. + if err := s.storage.UpdateAuthRequest(ctx, authReq.ID, func(a storage.AuthRequest) (storage.AuthRequest, error) { + a.MFAValidated = true + return a, nil + }); err != nil { + return "", false, fmt.Errorf("failed to update auth request MFA status: %v", err) + } + // we can skip the redirect to /approval and go ahead and send code if it's not required if s.skipApproval && !authReq.ForceApprovalPrompt { return "", true, nil @@ -772,13 +802,7 @@ func (s *Server) finalizeLogin(ctx context.Context, identity connector.Identity, } } - // an HMAC is used here to ensure that the request ID is unpredictable, ensuring that an attacker who intercepted the original - // flow would be unable to poll for the result at the /approval endpoint - h := hmac.New(sha256.New, authReq.HMACKey) - h.Write([]byte(authReq.ID)) - mac := h.Sum(nil) - - returnURL := path.Join(s.issuerURL.Path, "/approval") + "?req=" + authReq.ID + "&hmac=" + base64.RawURLEncoding.EncodeToString(mac) + returnURL := path.Join(s.issuerURL.Path, "/approval") + "?req=" + authReq.ID + "&hmac=" + hmacParam return returnURL, false, nil } @@ -810,6 +834,28 @@ func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) { s.renderError(r, w, http.StatusInternalServerError, "Login process not yet finalized.") return } + if !authReq.MFAValidated { + // Check if MFA is actually required — if so, redirect to TOTP instead of blocking. + // This handles the case where MFA was enabled after the auth flow started. + mfaChain, err := s.mfaChainForClient(ctx, authReq.ClientID, authReq.ConnectorID) + if err != nil { + s.logger.ErrorContext(ctx, "failed to get MFA chain", "err", err) + s.renderError(r, w, http.StatusInternalServerError, "Internal server error.") + return + } + if len(mfaChain) > 0 { + h := hmac.New(sha256.New, authReq.HMACKey) + h.Write([]byte(authReq.ID)) + v := url.Values{} + v.Set("req", authReq.ID) + v.Set("hmac", base64.RawURLEncoding.EncodeToString(h.Sum(nil))) + v.Set("authenticator", mfaChain[0]) + totpURL := path.Join(s.issuerURL.Path, "/totp/verify") + "?" + v.Encode() + http.Redirect(w, r, totpURL, http.StatusSeeOther) + return + } + // No MFA required but flag not set — allow through (backward compat). + } // build expected hmac with secret key h := hmac.New(sha256.New, authReq.HMACKey) diff --git a/server/handlers_approval_test.go b/server/handlers_approval_test.go index 5ab80fc5..24a72b83 100644 --- a/server/handlers_approval_test.go +++ b/server/handlers_approval_test.go @@ -84,6 +84,7 @@ func TestHandleApprovalDoubleSubmitPOST(t *testing.T) { RedirectURI: "https://client.example/callback", Expiry: time.Now().Add(time.Minute), LoggedIn: true, + MFAValidated: true, HMACKey: []byte("approval-double-submit-key"), } require.NoError(t, server.storage.CreateAuthRequest(ctx, authReq)) diff --git a/server/mfa.go b/server/mfa.go new file mode 100644 index 00000000..53d2aa67 --- /dev/null +++ b/server/mfa.go @@ -0,0 +1,295 @@ +package server + +import ( + "bytes" + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "fmt" + "image/png" + "net/http" + "strings" + + "github.com/pquerna/otp" + "github.com/pquerna/otp/totp" + + "github.com/dexidp/dex/storage" +) + +// MFAProvider is a pluggable multi-factor authentication method. +type MFAProvider interface { + // Type returns the authenticator type identifier (e.g., "TOTP"). + Type() string + // EnabledForConnectorType returns true if this provider applies to the given connector type. + // If no connector types are configured, the provider applies to all. + EnabledForConnectorType(connectorType string) bool +} + +// TOTPProvider implements TOTP-based multi-factor authentication. +type TOTPProvider struct { + issuer string + connectorTypes map[string]struct{} +} + +// NewTOTPProvider creates a new TOTP MFA provider. +func NewTOTPProvider(issuer string, connectorTypes []string) *TOTPProvider { + m := make(map[string]struct{}, len(connectorTypes)) + for _, t := range connectorTypes { + m[t] = struct{}{} + } + return &TOTPProvider{issuer: issuer, connectorTypes: m} +} + +func (p *TOTPProvider) EnabledForConnectorType(connectorType string) bool { + if len(p.connectorTypes) == 0 { + return true + } + _, ok := p.connectorTypes[connectorType] + return ok +} + +func (p *TOTPProvider) Type() string { return "TOTP" } + +func (p *TOTPProvider) generate(connID, email string) (*otp.Key, error) { + return totp.Generate(totp.GenerateOpts{ + Issuer: p.issuer, + AccountName: fmt.Sprintf("(%s) %s", connID, email), + }) +} + +func (s *Server) handleMFAVerify(w http.ResponseWriter, r *http.Request) { + macEncoded := r.FormValue("hmac") + if macEncoded == "" { + s.renderError(r, w, http.StatusUnauthorized, "Unauthorized request.") + return + } + mac, err := base64.RawURLEncoding.DecodeString(macEncoded) + if err != nil { + s.renderError(r, w, http.StatusUnauthorized, "Unauthorized request.") + return + } + + ctx := r.Context() + + authReq, err := s.storage.GetAuthRequest(ctx, r.FormValue("req")) + if err != nil { + s.logger.ErrorContext(ctx, "failed to get auth request", "err", err) + s.renderError(r, w, http.StatusInternalServerError, "Database error.") + return + } + if !authReq.LoggedIn { + s.logger.ErrorContext(ctx, "auth request does not have an identity for MFA verification") + s.renderError(r, w, http.StatusInternalServerError, "Login process not yet finalized.") + return + } + + // Verify HMAC + h := hmac.New(sha256.New, authReq.HMACKey) + h.Write([]byte(authReq.ID)) + if !hmac.Equal(mac, h.Sum(nil)) { + s.renderError(r, w, http.StatusUnauthorized, "Unauthorized request.") + return + } + + authenticatorID := r.FormValue("authenticator") + provider, ok := s.mfaProviders[authenticatorID] + if !ok { + s.renderError(r, w, http.StatusBadRequest, "Unknown authenticator.") + return + } + + totpProvider, ok := provider.(*TOTPProvider) + if !ok { + s.renderError(r, w, http.StatusInternalServerError, "Unsupported authenticator type.") + return + } + + identity, err := s.storage.GetUserIdentity(ctx, authReq.Claims.UserID, authReq.ConnectorID) + if err != nil { + s.logger.ErrorContext(ctx, "failed to get user identity", "err", err) + s.renderError(r, w, http.StatusInternalServerError, "Database error.") + return + } + + returnURL := strings.Replace(r.URL.String(), "/totp/verify", "/approval", 1) + + if authReq.MFAValidated { + http.Redirect(w, r, returnURL, http.StatusSeeOther) + return + } + + secret := identity.MFASecrets[authenticatorID] + + switch r.Method { + case http.MethodGet: + if secret == nil { + // First-time enrollment: generate a new TOTP key. + generated, err := totpProvider.generate(authReq.ConnectorID, authReq.Claims.Email) + if err != nil { + s.logger.ErrorContext(ctx, "failed to generate TOTP key", "err", err) + s.renderError(r, w, http.StatusInternalServerError, "Internal server error.") + return + } + + secret = &storage.MFASecret{ + AuthenticatorID: authenticatorID, + Type: "TOTP", + Secret: generated.String(), + Confirmed: false, + CreatedAt: s.now(), + } + + if err := s.storage.UpdateUserIdentity(ctx, authReq.Claims.UserID, authReq.ConnectorID, func(old storage.UserIdentity) (storage.UserIdentity, error) { + if old.MFASecrets == nil { + old.MFASecrets = make(map[string]*storage.MFASecret) + } + old.MFASecrets[authenticatorID] = secret + return old, nil + }); err != nil { + s.logger.ErrorContext(ctx, "failed to store MFA secret", "err", err) + s.renderError(r, w, http.StatusInternalServerError, "Internal server error.") + return + } + } + + s.renderTOTPPage(secret, false, totpProvider.issuer, authReq.ConnectorID, w, r) + + case http.MethodPost: + if secret == nil || secret.Secret == "" { + s.renderError(r, w, http.StatusBadRequest, "MFA not enrolled.") + return + } + + code := r.FormValue("totp") + generated, err := otp.NewKeyFromURL(secret.Secret) + if err != nil { + s.logger.ErrorContext(ctx, "failed to load TOTP key", "err", err) + s.renderError(r, w, http.StatusInternalServerError, "Internal server error.") + return + } + + if !totp.Validate(code, generated.Secret()) { + s.renderTOTPPage(secret, true, totpProvider.issuer, authReq.ConnectorID, w, r) + return + } + + // Mark MFA secret as confirmed. + if !secret.Confirmed { + if err := s.storage.UpdateUserIdentity(ctx, authReq.Claims.UserID, authReq.ConnectorID, func(old storage.UserIdentity) (storage.UserIdentity, error) { + if s := old.MFASecrets[authenticatorID]; s != nil { + s.Confirmed = true + } + return old, nil + }); err != nil { + s.logger.ErrorContext(ctx, "failed to confirm MFA secret", "err", err) + s.renderError(r, w, http.StatusInternalServerError, "Internal server error.") + return + } + } + + // Mark auth request as MFA-validated. + if err := s.storage.UpdateAuthRequest(ctx, authReq.ID, func(old storage.AuthRequest) (storage.AuthRequest, error) { + old.MFAValidated = true + return old, nil + }); err != nil { + s.logger.ErrorContext(ctx, "failed to update auth request", "err", err) + s.renderError(r, w, http.StatusInternalServerError, "Internal server error.") + return + } + + // Skip approval if configured. + if s.skipApproval && !authReq.ForceApprovalPrompt { + authReq, err = s.storage.GetAuthRequest(ctx, authReq.ID) + if err != nil { + s.logger.ErrorContext(ctx, "failed to get finalized auth request", "err", err) + s.renderError(r, w, http.StatusInternalServerError, "Login error.") + return + } + s.sendCodeResponse(w, r, authReq) + return + } + + http.Redirect(w, r, returnURL, http.StatusSeeOther) + + default: + s.renderError(r, w, http.StatusBadRequest, "Unsupported request method.") + } +} + +func (s *Server) renderTOTPPage(secret *storage.MFASecret, lastFail bool, issuer, connectorID string, w http.ResponseWriter, r *http.Request) { + var qrCode string + if !secret.Confirmed { + var err error + qrCode, err = generateTOTPQRCode(secret.Secret) + if err != nil { + s.logger.ErrorContext(r.Context(), "failed to generate QR code", "err", err) + s.renderError(r, w, http.StatusInternalServerError, "Internal server error.") + return + } + } + if err := s.templates.totpVerify(r, w, r.URL.String(), issuer, connectorID, qrCode, lastFail); err != nil { + s.logger.ErrorContext(r.Context(), "server template error", "err", err) + } +} + +func generateTOTPQRCode(keyURL string) (string, error) { + generated, err := otp.NewKeyFromURL(keyURL) + if err != nil { + return "", fmt.Errorf("failed to load TOTP key: %w", err) + } + + qrCodeImage, err := generated.Image(300, 300) + if err != nil { + return "", fmt.Errorf("failed to generate TOTP QR code: %w", err) + } + + var buf bytes.Buffer + if err := png.Encode(&buf, qrCodeImage); err != nil { + return "", fmt.Errorf("failed to encode TOTP QR code: %w", err) + } + + return base64.StdEncoding.EncodeToString(buf.Bytes()), nil +} + +// mfaChainForClient returns the MFA chain for a client filtered by connector type, +// falling back to the server's defaultMFAChain if the client has none. +// Returns nil if no MFA is configured/applicable. +func (s *Server) mfaChainForClient(ctx context.Context, clientID, connectorID string) ([]string, error) { + if len(s.mfaProviders) == 0 { + return nil, nil + } + + client, err := s.storage.GetClient(ctx, clientID) + if err != nil { + return nil, err + } + + // nil means "not set" — fall back to default. + // Explicit empty slice ([]string{}) means "no MFA" — don't fall back. + source := client.MFAChain + if source == nil { + source = s.defaultMFAChain + } + + // Resolve connector type from connector ID. + connectorType := s.getConnectorType(connectorID) + + var chain []string + for _, authID := range source { + provider, ok := s.mfaProviders[authID] + if ok && provider.EnabledForConnectorType(connectorType) { + chain = append(chain, authID) + } + } + return chain, nil +} + +// getConnectorType returns the type of the connector with the given ID. +func (s *Server) getConnectorType(connectorID string) string { + conn, err := s.storage.GetConnector(context.Background(), connectorID) + if err != nil { + return "" + } + return conn.Type +} diff --git a/server/server.go b/server/server.go index e63cb278..83548e41 100644 --- a/server/server.go +++ b/server/server.go @@ -137,6 +137,12 @@ type Config struct { // If enabled, the server will continue starting even if some connectors fail to initialize. // This allows the server to operate with a subset of connectors if some are misconfigured. ContinueOnConnectorFailure bool + + // MFAProviders maps authenticator IDs to their provider implementations. + MFAProviders map[string]MFAProvider + + // DefaultMFAChain is applied to clients that don't specify their own mfaChain. + DefaultMFAChain []string } // WebConfig holds the server's frontend templates and asset configuration. @@ -226,6 +232,9 @@ type Server struct { logger *slog.Logger signer signer.Signer + + mfaProviders map[string]MFAProvider + defaultMFAChain []string } // NewServer constructs a server from the provided config. @@ -349,6 +358,8 @@ func newServer(ctx context.Context, c Config) (*Server, error) { passwordConnector: c.PasswordConnector, logger: c.Logger, signer: c.Signer, + mfaProviders: c.MFAProviders, + defaultMFAChain: c.DefaultMFAChain, } // Retrieves connector objects in backend storage. This list includes the static connectors @@ -537,6 +548,7 @@ func newServer(ctx context.Context, c Config) (*Server, error) { // "authproxy" connector. handleFunc("/callback/{connector}", s.handleConnectorCallback) handleFunc("/approval", s.handleApproval) + handleFunc("/totp/verify", s.handleMFAVerify) handle("/healthz", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if !c.HealthChecker.IsHealthy() { s.renderError(r, w, http.StatusInternalServerError, "Health check failed.") diff --git a/server/templates.go b/server/templates.go index b77663e1..ca831e5f 100644 --- a/server/templates.go +++ b/server/templates.go @@ -22,6 +22,7 @@ const ( tmplError = "error.html" tmplDevice = "device.html" tmplDeviceSuccess = "device_success.html" + tmplTOTPVerify = "totp_verify.html" ) var requiredTmpls = []string{ @@ -42,6 +43,7 @@ type templates struct { errorTmpl *template.Template deviceTmpl *template.Template deviceSuccessTmpl *template.Template + totpVerifyTmpl *template.Template } type webConfig struct { @@ -169,6 +171,7 @@ func loadTemplates(c webConfig, templatesDir string) (*templates, error) { errorTmpl: tmpls.Lookup(tmplError), deviceTmpl: tmpls.Lookup(tmplDevice), deviceSuccessTmpl: tmpls.Lookup(tmplDeviceSuccess), + totpVerifyTmpl: tmpls.Lookup(tmplTOTPVerify), }, nil } @@ -325,6 +328,21 @@ func (t *templates) approval(r *http.Request, w http.ResponseWriter, authReqID, return renderTemplate(w, t.approvalTmpl, data) } +func (t *templates) totpVerify(r *http.Request, w http.ResponseWriter, postURL, issuer, connector, qrCode string, lastWasInvalid bool) error { + if lastWasInvalid { + w.WriteHeader(http.StatusUnauthorized) + } + data := struct { + PostURL string + Invalid bool + Issuer string + Connector string + QRCode string + ReqPath string + }{postURL, lastWasInvalid, issuer, connector, qrCode, r.URL.Path} + return renderTemplate(w, t.totpVerifyTmpl, data) +} + func (t *templates) oob(r *http.Request, w http.ResponseWriter, code string) error { data := struct { Code string diff --git a/storage/ent/client/authrequest.go b/storage/ent/client/authrequest.go index 25d3e415..86c71056 100644 --- a/storage/ent/client/authrequest.go +++ b/storage/ent/client/authrequest.go @@ -32,6 +32,7 @@ func (d *Database) CreateAuthRequest(ctx context.Context, authRequest storage.Au SetConnectorID(authRequest.ConnectorID). SetConnectorData(authRequest.ConnectorData). SetHmacKey(authRequest.HMACKey). + SetMfaValidated(authRequest.MFAValidated). Save(ctx) if err != nil { return convertDBError("create auth request: %w", err) @@ -96,6 +97,7 @@ func (d *Database) UpdateAuthRequest(ctx context.Context, id string, updater fun SetConnectorID(newAuthRequest.ConnectorID). SetConnectorData(newAuthRequest.ConnectorData). SetHmacKey(newAuthRequest.HMACKey). + SetMfaValidated(newAuthRequest.MFAValidated). Save(context.TODO()) if err != nil { return rollback(tx, "update auth request uploading: %w", err) diff --git a/storage/ent/client/client.go b/storage/ent/client/client.go index a4f0d942..4c18939d 100644 --- a/storage/ent/client/client.go +++ b/storage/ent/client/client.go @@ -17,6 +17,7 @@ func (d *Database) CreateClient(ctx context.Context, client storage.Client) erro SetRedirectUris(client.RedirectURIs). SetTrustedPeers(client.TrustedPeers). SetAllowedConnectors(client.AllowedConnectors). + SetMfaChain(client.MFAChain). Save(ctx) if err != nil { return convertDBError("create oauth2 client: %w", err) @@ -81,6 +82,7 @@ func (d *Database) UpdateClient(ctx context.Context, id string, updater func(old SetRedirectUris(newClient.RedirectURIs). SetTrustedPeers(newClient.TrustedPeers). SetAllowedConnectors(newClient.AllowedConnectors). + SetMfaChain(newClient.MFAChain). Save(ctx) if err != nil { return rollback(tx, "update client uploading: %w", err) diff --git a/storage/ent/client/types.go b/storage/ent/client/types.go index f8e99c4a..2b3fda73 100644 --- a/storage/ent/client/types.go +++ b/storage/ent/client/types.go @@ -45,7 +45,8 @@ func toStorageAuthRequest(a *db.AuthRequest) storage.AuthRequest { CodeChallenge: a.CodeChallenge, CodeChallengeMethod: a.CodeChallengeMethod, }, - HMACKey: a.HmacKey, + HMACKey: a.HmacKey, + MFAValidated: a.MfaValidated, } } @@ -84,6 +85,7 @@ func toStorageClient(c *db.OAuth2Client) storage.Client { Name: c.Name, LogoURL: c.LogoURL, AllowedConnectors: c.AllowedConnectors, + MFAChain: c.MfaChain, } } @@ -193,6 +195,18 @@ func toStorageUserIdentity(u *db.UserIdentity) storage.UserIdentity { // Server code assumes this will be non-nil. s.Consents = make(map[string][]string) } + + if u.MfaSecrets != nil { + if err := json.Unmarshal(*u.MfaSecrets, &s.MFASecrets); err != nil { + // Correctness of json structure is guaranteed on uploading + panic(err) + } + if s.MFASecrets == nil { + s.MFASecrets = make(map[string]*storage.MFASecret) + } + } else { + s.MFASecrets = make(map[string]*storage.MFASecret) + } return s } diff --git a/storage/ent/client/useridentity.go b/storage/ent/client/useridentity.go index 1cf87919..31e9d049 100644 --- a/storage/ent/client/useridentity.go +++ b/storage/ent/client/useridentity.go @@ -18,6 +18,14 @@ func (d *Database) CreateUserIdentity(ctx context.Context, identity storage.User return fmt.Errorf("encode consents user identity: %w", err) } + if identity.MFASecrets == nil { + identity.MFASecrets = make(map[string]*storage.MFASecret) + } + encodedMFASecrets, err := json.Marshal(identity.MFASecrets) + if err != nil { + return fmt.Errorf("encode mfa secrets user identity: %w", err) + } + id := compositeKeyID(identity.UserID, identity.ConnectorID, d.hasher) _, err = d.client.UserIdentity.Create(). SetID(id). @@ -30,6 +38,7 @@ func (d *Database) CreateUserIdentity(ctx context.Context, identity storage.User SetClaimsEmailVerified(identity.Claims.EmailVerified). SetClaimsGroups(identity.Claims.Groups). SetConsents(encodedConsents). + SetMfaSecrets(encodedMFASecrets). SetCreatedAt(identity.CreatedAt). SetLastLogin(identity.LastLogin). SetBlockedUntil(identity.BlockedUntil). @@ -90,6 +99,15 @@ func (d *Database) UpdateUserIdentity(ctx context.Context, userID string, connec return rollback(tx, "encode consents user identity: %w", err) } + if newUserIdentity.MFASecrets == nil { + newUserIdentity.MFASecrets = make(map[string]*storage.MFASecret) + } + + encodedMFASecrets, err := json.Marshal(newUserIdentity.MFASecrets) + if err != nil { + return rollback(tx, "encode mfa secrets user identity: %w", err) + } + _, err = tx.UserIdentity.UpdateOneID(id). SetUserID(newUserIdentity.UserID). SetConnectorID(newUserIdentity.ConnectorID). @@ -100,6 +118,7 @@ func (d *Database) UpdateUserIdentity(ctx context.Context, userID string, connec SetClaimsEmailVerified(newUserIdentity.Claims.EmailVerified). SetClaimsGroups(newUserIdentity.Claims.Groups). SetConsents(encodedConsents). + SetMfaSecrets(encodedMFASecrets). SetCreatedAt(newUserIdentity.CreatedAt). SetLastLogin(newUserIdentity.LastLogin). SetBlockedUntil(newUserIdentity.BlockedUntil). diff --git a/storage/ent/db/authrequest.go b/storage/ent/db/authrequest.go index ac5b550a..02ead496 100644 --- a/storage/ent/db/authrequest.go +++ b/storage/ent/db/authrequest.go @@ -57,7 +57,9 @@ type AuthRequest struct { // CodeChallengeMethod holds the value of the "code_challenge_method" field. CodeChallengeMethod string `json:"code_challenge_method,omitempty"` // HmacKey holds the value of the "hmac_key" field. - HmacKey []byte `json:"hmac_key,omitempty"` + HmacKey []byte `json:"hmac_key,omitempty"` + // MfaValidated holds the value of the "mfa_validated" field. + MfaValidated bool `json:"mfa_validated,omitempty"` selectValues sql.SelectValues } @@ -68,7 +70,7 @@ func (*AuthRequest) scanValues(columns []string) ([]any, error) { switch columns[i] { case authrequest.FieldScopes, authrequest.FieldResponseTypes, authrequest.FieldClaimsGroups, authrequest.FieldConnectorData, authrequest.FieldHmacKey: values[i] = new([]byte) - case authrequest.FieldForceApprovalPrompt, authrequest.FieldLoggedIn, authrequest.FieldClaimsEmailVerified: + case authrequest.FieldForceApprovalPrompt, authrequest.FieldLoggedIn, authrequest.FieldClaimsEmailVerified, authrequest.FieldMfaValidated: values[i] = new(sql.NullBool) case authrequest.FieldID, authrequest.FieldClientID, authrequest.FieldRedirectURI, authrequest.FieldNonce, authrequest.FieldState, authrequest.FieldClaimsUserID, authrequest.FieldClaimsUsername, authrequest.FieldClaimsEmail, authrequest.FieldClaimsPreferredUsername, authrequest.FieldConnectorID, authrequest.FieldCodeChallenge, authrequest.FieldCodeChallengeMethod: values[i] = new(sql.NullString) @@ -221,6 +223,12 @@ func (_m *AuthRequest) assignValues(columns []string, values []any) error { } else if value != nil { _m.HmacKey = *value } + case authrequest.FieldMfaValidated: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field mfa_validated", values[i]) + } else if value.Valid { + _m.MfaValidated = value.Bool + } default: _m.selectValues.Set(columns[i], values[i]) } @@ -318,6 +326,9 @@ func (_m *AuthRequest) String() string { builder.WriteString(", ") builder.WriteString("hmac_key=") builder.WriteString(fmt.Sprintf("%v", _m.HmacKey)) + builder.WriteString(", ") + builder.WriteString("mfa_validated=") + builder.WriteString(fmt.Sprintf("%v", _m.MfaValidated)) builder.WriteByte(')') return builder.String() } diff --git a/storage/ent/db/authrequest/authrequest.go b/storage/ent/db/authrequest/authrequest.go index 0998c799..4fda6633 100644 --- a/storage/ent/db/authrequest/authrequest.go +++ b/storage/ent/db/authrequest/authrequest.go @@ -51,6 +51,8 @@ const ( FieldCodeChallengeMethod = "code_challenge_method" // FieldHmacKey holds the string denoting the hmac_key field in the database. FieldHmacKey = "hmac_key" + // FieldMfaValidated holds the string denoting the mfa_validated field in the database. + FieldMfaValidated = "mfa_validated" // Table holds the table name of the authrequest in the database. Table = "auth_requests" ) @@ -78,6 +80,7 @@ var Columns = []string{ FieldCodeChallenge, FieldCodeChallengeMethod, FieldHmacKey, + FieldMfaValidated, } // ValidColumn reports if the column name is valid (part of the table columns). @@ -97,6 +100,8 @@ var ( DefaultCodeChallenge string // DefaultCodeChallengeMethod holds the default value on creation for the "code_challenge_method" field. DefaultCodeChallengeMethod string + // DefaultMfaValidated holds the default value on creation for the "mfa_validated" field. + DefaultMfaValidated bool // IDValidator is a validator for the "id" field. It is called by the builders before save. IDValidator func(string) error ) @@ -183,3 +188,8 @@ func ByCodeChallenge(opts ...sql.OrderTermOption) OrderOption { func ByCodeChallengeMethod(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldCodeChallengeMethod, opts...).ToFunc() } + +// ByMfaValidated orders the results by the mfa_validated field. +func ByMfaValidated(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldMfaValidated, opts...).ToFunc() +} diff --git a/storage/ent/db/authrequest/where.go b/storage/ent/db/authrequest/where.go index 4d3a39be..2f679bb3 100644 --- a/storage/ent/db/authrequest/where.go +++ b/storage/ent/db/authrequest/where.go @@ -149,6 +149,11 @@ func HmacKey(v []byte) predicate.AuthRequest { return predicate.AuthRequest(sql.FieldEQ(FieldHmacKey, v)) } +// MfaValidated applies equality check predicate on the "mfa_validated" field. It's identical to MfaValidatedEQ. +func MfaValidated(v bool) predicate.AuthRequest { + return predicate.AuthRequest(sql.FieldEQ(FieldMfaValidated, v)) +} + // ClientIDEQ applies the EQ predicate on the "client_id" field. func ClientIDEQ(v string) predicate.AuthRequest { return predicate.AuthRequest(sql.FieldEQ(FieldClientID, v)) @@ -1054,6 +1059,16 @@ func HmacKeyLTE(v []byte) predicate.AuthRequest { return predicate.AuthRequest(sql.FieldLTE(FieldHmacKey, v)) } +// MfaValidatedEQ applies the EQ predicate on the "mfa_validated" field. +func MfaValidatedEQ(v bool) predicate.AuthRequest { + return predicate.AuthRequest(sql.FieldEQ(FieldMfaValidated, v)) +} + +// MfaValidatedNEQ applies the NEQ predicate on the "mfa_validated" field. +func MfaValidatedNEQ(v bool) predicate.AuthRequest { + return predicate.AuthRequest(sql.FieldNEQ(FieldMfaValidated, v)) +} + // And groups predicates with the AND operator between them. func And(predicates ...predicate.AuthRequest) predicate.AuthRequest { return predicate.AuthRequest(sql.AndPredicates(predicates...)) diff --git a/storage/ent/db/authrequest_create.go b/storage/ent/db/authrequest_create.go index 6224ef8e..c36648a1 100644 --- a/storage/ent/db/authrequest_create.go +++ b/storage/ent/db/authrequest_create.go @@ -164,6 +164,20 @@ func (_c *AuthRequestCreate) SetHmacKey(v []byte) *AuthRequestCreate { return _c } +// SetMfaValidated sets the "mfa_validated" field. +func (_c *AuthRequestCreate) SetMfaValidated(v bool) *AuthRequestCreate { + _c.mutation.SetMfaValidated(v) + return _c +} + +// SetNillableMfaValidated sets the "mfa_validated" field if the given value is not nil. +func (_c *AuthRequestCreate) SetNillableMfaValidated(v *bool) *AuthRequestCreate { + if v != nil { + _c.SetMfaValidated(*v) + } + return _c +} + // SetID sets the "id" field. func (_c *AuthRequestCreate) SetID(v string) *AuthRequestCreate { _c.mutation.SetID(v) @@ -217,6 +231,10 @@ func (_c *AuthRequestCreate) defaults() { v := authrequest.DefaultCodeChallengeMethod _c.mutation.SetCodeChallengeMethod(v) } + if _, ok := _c.mutation.MfaValidated(); !ok { + v := authrequest.DefaultMfaValidated + _c.mutation.SetMfaValidated(v) + } } // check runs all checks and user-defined validators on the builder. @@ -269,6 +287,9 @@ func (_c *AuthRequestCreate) check() error { if _, ok := _c.mutation.HmacKey(); !ok { return &ValidationError{Name: "hmac_key", err: errors.New(`db: missing required field "AuthRequest.hmac_key"`)} } + if _, ok := _c.mutation.MfaValidated(); !ok { + return &ValidationError{Name: "mfa_validated", err: errors.New(`db: missing required field "AuthRequest.mfa_validated"`)} + } if v, ok := _c.mutation.ID(); ok { if err := authrequest.IDValidator(v); err != nil { return &ValidationError{Name: "id", err: fmt.Errorf(`db: validator failed for field "AuthRequest.id": %w`, err)} @@ -389,6 +410,10 @@ func (_c *AuthRequestCreate) createSpec() (*AuthRequest, *sqlgraph.CreateSpec) { _spec.SetField(authrequest.FieldHmacKey, field.TypeBytes, value) _node.HmacKey = value } + if value, ok := _c.mutation.MfaValidated(); ok { + _spec.SetField(authrequest.FieldMfaValidated, field.TypeBool, value) + _node.MfaValidated = value + } return _node, _spec } diff --git a/storage/ent/db/authrequest_update.go b/storage/ent/db/authrequest_update.go index e1fa678a..e512b2d9 100644 --- a/storage/ent/db/authrequest_update.go +++ b/storage/ent/db/authrequest_update.go @@ -311,6 +311,20 @@ func (_u *AuthRequestUpdate) SetHmacKey(v []byte) *AuthRequestUpdate { return _u } +// SetMfaValidated sets the "mfa_validated" field. +func (_u *AuthRequestUpdate) SetMfaValidated(v bool) *AuthRequestUpdate { + _u.mutation.SetMfaValidated(v) + return _u +} + +// SetNillableMfaValidated sets the "mfa_validated" field if the given value is not nil. +func (_u *AuthRequestUpdate) SetNillableMfaValidated(v *bool) *AuthRequestUpdate { + if v != nil { + _u.SetMfaValidated(*v) + } + return _u +} + // Mutation returns the AuthRequestMutation object of the builder. func (_u *AuthRequestUpdate) Mutation() *AuthRequestMutation { return _u.mutation @@ -439,6 +453,9 @@ func (_u *AuthRequestUpdate) sqlSave(ctx context.Context) (_node int, err error) if value, ok := _u.mutation.HmacKey(); ok { _spec.SetField(authrequest.FieldHmacKey, field.TypeBytes, value) } + if value, ok := _u.mutation.MfaValidated(); ok { + _spec.SetField(authrequest.FieldMfaValidated, field.TypeBool, value) + } if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{authrequest.Label} @@ -741,6 +758,20 @@ func (_u *AuthRequestUpdateOne) SetHmacKey(v []byte) *AuthRequestUpdateOne { return _u } +// SetMfaValidated sets the "mfa_validated" field. +func (_u *AuthRequestUpdateOne) SetMfaValidated(v bool) *AuthRequestUpdateOne { + _u.mutation.SetMfaValidated(v) + return _u +} + +// SetNillableMfaValidated sets the "mfa_validated" field if the given value is not nil. +func (_u *AuthRequestUpdateOne) SetNillableMfaValidated(v *bool) *AuthRequestUpdateOne { + if v != nil { + _u.SetMfaValidated(*v) + } + return _u +} + // Mutation returns the AuthRequestMutation object of the builder. func (_u *AuthRequestUpdateOne) Mutation() *AuthRequestMutation { return _u.mutation @@ -899,6 +930,9 @@ func (_u *AuthRequestUpdateOne) sqlSave(ctx context.Context) (_node *AuthRequest if value, ok := _u.mutation.HmacKey(); ok { _spec.SetField(authrequest.FieldHmacKey, field.TypeBytes, value) } + if value, ok := _u.mutation.MfaValidated(); ok { + _spec.SetField(authrequest.FieldMfaValidated, field.TypeBool, value) + } _node = &AuthRequest{config: _u.config} _spec.Assign = _node.assignValues _spec.ScanValues = _node.scanValues diff --git a/storage/ent/db/migrate/schema.go b/storage/ent/db/migrate/schema.go index 786598c0..b9ee6229 100644 --- a/storage/ent/db/migrate/schema.go +++ b/storage/ent/db/migrate/schema.go @@ -56,6 +56,7 @@ var ( {Name: "code_challenge", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, {Name: "code_challenge_method", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, {Name: "hmac_key", Type: field.TypeBytes}, + {Name: "mfa_validated", Type: field.TypeBool, Default: false}, } // AuthRequestsTable holds the schema information for the "auth_requests" table. AuthRequestsTable = &schema.Table{ @@ -151,6 +152,7 @@ var ( {Name: "name", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, {Name: "logo_url", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, {Name: "allowed_connectors", Type: field.TypeJSON, Nullable: true}, + {Name: "mfa_chain", Type: field.TypeJSON, Nullable: true}, } // Oauth2clientsTable holds the schema information for the "oauth2clients" table. Oauth2clientsTable = &schema.Table{ @@ -227,6 +229,7 @@ var ( {Name: "claims_email_verified", Type: field.TypeBool, Default: false}, {Name: "claims_groups", Type: field.TypeJSON, Nullable: true}, {Name: "consents", Type: field.TypeBytes}, + {Name: "mfa_secrets", Type: field.TypeBytes, Nullable: true}, {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime(3)", "postgres": "timestamptz", "sqlite3": "timestamp"}}, {Name: "last_login", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime(3)", "postgres": "timestamptz", "sqlite3": "timestamp"}}, {Name: "blocked_until", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime(3)", "postgres": "timestamptz", "sqlite3": "timestamp"}}, diff --git a/storage/ent/db/mutation.go b/storage/ent/db/mutation.go index 748022c9..882e6fd5 100644 --- a/storage/ent/db/mutation.go +++ b/storage/ent/db/mutation.go @@ -1262,6 +1262,7 @@ type AuthRequestMutation struct { code_challenge *string code_challenge_method *string hmac_key *[]byte + mfa_validated *bool clearedFields map[string]struct{} done bool oldValue func(context.Context) (*AuthRequest, error) @@ -2192,6 +2193,42 @@ func (m *AuthRequestMutation) ResetHmacKey() { m.hmac_key = nil } +// SetMfaValidated sets the "mfa_validated" field. +func (m *AuthRequestMutation) SetMfaValidated(b bool) { + m.mfa_validated = &b +} + +// MfaValidated returns the value of the "mfa_validated" field in the mutation. +func (m *AuthRequestMutation) MfaValidated() (r bool, exists bool) { + v := m.mfa_validated + if v == nil { + return + } + return *v, true +} + +// OldMfaValidated returns the old "mfa_validated" field's value of the AuthRequest entity. +// If the AuthRequest object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthRequestMutation) OldMfaValidated(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldMfaValidated is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldMfaValidated requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldMfaValidated: %w", err) + } + return oldValue.MfaValidated, nil +} + +// ResetMfaValidated resets all changes to the "mfa_validated" field. +func (m *AuthRequestMutation) ResetMfaValidated() { + m.mfa_validated = nil +} + // Where appends a list predicates to the AuthRequestMutation builder. func (m *AuthRequestMutation) Where(ps ...predicate.AuthRequest) { m.predicates = append(m.predicates, ps...) @@ -2226,7 +2263,7 @@ func (m *AuthRequestMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *AuthRequestMutation) Fields() []string { - fields := make([]string, 0, 20) + fields := make([]string, 0, 21) if m.client_id != nil { fields = append(fields, authrequest.FieldClientID) } @@ -2287,6 +2324,9 @@ func (m *AuthRequestMutation) Fields() []string { if m.hmac_key != nil { fields = append(fields, authrequest.FieldHmacKey) } + if m.mfa_validated != nil { + fields = append(fields, authrequest.FieldMfaValidated) + } return fields } @@ -2335,6 +2375,8 @@ func (m *AuthRequestMutation) Field(name string) (ent.Value, bool) { return m.CodeChallengeMethod() case authrequest.FieldHmacKey: return m.HmacKey() + case authrequest.FieldMfaValidated: + return m.MfaValidated() } return nil, false } @@ -2384,6 +2426,8 @@ func (m *AuthRequestMutation) OldField(ctx context.Context, name string) (ent.Va return m.OldCodeChallengeMethod(ctx) case authrequest.FieldHmacKey: return m.OldHmacKey(ctx) + case authrequest.FieldMfaValidated: + return m.OldMfaValidated(ctx) } return nil, fmt.Errorf("unknown AuthRequest field %s", name) } @@ -2533,6 +2577,13 @@ func (m *AuthRequestMutation) SetField(name string, value ent.Value) error { } m.SetHmacKey(v) return nil + case authrequest.FieldMfaValidated: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMfaValidated(v) + return nil } return fmt.Errorf("unknown AuthRequest field %s", name) } @@ -2669,6 +2720,9 @@ func (m *AuthRequestMutation) ResetField(name string) error { case authrequest.FieldHmacKey: m.ResetHmacKey() return nil + case authrequest.FieldMfaValidated: + m.ResetMfaValidated() + return nil } return fmt.Errorf("unknown AuthRequest field %s", name) } @@ -5779,6 +5833,8 @@ type OAuth2ClientMutation struct { logo_url *string allowed_connectors *[]string appendallowed_connectors []string + mfa_chain *[]string + appendmfa_chain []string clearedFields map[string]struct{} done bool oldValue func(context.Context) (*OAuth2Client, error) @@ -6228,6 +6284,71 @@ func (m *OAuth2ClientMutation) ResetAllowedConnectors() { delete(m.clearedFields, oauth2client.FieldAllowedConnectors) } +// SetMfaChain sets the "mfa_chain" field. +func (m *OAuth2ClientMutation) SetMfaChain(s []string) { + m.mfa_chain = &s + m.appendmfa_chain = nil +} + +// MfaChain returns the value of the "mfa_chain" field in the mutation. +func (m *OAuth2ClientMutation) MfaChain() (r []string, exists bool) { + v := m.mfa_chain + if v == nil { + return + } + return *v, true +} + +// OldMfaChain returns the old "mfa_chain" field's value of the OAuth2Client entity. +// If the OAuth2Client object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *OAuth2ClientMutation) OldMfaChain(ctx context.Context) (v []string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldMfaChain is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldMfaChain requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldMfaChain: %w", err) + } + return oldValue.MfaChain, nil +} + +// AppendMfaChain adds s to the "mfa_chain" field. +func (m *OAuth2ClientMutation) AppendMfaChain(s []string) { + m.appendmfa_chain = append(m.appendmfa_chain, s...) +} + +// AppendedMfaChain returns the list of values that were appended to the "mfa_chain" field in this mutation. +func (m *OAuth2ClientMutation) AppendedMfaChain() ([]string, bool) { + if len(m.appendmfa_chain) == 0 { + return nil, false + } + return m.appendmfa_chain, true +} + +// ClearMfaChain clears the value of the "mfa_chain" field. +func (m *OAuth2ClientMutation) ClearMfaChain() { + m.mfa_chain = nil + m.appendmfa_chain = nil + m.clearedFields[oauth2client.FieldMfaChain] = struct{}{} +} + +// MfaChainCleared returns if the "mfa_chain" field was cleared in this mutation. +func (m *OAuth2ClientMutation) MfaChainCleared() bool { + _, ok := m.clearedFields[oauth2client.FieldMfaChain] + return ok +} + +// ResetMfaChain resets all changes to the "mfa_chain" field. +func (m *OAuth2ClientMutation) ResetMfaChain() { + m.mfa_chain = nil + m.appendmfa_chain = nil + delete(m.clearedFields, oauth2client.FieldMfaChain) +} + // Where appends a list predicates to the OAuth2ClientMutation builder. func (m *OAuth2ClientMutation) Where(ps ...predicate.OAuth2Client) { m.predicates = append(m.predicates, ps...) @@ -6262,7 +6383,7 @@ func (m *OAuth2ClientMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *OAuth2ClientMutation) Fields() []string { - fields := make([]string, 0, 7) + fields := make([]string, 0, 8) if m.secret != nil { fields = append(fields, oauth2client.FieldSecret) } @@ -6284,6 +6405,9 @@ func (m *OAuth2ClientMutation) Fields() []string { if m.allowed_connectors != nil { fields = append(fields, oauth2client.FieldAllowedConnectors) } + if m.mfa_chain != nil { + fields = append(fields, oauth2client.FieldMfaChain) + } return fields } @@ -6306,6 +6430,8 @@ func (m *OAuth2ClientMutation) Field(name string) (ent.Value, bool) { return m.LogoURL() case oauth2client.FieldAllowedConnectors: return m.AllowedConnectors() + case oauth2client.FieldMfaChain: + return m.MfaChain() } return nil, false } @@ -6329,6 +6455,8 @@ func (m *OAuth2ClientMutation) OldField(ctx context.Context, name string) (ent.V return m.OldLogoURL(ctx) case oauth2client.FieldAllowedConnectors: return m.OldAllowedConnectors(ctx) + case oauth2client.FieldMfaChain: + return m.OldMfaChain(ctx) } return nil, fmt.Errorf("unknown OAuth2Client field %s", name) } @@ -6387,6 +6515,13 @@ func (m *OAuth2ClientMutation) SetField(name string, value ent.Value) error { } m.SetAllowedConnectors(v) return nil + case oauth2client.FieldMfaChain: + v, ok := value.([]string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMfaChain(v) + return nil } return fmt.Errorf("unknown OAuth2Client field %s", name) } @@ -6426,6 +6561,9 @@ func (m *OAuth2ClientMutation) ClearedFields() []string { if m.FieldCleared(oauth2client.FieldAllowedConnectors) { fields = append(fields, oauth2client.FieldAllowedConnectors) } + if m.FieldCleared(oauth2client.FieldMfaChain) { + fields = append(fields, oauth2client.FieldMfaChain) + } return fields } @@ -6449,6 +6587,9 @@ func (m *OAuth2ClientMutation) ClearField(name string) error { case oauth2client.FieldAllowedConnectors: m.ClearAllowedConnectors() return nil + case oauth2client.FieldMfaChain: + m.ClearMfaChain() + return nil } return fmt.Errorf("unknown OAuth2Client nullable field %s", name) } @@ -6478,6 +6619,9 @@ func (m *OAuth2ClientMutation) ResetField(name string) error { case oauth2client.FieldAllowedConnectors: m.ResetAllowedConnectors() return nil + case oauth2client.FieldMfaChain: + m.ResetMfaChain() + return nil } return fmt.Errorf("unknown OAuth2Client field %s", name) } @@ -9006,6 +9150,7 @@ type UserIdentityMutation struct { claims_groups *[]string appendclaims_groups []string consents *[]byte + mfa_secrets *[]byte created_at *time.Time last_login *time.Time blocked_until *time.Time @@ -9472,6 +9617,55 @@ func (m *UserIdentityMutation) ResetConsents() { m.consents = nil } +// SetMfaSecrets sets the "mfa_secrets" field. +func (m *UserIdentityMutation) SetMfaSecrets(b []byte) { + m.mfa_secrets = &b +} + +// MfaSecrets returns the value of the "mfa_secrets" field in the mutation. +func (m *UserIdentityMutation) MfaSecrets() (r []byte, exists bool) { + v := m.mfa_secrets + if v == nil { + return + } + return *v, true +} + +// OldMfaSecrets returns the old "mfa_secrets" field's value of the UserIdentity entity. +// If the UserIdentity object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserIdentityMutation) OldMfaSecrets(ctx context.Context) (v *[]byte, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldMfaSecrets is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldMfaSecrets requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldMfaSecrets: %w", err) + } + return oldValue.MfaSecrets, nil +} + +// ClearMfaSecrets clears the value of the "mfa_secrets" field. +func (m *UserIdentityMutation) ClearMfaSecrets() { + m.mfa_secrets = nil + m.clearedFields[useridentity.FieldMfaSecrets] = struct{}{} +} + +// MfaSecretsCleared returns if the "mfa_secrets" field was cleared in this mutation. +func (m *UserIdentityMutation) MfaSecretsCleared() bool { + _, ok := m.clearedFields[useridentity.FieldMfaSecrets] + return ok +} + +// ResetMfaSecrets resets all changes to the "mfa_secrets" field. +func (m *UserIdentityMutation) ResetMfaSecrets() { + m.mfa_secrets = nil + delete(m.clearedFields, useridentity.FieldMfaSecrets) +} + // SetCreatedAt sets the "created_at" field. func (m *UserIdentityMutation) SetCreatedAt(t time.Time) { m.created_at = &t @@ -9614,7 +9808,7 @@ func (m *UserIdentityMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *UserIdentityMutation) Fields() []string { - fields := make([]string, 0, 12) + fields := make([]string, 0, 13) if m.user_id != nil { fields = append(fields, useridentity.FieldUserID) } @@ -9642,6 +9836,9 @@ func (m *UserIdentityMutation) Fields() []string { if m.consents != nil { fields = append(fields, useridentity.FieldConsents) } + if m.mfa_secrets != nil { + fields = append(fields, useridentity.FieldMfaSecrets) + } if m.created_at != nil { fields = append(fields, useridentity.FieldCreatedAt) } @@ -9677,6 +9874,8 @@ func (m *UserIdentityMutation) Field(name string) (ent.Value, bool) { return m.ClaimsGroups() case useridentity.FieldConsents: return m.Consents() + case useridentity.FieldMfaSecrets: + return m.MfaSecrets() case useridentity.FieldCreatedAt: return m.CreatedAt() case useridentity.FieldLastLogin: @@ -9710,6 +9909,8 @@ func (m *UserIdentityMutation) OldField(ctx context.Context, name string) (ent.V return m.OldClaimsGroups(ctx) case useridentity.FieldConsents: return m.OldConsents(ctx) + case useridentity.FieldMfaSecrets: + return m.OldMfaSecrets(ctx) case useridentity.FieldCreatedAt: return m.OldCreatedAt(ctx) case useridentity.FieldLastLogin: @@ -9788,6 +9989,13 @@ func (m *UserIdentityMutation) SetField(name string, value ent.Value) error { } m.SetConsents(v) return nil + case useridentity.FieldMfaSecrets: + v, ok := value.([]byte) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMfaSecrets(v) + return nil case useridentity.FieldCreatedAt: v, ok := value.(time.Time) if !ok { @@ -9842,6 +10050,9 @@ func (m *UserIdentityMutation) ClearedFields() []string { if m.FieldCleared(useridentity.FieldClaimsGroups) { fields = append(fields, useridentity.FieldClaimsGroups) } + if m.FieldCleared(useridentity.FieldMfaSecrets) { + fields = append(fields, useridentity.FieldMfaSecrets) + } return fields } @@ -9859,6 +10070,9 @@ func (m *UserIdentityMutation) ClearField(name string) error { case useridentity.FieldClaimsGroups: m.ClearClaimsGroups() return nil + case useridentity.FieldMfaSecrets: + m.ClearMfaSecrets() + return nil } return fmt.Errorf("unknown UserIdentity nullable field %s", name) } @@ -9894,6 +10108,9 @@ func (m *UserIdentityMutation) ResetField(name string) error { case useridentity.FieldConsents: m.ResetConsents() return nil + case useridentity.FieldMfaSecrets: + m.ResetMfaSecrets() + return nil case useridentity.FieldCreatedAt: m.ResetCreatedAt() return nil diff --git a/storage/ent/db/oauth2client.go b/storage/ent/db/oauth2client.go index c6671aba..81a87afc 100644 --- a/storage/ent/db/oauth2client.go +++ b/storage/ent/db/oauth2client.go @@ -31,7 +31,9 @@ type OAuth2Client struct { LogoURL string `json:"logo_url,omitempty"` // AllowedConnectors holds the value of the "allowed_connectors" field. AllowedConnectors []string `json:"allowed_connectors,omitempty"` - selectValues sql.SelectValues + // MfaChain holds the value of the "mfa_chain" field. + MfaChain []string `json:"mfa_chain,omitempty"` + selectValues sql.SelectValues } // scanValues returns the types for scanning values from sql.Rows. @@ -39,7 +41,7 @@ func (*OAuth2Client) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) for i := range columns { switch columns[i] { - case oauth2client.FieldRedirectUris, oauth2client.FieldTrustedPeers, oauth2client.FieldAllowedConnectors: + case oauth2client.FieldRedirectUris, oauth2client.FieldTrustedPeers, oauth2client.FieldAllowedConnectors, oauth2client.FieldMfaChain: values[i] = new([]byte) case oauth2client.FieldPublic: values[i] = new(sql.NullBool) @@ -114,6 +116,14 @@ func (_m *OAuth2Client) assignValues(columns []string, values []any) error { return fmt.Errorf("unmarshal field allowed_connectors: %w", err) } } + case oauth2client.FieldMfaChain: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field mfa_chain", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.MfaChain); err != nil { + return fmt.Errorf("unmarshal field mfa_chain: %w", err) + } + } default: _m.selectValues.Set(columns[i], values[i]) } @@ -170,6 +180,9 @@ func (_m *OAuth2Client) String() string { builder.WriteString(", ") builder.WriteString("allowed_connectors=") builder.WriteString(fmt.Sprintf("%v", _m.AllowedConnectors)) + builder.WriteString(", ") + builder.WriteString("mfa_chain=") + builder.WriteString(fmt.Sprintf("%v", _m.MfaChain)) builder.WriteByte(')') return builder.String() } diff --git a/storage/ent/db/oauth2client/oauth2client.go b/storage/ent/db/oauth2client/oauth2client.go index 529f4c8d..0143eb6e 100644 --- a/storage/ent/db/oauth2client/oauth2client.go +++ b/storage/ent/db/oauth2client/oauth2client.go @@ -25,6 +25,8 @@ const ( FieldLogoURL = "logo_url" // FieldAllowedConnectors holds the string denoting the allowed_connectors field in the database. FieldAllowedConnectors = "allowed_connectors" + // FieldMfaChain holds the string denoting the mfa_chain field in the database. + FieldMfaChain = "mfa_chain" // Table holds the table name of the oauth2client in the database. Table = "oauth2clients" ) @@ -39,6 +41,7 @@ var Columns = []string{ FieldName, FieldLogoURL, FieldAllowedConnectors, + FieldMfaChain, } // ValidColumn reports if the column name is valid (part of the table columns). diff --git a/storage/ent/db/oauth2client/where.go b/storage/ent/db/oauth2client/where.go index 1425bf7e..d8ce00d6 100644 --- a/storage/ent/db/oauth2client/where.go +++ b/storage/ent/db/oauth2client/where.go @@ -317,6 +317,16 @@ func AllowedConnectorsNotNil() predicate.OAuth2Client { return predicate.OAuth2Client(sql.FieldNotNull(FieldAllowedConnectors)) } +// MfaChainIsNil applies the IsNil predicate on the "mfa_chain" field. +func MfaChainIsNil() predicate.OAuth2Client { + return predicate.OAuth2Client(sql.FieldIsNull(FieldMfaChain)) +} + +// MfaChainNotNil applies the NotNil predicate on the "mfa_chain" field. +func MfaChainNotNil() predicate.OAuth2Client { + return predicate.OAuth2Client(sql.FieldNotNull(FieldMfaChain)) +} + // And groups predicates with the AND operator between them. func And(predicates ...predicate.OAuth2Client) predicate.OAuth2Client { return predicate.OAuth2Client(sql.AndPredicates(predicates...)) diff --git a/storage/ent/db/oauth2client_create.go b/storage/ent/db/oauth2client_create.go index fe29ab2a..1e8d2a54 100644 --- a/storage/ent/db/oauth2client_create.go +++ b/storage/ent/db/oauth2client_create.go @@ -61,6 +61,12 @@ func (_c *OAuth2ClientCreate) SetAllowedConnectors(v []string) *OAuth2ClientCrea return _c } +// SetMfaChain sets the "mfa_chain" field. +func (_c *OAuth2ClientCreate) SetMfaChain(v []string) *OAuth2ClientCreate { + _c.mutation.SetMfaChain(v) + return _c +} + // SetID sets the "id" field. func (_c *OAuth2ClientCreate) SetID(v string) *OAuth2ClientCreate { _c.mutation.SetID(v) @@ -196,6 +202,10 @@ func (_c *OAuth2ClientCreate) createSpec() (*OAuth2Client, *sqlgraph.CreateSpec) _spec.SetField(oauth2client.FieldAllowedConnectors, field.TypeJSON, value) _node.AllowedConnectors = value } + if value, ok := _c.mutation.MfaChain(); ok { + _spec.SetField(oauth2client.FieldMfaChain, field.TypeJSON, value) + _node.MfaChain = value + } return _node, _spec } diff --git a/storage/ent/db/oauth2client_update.go b/storage/ent/db/oauth2client_update.go index 3fdbdbdc..7e9f753c 100644 --- a/storage/ent/db/oauth2client_update.go +++ b/storage/ent/db/oauth2client_update.go @@ -138,6 +138,24 @@ func (_u *OAuth2ClientUpdate) ClearAllowedConnectors() *OAuth2ClientUpdate { return _u } +// SetMfaChain sets the "mfa_chain" field. +func (_u *OAuth2ClientUpdate) SetMfaChain(v []string) *OAuth2ClientUpdate { + _u.mutation.SetMfaChain(v) + return _u +} + +// AppendMfaChain appends value to the "mfa_chain" field. +func (_u *OAuth2ClientUpdate) AppendMfaChain(v []string) *OAuth2ClientUpdate { + _u.mutation.AppendMfaChain(v) + return _u +} + +// ClearMfaChain clears the value of the "mfa_chain" field. +func (_u *OAuth2ClientUpdate) ClearMfaChain() *OAuth2ClientUpdate { + _u.mutation.ClearMfaChain() + return _u +} + // Mutation returns the OAuth2ClientMutation object of the builder. func (_u *OAuth2ClientUpdate) Mutation() *OAuth2ClientMutation { return _u.mutation @@ -247,6 +265,17 @@ func (_u *OAuth2ClientUpdate) sqlSave(ctx context.Context) (_node int, err error if _u.mutation.AllowedConnectorsCleared() { _spec.ClearField(oauth2client.FieldAllowedConnectors, field.TypeJSON) } + if value, ok := _u.mutation.MfaChain(); ok { + _spec.SetField(oauth2client.FieldMfaChain, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedMfaChain(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, oauth2client.FieldMfaChain, value) + }) + } + if _u.mutation.MfaChainCleared() { + _spec.ClearField(oauth2client.FieldMfaChain, field.TypeJSON) + } if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{oauth2client.Label} @@ -377,6 +406,24 @@ func (_u *OAuth2ClientUpdateOne) ClearAllowedConnectors() *OAuth2ClientUpdateOne return _u } +// SetMfaChain sets the "mfa_chain" field. +func (_u *OAuth2ClientUpdateOne) SetMfaChain(v []string) *OAuth2ClientUpdateOne { + _u.mutation.SetMfaChain(v) + return _u +} + +// AppendMfaChain appends value to the "mfa_chain" field. +func (_u *OAuth2ClientUpdateOne) AppendMfaChain(v []string) *OAuth2ClientUpdateOne { + _u.mutation.AppendMfaChain(v) + return _u +} + +// ClearMfaChain clears the value of the "mfa_chain" field. +func (_u *OAuth2ClientUpdateOne) ClearMfaChain() *OAuth2ClientUpdateOne { + _u.mutation.ClearMfaChain() + return _u +} + // Mutation returns the OAuth2ClientMutation object of the builder. func (_u *OAuth2ClientUpdateOne) Mutation() *OAuth2ClientMutation { return _u.mutation @@ -516,6 +563,17 @@ func (_u *OAuth2ClientUpdateOne) sqlSave(ctx context.Context) (_node *OAuth2Clie if _u.mutation.AllowedConnectorsCleared() { _spec.ClearField(oauth2client.FieldAllowedConnectors, field.TypeJSON) } + if value, ok := _u.mutation.MfaChain(); ok { + _spec.SetField(oauth2client.FieldMfaChain, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedMfaChain(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, oauth2client.FieldMfaChain, value) + }) + } + if _u.mutation.MfaChainCleared() { + _spec.ClearField(oauth2client.FieldMfaChain, field.TypeJSON) + } _node = &OAuth2Client{config: _u.config} _spec.Assign = _node.assignValues _spec.ScanValues = _node.scanValues diff --git a/storage/ent/db/runtime.go b/storage/ent/db/runtime.go index 98c12ecc..98c19a58 100644 --- a/storage/ent/db/runtime.go +++ b/storage/ent/db/runtime.go @@ -84,6 +84,10 @@ func init() { authrequestDescCodeChallengeMethod := authrequestFields[19].Descriptor() // authrequest.DefaultCodeChallengeMethod holds the default value on creation for the code_challenge_method field. authrequest.DefaultCodeChallengeMethod = authrequestDescCodeChallengeMethod.Default.(string) + // authrequestDescMfaValidated is the schema descriptor for mfa_validated field. + authrequestDescMfaValidated := authrequestFields[21].Descriptor() + // authrequest.DefaultMfaValidated holds the default value on creation for the mfa_validated field. + authrequest.DefaultMfaValidated = authrequestDescMfaValidated.Default.(bool) // authrequestDescID is the schema descriptor for id field. authrequestDescID := authrequestFields[0].Descriptor() // authrequest.IDValidator is a validator for the "id" field. It is called by the builders before save. diff --git a/storage/ent/db/useridentity.go b/storage/ent/db/useridentity.go index 7127299b..91f291de 100644 --- a/storage/ent/db/useridentity.go +++ b/storage/ent/db/useridentity.go @@ -36,6 +36,8 @@ type UserIdentity struct { ClaimsGroups []string `json:"claims_groups,omitempty"` // Consents holds the value of the "consents" field. Consents []byte `json:"consents,omitempty"` + // MfaSecrets holds the value of the "mfa_secrets" field. + MfaSecrets *[]byte `json:"mfa_secrets,omitempty"` // CreatedAt holds the value of the "created_at" field. CreatedAt time.Time `json:"created_at,omitempty"` // LastLogin holds the value of the "last_login" field. @@ -50,7 +52,7 @@ func (*UserIdentity) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) for i := range columns { switch columns[i] { - case useridentity.FieldClaimsGroups, useridentity.FieldConsents: + case useridentity.FieldClaimsGroups, useridentity.FieldConsents, useridentity.FieldMfaSecrets: values[i] = new([]byte) case useridentity.FieldClaimsEmailVerified: values[i] = new(sql.NullBool) @@ -135,6 +137,12 @@ func (_m *UserIdentity) assignValues(columns []string, values []any) error { } else if value != nil { _m.Consents = *value } + case useridentity.FieldMfaSecrets: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field mfa_secrets", values[i]) + } else if value != nil { + _m.MfaSecrets = value + } case useridentity.FieldCreatedAt: if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field created_at", values[i]) @@ -216,6 +224,11 @@ func (_m *UserIdentity) String() string { builder.WriteString("consents=") builder.WriteString(fmt.Sprintf("%v", _m.Consents)) builder.WriteString(", ") + if v := _m.MfaSecrets; v != nil { + builder.WriteString("mfa_secrets=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") builder.WriteString("created_at=") builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) builder.WriteString(", ") diff --git a/storage/ent/db/useridentity/useridentity.go b/storage/ent/db/useridentity/useridentity.go index f08d74ec..9fae1444 100644 --- a/storage/ent/db/useridentity/useridentity.go +++ b/storage/ent/db/useridentity/useridentity.go @@ -29,6 +29,8 @@ const ( FieldClaimsGroups = "claims_groups" // FieldConsents holds the string denoting the consents field in the database. FieldConsents = "consents" + // FieldMfaSecrets holds the string denoting the mfa_secrets field in the database. + FieldMfaSecrets = "mfa_secrets" // FieldCreatedAt holds the string denoting the created_at field in the database. FieldCreatedAt = "created_at" // FieldLastLogin holds the string denoting the last_login field in the database. @@ -51,6 +53,7 @@ var Columns = []string{ FieldClaimsEmailVerified, FieldClaimsGroups, FieldConsents, + FieldMfaSecrets, FieldCreatedAt, FieldLastLogin, FieldBlockedUntil, diff --git a/storage/ent/db/useridentity/where.go b/storage/ent/db/useridentity/where.go index 201d340f..c3e3d911 100644 --- a/storage/ent/db/useridentity/where.go +++ b/storage/ent/db/useridentity/where.go @@ -104,6 +104,11 @@ func Consents(v []byte) predicate.UserIdentity { return predicate.UserIdentity(sql.FieldEQ(FieldConsents, v)) } +// MfaSecrets applies equality check predicate on the "mfa_secrets" field. It's identical to MfaSecretsEQ. +func MfaSecrets(v []byte) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldEQ(FieldMfaSecrets, v)) +} + // CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. func CreatedAt(v time.Time) predicate.UserIdentity { return predicate.UserIdentity(sql.FieldEQ(FieldCreatedAt, v)) @@ -569,6 +574,56 @@ func ConsentsLTE(v []byte) predicate.UserIdentity { return predicate.UserIdentity(sql.FieldLTE(FieldConsents, v)) } +// MfaSecretsEQ applies the EQ predicate on the "mfa_secrets" field. +func MfaSecretsEQ(v []byte) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldEQ(FieldMfaSecrets, v)) +} + +// MfaSecretsNEQ applies the NEQ predicate on the "mfa_secrets" field. +func MfaSecretsNEQ(v []byte) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldNEQ(FieldMfaSecrets, v)) +} + +// MfaSecretsIn applies the In predicate on the "mfa_secrets" field. +func MfaSecretsIn(vs ...[]byte) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldIn(FieldMfaSecrets, vs...)) +} + +// MfaSecretsNotIn applies the NotIn predicate on the "mfa_secrets" field. +func MfaSecretsNotIn(vs ...[]byte) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldNotIn(FieldMfaSecrets, vs...)) +} + +// MfaSecretsGT applies the GT predicate on the "mfa_secrets" field. +func MfaSecretsGT(v []byte) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldGT(FieldMfaSecrets, v)) +} + +// MfaSecretsGTE applies the GTE predicate on the "mfa_secrets" field. +func MfaSecretsGTE(v []byte) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldGTE(FieldMfaSecrets, v)) +} + +// MfaSecretsLT applies the LT predicate on the "mfa_secrets" field. +func MfaSecretsLT(v []byte) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldLT(FieldMfaSecrets, v)) +} + +// MfaSecretsLTE applies the LTE predicate on the "mfa_secrets" field. +func MfaSecretsLTE(v []byte) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldLTE(FieldMfaSecrets, v)) +} + +// MfaSecretsIsNil applies the IsNil predicate on the "mfa_secrets" field. +func MfaSecretsIsNil() predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldIsNull(FieldMfaSecrets)) +} + +// MfaSecretsNotNil applies the NotNil predicate on the "mfa_secrets" field. +func MfaSecretsNotNil() predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldNotNull(FieldMfaSecrets)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.UserIdentity { return predicate.UserIdentity(sql.FieldEQ(FieldCreatedAt, v)) diff --git a/storage/ent/db/useridentity_create.go b/storage/ent/db/useridentity_create.go index 336d5c30..0f4b355a 100644 --- a/storage/ent/db/useridentity_create.go +++ b/storage/ent/db/useridentity_create.go @@ -114,6 +114,12 @@ func (_c *UserIdentityCreate) SetConsents(v []byte) *UserIdentityCreate { return _c } +// SetMfaSecrets sets the "mfa_secrets" field. +func (_c *UserIdentityCreate) SetMfaSecrets(v []byte) *UserIdentityCreate { + _c.mutation.SetMfaSecrets(v) + return _c +} + // SetCreatedAt sets the "created_at" field. func (_c *UserIdentityCreate) SetCreatedAt(v time.Time) *UserIdentityCreate { _c.mutation.SetCreatedAt(v) @@ -316,6 +322,10 @@ func (_c *UserIdentityCreate) createSpec() (*UserIdentity, *sqlgraph.CreateSpec) _spec.SetField(useridentity.FieldConsents, field.TypeBytes, value) _node.Consents = value } + if value, ok := _c.mutation.MfaSecrets(); ok { + _spec.SetField(useridentity.FieldMfaSecrets, field.TypeBytes, value) + _node.MfaSecrets = &value + } if value, ok := _c.mutation.CreatedAt(); ok { _spec.SetField(useridentity.FieldCreatedAt, field.TypeTime, value) _node.CreatedAt = value diff --git a/storage/ent/db/useridentity_update.go b/storage/ent/db/useridentity_update.go index 27ee0d3a..9de8e51f 100644 --- a/storage/ent/db/useridentity_update.go +++ b/storage/ent/db/useridentity_update.go @@ -151,6 +151,18 @@ func (_u *UserIdentityUpdate) SetConsents(v []byte) *UserIdentityUpdate { return _u } +// SetMfaSecrets sets the "mfa_secrets" field. +func (_u *UserIdentityUpdate) SetMfaSecrets(v []byte) *UserIdentityUpdate { + _u.mutation.SetMfaSecrets(v) + return _u +} + +// ClearMfaSecrets clears the value of the "mfa_secrets" field. +func (_u *UserIdentityUpdate) ClearMfaSecrets() *UserIdentityUpdate { + _u.mutation.ClearMfaSecrets() + return _u +} + // SetCreatedAt sets the "created_at" field. func (_u *UserIdentityUpdate) SetCreatedAt(v time.Time) *UserIdentityUpdate { _u.mutation.SetCreatedAt(v) @@ -287,6 +299,12 @@ func (_u *UserIdentityUpdate) sqlSave(ctx context.Context) (_node int, err error if value, ok := _u.mutation.Consents(); ok { _spec.SetField(useridentity.FieldConsents, field.TypeBytes, value) } + if value, ok := _u.mutation.MfaSecrets(); ok { + _spec.SetField(useridentity.FieldMfaSecrets, field.TypeBytes, value) + } + if _u.mutation.MfaSecretsCleared() { + _spec.ClearField(useridentity.FieldMfaSecrets, field.TypeBytes) + } if value, ok := _u.mutation.CreatedAt(); ok { _spec.SetField(useridentity.FieldCreatedAt, field.TypeTime, value) } @@ -438,6 +456,18 @@ func (_u *UserIdentityUpdateOne) SetConsents(v []byte) *UserIdentityUpdateOne { return _u } +// SetMfaSecrets sets the "mfa_secrets" field. +func (_u *UserIdentityUpdateOne) SetMfaSecrets(v []byte) *UserIdentityUpdateOne { + _u.mutation.SetMfaSecrets(v) + return _u +} + +// ClearMfaSecrets clears the value of the "mfa_secrets" field. +func (_u *UserIdentityUpdateOne) ClearMfaSecrets() *UserIdentityUpdateOne { + _u.mutation.ClearMfaSecrets() + return _u +} + // SetCreatedAt sets the "created_at" field. func (_u *UserIdentityUpdateOne) SetCreatedAt(v time.Time) *UserIdentityUpdateOne { _u.mutation.SetCreatedAt(v) @@ -604,6 +634,12 @@ func (_u *UserIdentityUpdateOne) sqlSave(ctx context.Context) (_node *UserIdenti if value, ok := _u.mutation.Consents(); ok { _spec.SetField(useridentity.FieldConsents, field.TypeBytes, value) } + if value, ok := _u.mutation.MfaSecrets(); ok { + _spec.SetField(useridentity.FieldMfaSecrets, field.TypeBytes, value) + } + if _u.mutation.MfaSecretsCleared() { + _spec.ClearField(useridentity.FieldMfaSecrets, field.TypeBytes) + } if value, ok := _u.mutation.CreatedAt(); ok { _spec.SetField(useridentity.FieldCreatedAt, field.TypeTime, value) } diff --git a/storage/ent/schema/authrequest.go b/storage/ent/schema/authrequest.go index 2b75927b..905c73ab 100644 --- a/storage/ent/schema/authrequest.go +++ b/storage/ent/schema/authrequest.go @@ -88,6 +88,8 @@ func (AuthRequest) Fields() []ent.Field { SchemaType(textSchema). Default(""), field.Bytes("hmac_key"), + field.Bool("mfa_validated"). + Default(false), } } diff --git a/storage/ent/schema/client.go b/storage/ent/schema/client.go index f0e10606..c7737d70 100644 --- a/storage/ent/schema/client.go +++ b/storage/ent/schema/client.go @@ -47,6 +47,8 @@ func (OAuth2Client) Fields() []ent.Field { NotEmpty(), field.JSON("allowed_connectors", []string{}). Optional(), + field.JSON("mfa_chain", []string{}). + Optional(), } } diff --git a/storage/ent/schema/useridentity.go b/storage/ent/schema/useridentity.go index a4928240..f8a4f2b8 100644 --- a/storage/ent/schema/useridentity.go +++ b/storage/ent/schema/useridentity.go @@ -41,6 +41,9 @@ func (UserIdentity) Fields() []ent.Field { field.JSON("claims_groups", []string{}). Optional(), field.Bytes("consents"), + field.Bytes("mfa_secrets"). + Nillable(). + Optional(), field.Time("created_at"). SchemaType(timeSchema), field.Time("last_login"). diff --git a/storage/etcd/types.go b/storage/etcd/types.go index 3624de32..117e7705 100644 --- a/storage/etcd/types.go +++ b/storage/etcd/types.go @@ -86,6 +86,8 @@ type AuthRequest struct { CodeChallengeMethod string `json:"code_challenge_method,omitempty"` HMACKey []byte `json:"hmac_key"` + + MFAValidated bool `json:"mfa_validated"` } func fromStorageAuthRequest(a storage.AuthRequest) AuthRequest { @@ -106,6 +108,7 @@ func fromStorageAuthRequest(a storage.AuthRequest) AuthRequest { CodeChallenge: a.PKCE.CodeChallenge, CodeChallengeMethod: a.PKCE.CodeChallengeMethod, HMACKey: a.HMACKey, + MFAValidated: a.MFAValidated, } } @@ -128,7 +131,8 @@ func toStorageAuthRequest(a AuthRequest) storage.AuthRequest { CodeChallenge: a.CodeChallenge, CodeChallengeMethod: a.CodeChallengeMethod, }, - HMACKey: a.HMACKey, + HMACKey: a.HMACKey, + MFAValidated: a.MFAValidated, } } @@ -258,13 +262,14 @@ func toStorageOfflineSessions(o OfflineSessions) storage.OfflineSessions { // UserIdentity is a mirrored struct from storage with JSON struct tags type UserIdentity struct { - UserID string `json:"user_id,omitempty"` - ConnectorID string `json:"connector_id,omitempty"` - Claims Claims `json:"claims,omitempty"` - Consents map[string][]string `json:"consents,omitempty"` - CreatedAt time.Time `json:"created_at"` - LastLogin time.Time `json:"last_login"` - BlockedUntil time.Time `json:"blocked_until"` + UserID string `json:"user_id,omitempty"` + ConnectorID string `json:"connector_id,omitempty"` + Claims Claims `json:"claims,omitempty"` + Consents map[string][]string `json:"consents,omitempty"` + MFASecrets map[string]*storage.MFASecret `json:"mfa_secrets,omitempty"` + CreatedAt time.Time `json:"created_at"` + LastLogin time.Time `json:"last_login"` + BlockedUntil time.Time `json:"blocked_until"` } func fromStorageUserIdentity(u storage.UserIdentity) UserIdentity { @@ -273,6 +278,7 @@ func fromStorageUserIdentity(u storage.UserIdentity) UserIdentity { ConnectorID: u.ConnectorID, Claims: fromStorageClaims(u.Claims), Consents: u.Consents, + MFASecrets: u.MFASecrets, CreatedAt: u.CreatedAt, LastLogin: u.LastLogin, BlockedUntil: u.BlockedUntil, @@ -285,6 +291,7 @@ func toStorageUserIdentity(u UserIdentity) storage.UserIdentity { ConnectorID: u.ConnectorID, Claims: toStorageClaims(u.Claims), Consents: u.Consents, + MFASecrets: u.MFASecrets, CreatedAt: u.CreatedAt, LastLogin: u.LastLogin, BlockedUntil: u.BlockedUntil, diff --git a/storage/kubernetes/types.go b/storage/kubernetes/types.go index 473f59cc..c1de9e56 100644 --- a/storage/kubernetes/types.go +++ b/storage/kubernetes/types.go @@ -287,6 +287,8 @@ type Client struct { LogoURL string `json:"logoURL,omitempty"` AllowedConnectors []string `json:"allowedConnectors,omitempty"` + + MFAChain []string `json:"mfaChain,omitempty"` } // ClientList is a list of Clients. @@ -314,6 +316,7 @@ func (cli *client) fromStorageClient(c storage.Client) Client { Name: c.Name, LogoURL: c.LogoURL, AllowedConnectors: c.AllowedConnectors, + MFAChain: c.MFAChain, } } @@ -327,6 +330,7 @@ func toStorageClient(c Client) storage.Client { Name: c.Name, LogoURL: c.LogoURL, AllowedConnectors: c.AllowedConnectors, + MFAChain: c.MFAChain, } } @@ -396,6 +400,8 @@ type AuthRequest struct { CodeChallengeMethod string `json:"code_challenge_method,omitempty"` HMACKey []byte `json:"hmac_key"` + + MFAValidated bool `json:"mfa_validated"` } // AuthRequestList is a list of AuthRequests. @@ -424,7 +430,8 @@ func toStorageAuthRequest(req AuthRequest) storage.AuthRequest { CodeChallenge: req.CodeChallenge, CodeChallengeMethod: req.CodeChallengeMethod, }, - HMACKey: req.HMACKey, + HMACKey: req.HMACKey, + MFAValidated: req.MFAValidated, } return a } @@ -454,6 +461,7 @@ func (cli *client) fromStorageAuthRequest(a storage.AuthRequest) AuthRequest { CodeChallenge: a.PKCE.CodeChallenge, CodeChallengeMethod: a.PKCE.CodeChallengeMethod, HMACKey: a.HMACKey, + MFAValidated: a.MFAValidated, } return req } @@ -913,13 +921,14 @@ type UserIdentity struct { k8sapi.TypeMeta `json:",inline"` k8sapi.ObjectMeta `json:"metadata,omitempty"` - UserID string `json:"userID,omitempty"` - ConnectorID string `json:"connectorID,omitempty"` - Claims Claims `json:"claims,omitempty"` - Consents map[string][]string `json:"consents,omitempty"` - CreatedAt time.Time `json:"createdAt,omitempty"` - LastLogin time.Time `json:"lastLogin,omitempty"` - BlockedUntil time.Time `json:"blockedUntil,omitempty"` + UserID string `json:"userID,omitempty"` + ConnectorID string `json:"connectorID,omitempty"` + Claims Claims `json:"claims,omitempty"` + Consents map[string][]string `json:"consents,omitempty"` + MFASecrets map[string]*storage.MFASecret `json:"mfaSecrets,omitempty"` + CreatedAt time.Time `json:"createdAt,omitempty"` + LastLogin time.Time `json:"lastLogin,omitempty"` + BlockedUntil time.Time `json:"blockedUntil,omitempty"` } // UserIdentityList is a list of UserIdentities. @@ -943,6 +952,7 @@ func (cli *client) fromStorageUserIdentity(u storage.UserIdentity) UserIdentity ConnectorID: u.ConnectorID, Claims: fromStorageClaims(u.Claims), Consents: u.Consents, + MFASecrets: u.MFASecrets, CreatedAt: u.CreatedAt, LastLogin: u.LastLogin, BlockedUntil: u.BlockedUntil, @@ -955,6 +965,7 @@ func toStorageUserIdentity(u UserIdentity) storage.UserIdentity { ConnectorID: u.ConnectorID, Claims: toStorageClaims(u.Claims), Consents: u.Consents, + MFASecrets: u.MFASecrets, CreatedAt: u.CreatedAt, LastLogin: u.LastLogin, BlockedUntil: u.BlockedUntil, diff --git a/storage/sql/crud.go b/storage/sql/crud.go index ab11713a..15435722 100644 --- a/storage/sql/crud.go +++ b/storage/sql/crud.go @@ -134,10 +134,11 @@ func (c *conn) CreateAuthRequest(ctx context.Context, a storage.AuthRequest) err connector_id, connector_data, expiry, code_challenge, code_challenge_method, - hmac_key + hmac_key, + mfa_validated ) values ( - $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21 + $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22 ); `, a.ID, a.ClientID, encoder(a.ResponseTypes), encoder(a.Scopes), a.RedirectURI, a.Nonce, a.State, @@ -148,6 +149,7 @@ func (c *conn) CreateAuthRequest(ctx context.Context, a storage.AuthRequest) err a.Expiry, a.PKCE.CodeChallenge, a.PKCE.CodeChallengeMethod, a.HMACKey, + a.MFAValidated, ) if err != nil { if c.alreadyExistsCheck(err) { @@ -180,8 +182,9 @@ func (c *conn) UpdateAuthRequest(ctx context.Context, id string, updater func(a connector_id = $15, connector_data = $16, expiry = $17, code_challenge = $18, code_challenge_method = $19, - hmac_key = $20 - where id = $21; + hmac_key = $20, + mfa_validated = $21 + where id = $22; `, a.ClientID, encoder(a.ResponseTypes), encoder(a.Scopes), a.RedirectURI, a.Nonce, a.State, a.ForceApprovalPrompt, a.LoggedIn, @@ -191,6 +194,7 @@ func (c *conn) UpdateAuthRequest(ctx context.Context, id string, updater func(a a.ConnectorID, a.ConnectorData, a.Expiry, a.PKCE.CodeChallenge, a.PKCE.CodeChallengeMethod, a.HMACKey, + a.MFAValidated, r.ID, ) if err != nil { @@ -212,7 +216,8 @@ func getAuthRequest(ctx context.Context, q querier, id string) (a storage.AuthRe claims_user_id, claims_username, claims_preferred_username, claims_email, claims_email_verified, claims_groups, connector_id, connector_data, expiry, - code_challenge, code_challenge_method, hmac_key + code_challenge, code_challenge_method, hmac_key, + mfa_validated from auth_request where id = $1; `, id).Scan( &a.ID, &a.ClientID, decoder(&a.ResponseTypes), decoder(&a.Scopes), &a.RedirectURI, &a.Nonce, &a.State, @@ -222,6 +227,7 @@ func getAuthRequest(ctx context.Context, q querier, id string) (a storage.AuthRe decoder(&a.Claims.Groups), &a.ConnectorID, &a.ConnectorData, &a.Expiry, &a.PKCE.CodeChallenge, &a.PKCE.CodeChallengeMethod, &a.HMACKey, + &a.MFAValidated, ) if err != nil { if err == sql.ErrNoRows { @@ -514,9 +520,10 @@ func (c *conn) UpdateClient(ctx context.Context, id string, updater func(old sto public = $4, name = $5, logo_url = $6, - allowed_connectors = $7 - where id = $8; - `, nc.Secret, encoder(nc.RedirectURIs), encoder(nc.TrustedPeers), nc.Public, nc.Name, nc.LogoURL, encoder(nc.AllowedConnectors), id, + allowed_connectors = $7, + mfa_chain = $8 + where id = $9; + `, nc.Secret, encoder(nc.RedirectURIs), encoder(nc.TrustedPeers), nc.Public, nc.Name, nc.LogoURL, encoder(nc.AllowedConnectors), encoder(nc.MFAChain), id, ) if err != nil { return fmt.Errorf("update client: %v", err) @@ -528,12 +535,12 @@ func (c *conn) UpdateClient(ctx context.Context, id string, updater func(old sto func (c *conn) CreateClient(ctx context.Context, cli storage.Client) error { _, err := c.Exec(` insert into client ( - id, secret, redirect_uris, trusted_peers, public, name, logo_url, allowed_connectors + id, secret, redirect_uris, trusted_peers, public, name, logo_url, allowed_connectors, mfa_chain ) - values ($1, $2, $3, $4, $5, $6, $7, $8); + values ($1, $2, $3, $4, $5, $6, $7, $8, $9); `, cli.ID, cli.Secret, encoder(cli.RedirectURIs), encoder(cli.TrustedPeers), - cli.Public, cli.Name, cli.LogoURL, encoder(cli.AllowedConnectors), + cli.Public, cli.Name, cli.LogoURL, encoder(cli.AllowedConnectors), encoder(cli.MFAChain), ) if err != nil { if c.alreadyExistsCheck(err) { @@ -547,7 +554,7 @@ func (c *conn) CreateClient(ctx context.Context, cli storage.Client) error { func getClient(ctx context.Context, q querier, id string) (storage.Client, error) { return scanClient(q.QueryRow(` select - id, secret, redirect_uris, trusted_peers, public, name, logo_url, allowed_connectors + id, secret, redirect_uris, trusted_peers, public, name, logo_url, allowed_connectors, mfa_chain from client where id = $1; `, id)) } @@ -559,7 +566,7 @@ func (c *conn) GetClient(ctx context.Context, id string) (storage.Client, error) func (c *conn) ListClients(ctx context.Context) ([]storage.Client, error) { rows, err := c.Query(` select - id, secret, redirect_uris, trusted_peers, public, name, logo_url, allowed_connectors + id, secret, redirect_uris, trusted_peers, public, name, logo_url, allowed_connectors, mfa_chain from client; `) if err != nil { @@ -583,9 +590,10 @@ func (c *conn) ListClients(ctx context.Context) ([]storage.Client, error) { func scanClient(s scanner) (cli storage.Client, err error) { var allowedConnectors []byte + var mfaChain []byte err = s.Scan( &cli.ID, &cli.Secret, decoder(&cli.RedirectURIs), decoder(&cli.TrustedPeers), - &cli.Public, &cli.Name, &cli.LogoURL, &allowedConnectors, + &cli.Public, &cli.Name, &cli.LogoURL, &allowedConnectors, &mfaChain, ) if err != nil { if err == sql.ErrNoRows { @@ -598,6 +606,11 @@ func scanClient(s scanner) (cli storage.Client, err error) { return cli, fmt.Errorf("unmarshal client allowed connectors: %v", err) } } + if len(mfaChain) > 0 { + if err := json.Unmarshal(mfaChain, &cli.MFAChain); err != nil { + return cli, fmt.Errorf("unmarshal client mfa chain: %v", err) + } + } return cli, nil } @@ -781,17 +794,17 @@ func (c *conn) CreateUserIdentity(ctx context.Context, u storage.UserIdentity) e user_id, connector_id, claims_user_id, claims_username, claims_preferred_username, claims_email, claims_email_verified, claims_groups, - consents, + consents, mfa_secrets, created_at, last_login, blocked_until ) values ( - $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12 + $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13 ); `, u.UserID, u.ConnectorID, u.Claims.UserID, u.Claims.Username, u.Claims.PreferredUsername, u.Claims.Email, u.Claims.EmailVerified, encoder(u.Claims.Groups), - encoder(u.Consents), + encoder(u.Consents), encoder(u.MFASecrets), u.CreatedAt, u.LastLogin, u.BlockedUntil, ) if err != nil { @@ -824,14 +837,15 @@ func (c *conn) UpdateUserIdentity(ctx context.Context, userID, connectorID strin claims_email_verified = $5, claims_groups = $6, consents = $7, - created_at = $8, - last_login = $9, - blocked_until = $10 - where user_id = $11 AND connector_id = $12; + mfa_secrets = $8, + created_at = $9, + last_login = $10, + blocked_until = $11 + where user_id = $12 AND connector_id = $13; `, newIdentity.Claims.UserID, newIdentity.Claims.Username, newIdentity.Claims.PreferredUsername, newIdentity.Claims.Email, newIdentity.Claims.EmailVerified, encoder(newIdentity.Claims.Groups), - encoder(newIdentity.Consents), + encoder(newIdentity.Consents), encoder(newIdentity.MFASecrets), newIdentity.CreatedAt, newIdentity.LastLogin, newIdentity.BlockedUntil, u.UserID, u.ConnectorID, ) @@ -852,7 +866,7 @@ func getUserIdentity(ctx context.Context, q querier, userID, connectorID string) user_id, connector_id, claims_user_id, claims_username, claims_preferred_username, claims_email, claims_email_verified, claims_groups, - consents, + consents, mfa_secrets, created_at, last_login, blocked_until from user_identity where user_id = $1 AND connector_id = $2; @@ -865,7 +879,7 @@ func (c *conn) ListUserIdentities(ctx context.Context) ([]storage.UserIdentity, user_id, connector_id, claims_user_id, claims_username, claims_preferred_username, claims_email, claims_email_verified, claims_groups, - consents, + consents, mfa_secrets, created_at, last_login, blocked_until from user_identity; `) @@ -889,11 +903,12 @@ func (c *conn) ListUserIdentities(ctx context.Context) ([]storage.UserIdentity, } func scanUserIdentity(s scanner) (u storage.UserIdentity, err error) { + var mfaSecrets []byte err = s.Scan( &u.UserID, &u.ConnectorID, &u.Claims.UserID, &u.Claims.Username, &u.Claims.PreferredUsername, &u.Claims.Email, &u.Claims.EmailVerified, decoder(&u.Claims.Groups), - decoder(&u.Consents), + decoder(&u.Consents), &mfaSecrets, &u.CreatedAt, &u.LastLogin, &u.BlockedUntil, ) if err != nil { @@ -905,6 +920,11 @@ func scanUserIdentity(s scanner) (u storage.UserIdentity, err error) { if u.Consents == nil { u.Consents = make(map[string][]string) } + if len(mfaSecrets) > 0 { + if err := json.Unmarshal(mfaSecrets, &u.MFASecrets); err != nil { + return u, fmt.Errorf("unmarshal user identity mfa secrets: %v", err) + } + } return u, nil } diff --git a/storage/sql/migrate.go b/storage/sql/migrate.go index 7561d146..d1c6d837 100644 --- a/storage/sql/migrate.go +++ b/storage/sql/migrate.go @@ -422,4 +422,17 @@ var migrations = []migration{ );`, }, }, + { + stmts: []string{ + ` + alter table auth_request + add column mfa_validated boolean not null default false;`, + ` + alter table user_identity + add column mfa_secrets bytea;`, + ` + alter table client + add column mfa_chain bytea;`, + }, + }, } diff --git a/storage/storage.go b/storage/storage.go index 963c7c67..5c25dce2 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -185,6 +185,10 @@ type Client struct { // AllowedConnectors is a list of connector IDs that the client is allowed to use for authentication. // If empty, all connectors are allowed. AllowedConnectors []string `json:"allowedConnectors"` + + // MFAChain is an ordered list of MFA authenticator IDs that a user must complete + // during login. Empty means no MFA required. + MFAChain []string `json:"mfaChain"` } // Claims represents the ID Token claims supported by the server. @@ -247,6 +251,9 @@ type AuthRequest struct { // HMACKey is used when generating an AuthRequest-specific HMAC HMACKey []byte + + // MFAValidated is set to true if the user has completed multi-factor authentication. + MFAValidated bool } // AuthCode represents a code which can be exchanged for an OAuth2 token response. @@ -330,12 +337,22 @@ type RefreshTokenRef struct { LastUsed time.Time } +// MFASecret stores the enrollment state and secret for an MFA authenticator. +type MFASecret struct { + AuthenticatorID string `json:"authenticatorID"` + Type string `json:"type"` + Secret string `json:"secret"` + Confirmed bool `json:"confirmed"` + CreatedAt time.Time `json:"createdAt"` +} + // UserIdentity represents persistent per-user identity data. type UserIdentity struct { UserID string ConnectorID string Claims Claims - Consents map[string][]string // clientID -> approved scopes + Consents map[string][]string // clientID -> approved scopes + MFASecrets map[string]*MFASecret // authenticatorID -> secret CreatedAt time.Time LastLogin time.Time BlockedUntil time.Time diff --git a/t b/t new file mode 100644 index 00000000..91d9073d --- /dev/null +++ b/t @@ -0,0 +1,5 @@ +export DEX_POSTGRES_DATABASE=dex +export DEX_POSTGRES_USER=postgres +export DEX_POSTGRES_PASSWORD=postgres +export DEX_POSTGRES_HOST=0.0.0.0 +export DEX_POSTGRES_PORT=5432 diff --git a/web/templates/totp_verify.html b/web/templates/totp_verify.html new file mode 100644 index 00000000..4aa7f2ef --- /dev/null +++ b/web/templates/totp_verify.html @@ -0,0 +1,35 @@ +{{ template "header.html" . }} + +
+

Two-factor authentication

+ {{ if not (eq .QRCode "") }} +

Scan the QR code below using your authenticator app, then enter the code.

+
+ QR code +
+ {{ else }} +

Enter the code from your authenticator app for {{ .Issuer }}.

+ {{ end }} +
+
+
+ +
+ +
+ + {{ if .Invalid }} +
+ Invalid code. Please try again. +
+ {{ end }} + + +
+
+ +{{ template "footer.html" . }}