Browse Source

feat: add multi-factor authentication support (TOTP)

Signed-off-by: maksim.nabokikh <max.nabokih@gmail.com>
pull/3712/head
maksim.nabokikh 15 hours ago
parent
commit
bfb0ddce4c
  1. 79
      cmd/dex/config.go
  2. 24
      cmd/dex/serve.go
  3. 24
      examples/config-dev.yaml
  4. 2
      go.mod
  5. 4
      go.sum
  6. 60
      server/handlers.go
  7. 1
      server/handlers_approval_test.go
  8. 295
      server/mfa.go
  9. 12
      server/server.go
  10. 18
      server/templates.go
  11. 2
      storage/ent/client/authrequest.go
  12. 2
      storage/ent/client/client.go
  13. 16
      storage/ent/client/types.go
  14. 19
      storage/ent/client/useridentity.go
  15. 15
      storage/ent/db/authrequest.go
  16. 10
      storage/ent/db/authrequest/authrequest.go
  17. 15
      storage/ent/db/authrequest/where.go
  18. 25
      storage/ent/db/authrequest_create.go
  19. 34
      storage/ent/db/authrequest_update.go
  20. 3
      storage/ent/db/migrate/schema.go
  21. 223
      storage/ent/db/mutation.go
  22. 17
      storage/ent/db/oauth2client.go
  23. 3
      storage/ent/db/oauth2client/oauth2client.go
  24. 10
      storage/ent/db/oauth2client/where.go
  25. 10
      storage/ent/db/oauth2client_create.go
  26. 58
      storage/ent/db/oauth2client_update.go
  27. 4
      storage/ent/db/runtime.go
  28. 15
      storage/ent/db/useridentity.go
  29. 3
      storage/ent/db/useridentity/useridentity.go
  30. 55
      storage/ent/db/useridentity/where.go
  31. 10
      storage/ent/db/useridentity_create.go
  32. 36
      storage/ent/db/useridentity_update.go
  33. 2
      storage/ent/schema/authrequest.go
  34. 2
      storage/ent/schema/client.go
  35. 3
      storage/ent/schema/useridentity.go
  36. 23
      storage/etcd/types.go
  37. 27
      storage/kubernetes/types.go
  38. 70
      storage/sql/crud.go
  39. 13
      storage/sql/migrate.go
  40. 19
      storage/storage.go
  41. 5
      t
  42. 35
      web/templates/totp_verify.html

79
cmd/dex/config.go

@ -65,6 +65,19 @@ type Config struct {
// querying the storage. Cannot be specified without enabling a passwords // querying the storage. Cannot be specified without enabling a passwords
// database. // database.
StaticPasswords []password `json:"staticPasswords"` StaticPasswords []password `json:"staticPasswords"`
// MFA holds multi-factor authentication configuration.
MFA MFAConfig `json:"mfa"`
}
// MFAConfig holds multi-factor authentication settings.
type MFAConfig struct {
// Authenticators defines MFA providers available for clients to reference.
Authenticators []MFAAuthenticator `json:"authenticators"`
// DefaultMFAChain is the default ordered list of authenticator IDs applied
// to clients that don't specify their own mfaChain. Empty means no MFA by default.
DefaultMFAChain []string `json:"defaultMFAChain"`
} }
// Validate the configuration // Validate the configuration
@ -103,6 +116,55 @@ func (c Config) Validate() error {
if len(checkErrors) != 0 { if len(checkErrors) != 0 {
return fmt.Errorf("invalid Config:\n\t-\t%s", strings.Join(checkErrors, "\n\t-\t")) return fmt.Errorf("invalid Config:\n\t-\t%s", strings.Join(checkErrors, "\n\t-\t"))
} }
if err := c.validateMFA(); err != nil {
return err
}
return nil
}
func (c Config) validateMFA() error {
mfa := c.MFA
if len(mfa.Authenticators) == 0 && len(mfa.DefaultMFAChain) == 0 {
return nil
}
if !featureflags.SessionsEnabled.Enabled() {
return fmt.Errorf("mfa requires sessions to be enabled (DEX_SESSIONS_ENABLED=true)")
}
knownTypes := map[string]bool{"TOTP": true}
ids := make(map[string]bool, len(mfa.Authenticators))
for _, auth := range mfa.Authenticators {
if auth.ID == "" {
return fmt.Errorf("mfa.authenticators: authenticator must have an id")
}
if ids[auth.ID] {
return fmt.Errorf("mfa.authenticators: duplicate authenticator id %q", auth.ID)
}
ids[auth.ID] = true
if !knownTypes[auth.Type] {
return fmt.Errorf("mfa.authenticators: unknown type %q for authenticator %q", auth.Type, auth.ID)
}
}
for _, authID := range mfa.DefaultMFAChain {
if !ids[authID] {
return fmt.Errorf("mfa.defaultMFAChain: references unknown authenticator %q", authID)
}
}
for _, client := range c.StaticClients {
for _, authID := range client.MFAChain {
if !ids[authID] {
return fmt.Errorf("staticClients: client %q references unknown MFA authenticator %q", client.ID, authID)
}
}
}
return nil return nil
} }
@ -585,3 +647,20 @@ type RefreshToken struct {
AbsoluteLifetime string `json:"absoluteLifetime"` AbsoluteLifetime string `json:"absoluteLifetime"`
ValidIfNotUsedFor string `json:"validIfNotUsedFor"` ValidIfNotUsedFor string `json:"validIfNotUsedFor"`
} }
// MFAAuthenticator defines a multi-factor authentication provider.
type MFAAuthenticator struct {
ID string `json:"id"`
Type string `json:"type"`
Config json.RawMessage `json:"config"`
// ConnectorTypes limits this authenticator to specific connector types (e.g., "ldap", "oidc", "saml").
// If empty, the authenticator applies to all connector types.
ConnectorTypes []string `json:"connectorTypes"`
}
// TOTPConfig holds configuration for a TOTP authenticator.
type TOTPConfig struct {
// Issuer is the name of the service shown in the authenticator app.
Issuer string `json:"issuer"`
}

24
cmd/dex/serve.go

@ -4,6 +4,7 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"log/slog" "log/slog"
@ -384,6 +385,8 @@ func runServe(options serveOptions) error {
ContinueOnConnectorFailure: featureflags.ContinueOnConnectorFailure.Enabled(), ContinueOnConnectorFailure: featureflags.ContinueOnConnectorFailure.Enabled(),
Signer: signerInstance, Signer: signerInstance,
IDTokensValidFor: idTokensValidFor, IDTokensValidFor: idTokensValidFor,
MFAProviders: buildMFAProviders(c.MFA.Authenticators, logger),
DefaultMFAChain: c.MFA.DefaultMFAChain,
} }
if c.Expiry.AuthRequests != "" { if c.Expiry.AuthRequests != "" {
@ -759,3 +762,24 @@ func loadTLSConfig(certFile, keyFile, caFile string, baseConfig *tls.Config) (*t
func recordBuildInfo() { func recordBuildInfo() {
buildInfo.WithLabelValues(version, runtime.Version(), fmt.Sprintf("%s/%s", runtime.GOOS, runtime.GOARCH)).Set(1) buildInfo.WithLabelValues(version, runtime.Version(), fmt.Sprintf("%s/%s", runtime.GOOS, runtime.GOARCH)).Set(1)
} }
func buildMFAProviders(authenticators []MFAAuthenticator, logger *slog.Logger) map[string]server.MFAProvider {
if len(authenticators) == 0 {
return nil
}
providers := make(map[string]server.MFAProvider, len(authenticators))
for _, auth := range authenticators {
switch auth.Type {
case "TOTP":
var cfg TOTPConfig
if err := json.Unmarshal(auth.Config, &cfg); err != nil {
logger.Error("failed to parse TOTP config", "id", auth.ID, "err", err)
continue
}
providers[auth.ID] = server.NewTOTPProvider(cfg.Issuer, auth.ConnectorTypes)
logger.Info("MFA authenticator configured", "id", auth.ID, "type", auth.Type)
}
}
return providers
}

24
examples/config-dev.yaml

@ -130,6 +130,25 @@ telemetry:
# # Supported code challenge methods. Defaults to ["S256", "plain"]. # # Supported code challenge methods. Defaults to ["S256", "plain"].
# codeChallengeMethodsSupported: ["S256", "plain"] # codeChallengeMethodsSupported: ["S256", "plain"]
# Multi-factor authentication configuration.
# Requires DEX_SESSIONS_ENABLED=true feature flag.
# mfa:
# authenticators:
# - id: totp-1
# type: TOTP
# config:
# issuer: "dex"
# # Optional: limit this authenticator to specific connector types (e.g., ldap, oidc, saml).
# # If omitted or empty, applies to all connector types.
# # It is recommended to use this option to prevent MFA from being used for connectors
# # with their own MFA mechanisms, e.g., OIDC, Google, etc. (but technically, it is possible).
# connectorTypes:
# - mockCallback
# # Default MFA chain applied to clients that don't specify their own mfaChain.
# # If omitted or empty, no MFA is required by default.
# defaultMFAChain:
# - totp-1
# Instead of reading from an external storage, use this list of clients. # Instead of reading from an external storage, use this list of clients.
# #
# If this option isn't chosen clients may be added through the gRPC API. # If this option isn't chosen clients may be added through the gRPC API.
@ -144,6 +163,11 @@ staticClients:
# If omitted or empty, all connectors are allowed. # If omitted or empty, all connectors are allowed.
# allowedConnectors: # allowedConnectors:
# - mock # - mock
# Optional: ordered list of MFA authenticator IDs the user must complete during login.
# References authenticator IDs from mfa.authenticators.
# If omitted, mfa.defaultMFAChain is used.
# mfaChain:
# - totp-1
# Example using environment variables # Example using environment variables
# Set DEX_CLIENT_ID and DEX_SECURE_CLIENT_SECRET before starting Dex # Set DEX_CLIENT_ID and DEX_SECURE_CLIENT_SECRET before starting Dex

2
go.mod

@ -28,6 +28,7 @@ require (
github.com/oklog/run v1.2.0 github.com/oklog/run v1.2.0
github.com/openbao/openbao/api/v2 v2.5.1 github.com/openbao/openbao/api/v2 v2.5.1
github.com/pkg/errors v0.9.1 github.com/pkg/errors v0.9.1
github.com/pquerna/otp v1.5.0
github.com/prometheus/client_golang v1.23.2 github.com/prometheus/client_golang v1.23.2
github.com/russellhaering/goxmldsig v1.5.0 github.com/russellhaering/goxmldsig v1.5.0
github.com/spf13/cobra v1.10.2 github.com/spf13/cobra v1.10.2
@ -58,6 +59,7 @@ require (
github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect
github.com/beorn7/perks v1.0.1 // indirect github.com/beorn7/perks v1.0.1 // indirect
github.com/bmatcuk/doublestar v1.3.4 // indirect github.com/bmatcuk/doublestar v1.3.4 // indirect
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/coreos/go-semver v0.3.1 // indirect github.com/coreos/go-semver v0.3.1 // indirect

4
go.sum

@ -42,6 +42,8 @@ github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/bmatcuk/doublestar v1.3.4 h1:gPypJ5xD31uhX6Tf54sDPUOBXTqKH4c9aPY66CyQrS0= github.com/bmatcuk/doublestar v1.3.4 h1:gPypJ5xD31uhX6Tf54sDPUOBXTqKH4c9aPY66CyQrS0=
github.com/bmatcuk/doublestar v1.3.4/go.mod h1:wiQtGV+rzVYxB7WIlirSN++5HPtPlXEo9MEoZQC/PmE= github.com/bmatcuk/doublestar v1.3.4/go.mod h1:wiQtGV+rzVYxB7WIlirSN++5HPtPlXEo9MEoZQC/PmE=
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI=
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
@ -201,6 +203,8 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pquerna/otp v1.5.0 h1:NMMR+WrmaqXU4EzdGJEE1aUUI0AMRzsp96fFFWNPwxs=
github.com/pquerna/otp v1.5.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg=
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o= github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg= github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=

60
server/handlers.go

@ -760,6 +760,36 @@ func (s *Server) finalizeLogin(ctx context.Context, identity connector.Identity,
userIdentity = &ui userIdentity = &ui
} }
// an HMAC is used here to ensure that the request ID is unpredictable, ensuring that an attacker who intercepted the original
// flow would be unable to poll for the result at the /approval endpoint
h := hmac.New(sha256.New, authReq.HMACKey)
h.Write([]byte(authReq.ID))
mac := h.Sum(nil)
hmacParam := base64.RawURLEncoding.EncodeToString(mac)
// Check if the client requires MFA.
mfaChain, err := s.mfaChainForClient(ctx, authReq.ClientID, authReq.ConnectorID)
if err != nil {
return "", false, fmt.Errorf("failed to get MFA chain for client: %v", err)
}
if len(mfaChain) > 0 {
// Redirect to MFA verification for the first authenticator in the chain.
v := url.Values{}
v.Set("req", authReq.ID)
v.Set("hmac", hmacParam)
v.Set("authenticator", mfaChain[0])
returnURL := path.Join(s.issuerURL.Path, "/totp/verify") + "?" + v.Encode()
return returnURL, false, nil
}
// No MFA required — mark as validated.
if err := s.storage.UpdateAuthRequest(ctx, authReq.ID, func(a storage.AuthRequest) (storage.AuthRequest, error) {
a.MFAValidated = true
return a, nil
}); err != nil {
return "", false, fmt.Errorf("failed to update auth request MFA status: %v", err)
}
// we can skip the redirect to /approval and go ahead and send code if it's not required // we can skip the redirect to /approval and go ahead and send code if it's not required
if s.skipApproval && !authReq.ForceApprovalPrompt { if s.skipApproval && !authReq.ForceApprovalPrompt {
return "", true, nil return "", true, nil
@ -772,13 +802,7 @@ func (s *Server) finalizeLogin(ctx context.Context, identity connector.Identity,
} }
} }
// an HMAC is used here to ensure that the request ID is unpredictable, ensuring that an attacker who intercepted the original returnURL := path.Join(s.issuerURL.Path, "/approval") + "?req=" + authReq.ID + "&hmac=" + hmacParam
// flow would be unable to poll for the result at the /approval endpoint
h := hmac.New(sha256.New, authReq.HMACKey)
h.Write([]byte(authReq.ID))
mac := h.Sum(nil)
returnURL := path.Join(s.issuerURL.Path, "/approval") + "?req=" + authReq.ID + "&hmac=" + base64.RawURLEncoding.EncodeToString(mac)
return returnURL, false, nil return returnURL, false, nil
} }
@ -810,6 +834,28 @@ func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) {
s.renderError(r, w, http.StatusInternalServerError, "Login process not yet finalized.") s.renderError(r, w, http.StatusInternalServerError, "Login process not yet finalized.")
return return
} }
if !authReq.MFAValidated {
// Check if MFA is actually required — if so, redirect to TOTP instead of blocking.
// This handles the case where MFA was enabled after the auth flow started.
mfaChain, err := s.mfaChainForClient(ctx, authReq.ClientID, authReq.ConnectorID)
if err != nil {
s.logger.ErrorContext(ctx, "failed to get MFA chain", "err", err)
s.renderError(r, w, http.StatusInternalServerError, "Internal server error.")
return
}
if len(mfaChain) > 0 {
h := hmac.New(sha256.New, authReq.HMACKey)
h.Write([]byte(authReq.ID))
v := url.Values{}
v.Set("req", authReq.ID)
v.Set("hmac", base64.RawURLEncoding.EncodeToString(h.Sum(nil)))
v.Set("authenticator", mfaChain[0])
totpURL := path.Join(s.issuerURL.Path, "/totp/verify") + "?" + v.Encode()
http.Redirect(w, r, totpURL, http.StatusSeeOther)
return
}
// No MFA required but flag not set — allow through (backward compat).
}
// build expected hmac with secret key // build expected hmac with secret key
h := hmac.New(sha256.New, authReq.HMACKey) h := hmac.New(sha256.New, authReq.HMACKey)

1
server/handlers_approval_test.go

@ -84,6 +84,7 @@ func TestHandleApprovalDoubleSubmitPOST(t *testing.T) {
RedirectURI: "https://client.example/callback", RedirectURI: "https://client.example/callback",
Expiry: time.Now().Add(time.Minute), Expiry: time.Now().Add(time.Minute),
LoggedIn: true, LoggedIn: true,
MFAValidated: true,
HMACKey: []byte("approval-double-submit-key"), HMACKey: []byte("approval-double-submit-key"),
} }
require.NoError(t, server.storage.CreateAuthRequest(ctx, authReq)) require.NoError(t, server.storage.CreateAuthRequest(ctx, authReq))

295
server/mfa.go

@ -0,0 +1,295 @@
package server
import (
"bytes"
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"fmt"
"image/png"
"net/http"
"strings"
"github.com/pquerna/otp"
"github.com/pquerna/otp/totp"
"github.com/dexidp/dex/storage"
)
// MFAProvider is a pluggable multi-factor authentication method.
type MFAProvider interface {
// Type returns the authenticator type identifier (e.g., "TOTP").
Type() string
// EnabledForConnectorType returns true if this provider applies to the given connector type.
// If no connector types are configured, the provider applies to all.
EnabledForConnectorType(connectorType string) bool
}
// TOTPProvider implements TOTP-based multi-factor authentication.
type TOTPProvider struct {
issuer string
connectorTypes map[string]struct{}
}
// NewTOTPProvider creates a new TOTP MFA provider.
func NewTOTPProvider(issuer string, connectorTypes []string) *TOTPProvider {
m := make(map[string]struct{}, len(connectorTypes))
for _, t := range connectorTypes {
m[t] = struct{}{}
}
return &TOTPProvider{issuer: issuer, connectorTypes: m}
}
func (p *TOTPProvider) EnabledForConnectorType(connectorType string) bool {
if len(p.connectorTypes) == 0 {
return true
}
_, ok := p.connectorTypes[connectorType]
return ok
}
func (p *TOTPProvider) Type() string { return "TOTP" }
func (p *TOTPProvider) generate(connID, email string) (*otp.Key, error) {
return totp.Generate(totp.GenerateOpts{
Issuer: p.issuer,
AccountName: fmt.Sprintf("(%s) %s", connID, email),
})
}
func (s *Server) handleMFAVerify(w http.ResponseWriter, r *http.Request) {
macEncoded := r.FormValue("hmac")
if macEncoded == "" {
s.renderError(r, w, http.StatusUnauthorized, "Unauthorized request.")
return
}
mac, err := base64.RawURLEncoding.DecodeString(macEncoded)
if err != nil {
s.renderError(r, w, http.StatusUnauthorized, "Unauthorized request.")
return
}
ctx := r.Context()
authReq, err := s.storage.GetAuthRequest(ctx, r.FormValue("req"))
if err != nil {
s.logger.ErrorContext(ctx, "failed to get auth request", "err", err)
s.renderError(r, w, http.StatusInternalServerError, "Database error.")
return
}
if !authReq.LoggedIn {
s.logger.ErrorContext(ctx, "auth request does not have an identity for MFA verification")
s.renderError(r, w, http.StatusInternalServerError, "Login process not yet finalized.")
return
}
// Verify HMAC
h := hmac.New(sha256.New, authReq.HMACKey)
h.Write([]byte(authReq.ID))
if !hmac.Equal(mac, h.Sum(nil)) {
s.renderError(r, w, http.StatusUnauthorized, "Unauthorized request.")
return
}
authenticatorID := r.FormValue("authenticator")
provider, ok := s.mfaProviders[authenticatorID]
if !ok {
s.renderError(r, w, http.StatusBadRequest, "Unknown authenticator.")
return
}
totpProvider, ok := provider.(*TOTPProvider)
if !ok {
s.renderError(r, w, http.StatusInternalServerError, "Unsupported authenticator type.")
return
}
identity, err := s.storage.GetUserIdentity(ctx, authReq.Claims.UserID, authReq.ConnectorID)
if err != nil {
s.logger.ErrorContext(ctx, "failed to get user identity", "err", err)
s.renderError(r, w, http.StatusInternalServerError, "Database error.")
return
}
returnURL := strings.Replace(r.URL.String(), "/totp/verify", "/approval", 1)
if authReq.MFAValidated {
http.Redirect(w, r, returnURL, http.StatusSeeOther)
return
}
secret := identity.MFASecrets[authenticatorID]
switch r.Method {
case http.MethodGet:
if secret == nil {
// First-time enrollment: generate a new TOTP key.
generated, err := totpProvider.generate(authReq.ConnectorID, authReq.Claims.Email)
if err != nil {
s.logger.ErrorContext(ctx, "failed to generate TOTP key", "err", err)
s.renderError(r, w, http.StatusInternalServerError, "Internal server error.")
return
}
secret = &storage.MFASecret{
AuthenticatorID: authenticatorID,
Type: "TOTP",
Secret: generated.String(),
Confirmed: false,
CreatedAt: s.now(),
}
if err := s.storage.UpdateUserIdentity(ctx, authReq.Claims.UserID, authReq.ConnectorID, func(old storage.UserIdentity) (storage.UserIdentity, error) {
if old.MFASecrets == nil {
old.MFASecrets = make(map[string]*storage.MFASecret)
}
old.MFASecrets[authenticatorID] = secret
return old, nil
}); err != nil {
s.logger.ErrorContext(ctx, "failed to store MFA secret", "err", err)
s.renderError(r, w, http.StatusInternalServerError, "Internal server error.")
return
}
}
s.renderTOTPPage(secret, false, totpProvider.issuer, authReq.ConnectorID, w, r)
case http.MethodPost:
if secret == nil || secret.Secret == "" {
s.renderError(r, w, http.StatusBadRequest, "MFA not enrolled.")
return
}
code := r.FormValue("totp")
generated, err := otp.NewKeyFromURL(secret.Secret)
if err != nil {
s.logger.ErrorContext(ctx, "failed to load TOTP key", "err", err)
s.renderError(r, w, http.StatusInternalServerError, "Internal server error.")
return
}
if !totp.Validate(code, generated.Secret()) {
s.renderTOTPPage(secret, true, totpProvider.issuer, authReq.ConnectorID, w, r)
return
}
// Mark MFA secret as confirmed.
if !secret.Confirmed {
if err := s.storage.UpdateUserIdentity(ctx, authReq.Claims.UserID, authReq.ConnectorID, func(old storage.UserIdentity) (storage.UserIdentity, error) {
if s := old.MFASecrets[authenticatorID]; s != nil {
s.Confirmed = true
}
return old, nil
}); err != nil {
s.logger.ErrorContext(ctx, "failed to confirm MFA secret", "err", err)
s.renderError(r, w, http.StatusInternalServerError, "Internal server error.")
return
}
}
// Mark auth request as MFA-validated.
if err := s.storage.UpdateAuthRequest(ctx, authReq.ID, func(old storage.AuthRequest) (storage.AuthRequest, error) {
old.MFAValidated = true
return old, nil
}); err != nil {
s.logger.ErrorContext(ctx, "failed to update auth request", "err", err)
s.renderError(r, w, http.StatusInternalServerError, "Internal server error.")
return
}
// Skip approval if configured.
if s.skipApproval && !authReq.ForceApprovalPrompt {
authReq, err = s.storage.GetAuthRequest(ctx, authReq.ID)
if err != nil {
s.logger.ErrorContext(ctx, "failed to get finalized auth request", "err", err)
s.renderError(r, w, http.StatusInternalServerError, "Login error.")
return
}
s.sendCodeResponse(w, r, authReq)
return
}
http.Redirect(w, r, returnURL, http.StatusSeeOther)
default:
s.renderError(r, w, http.StatusBadRequest, "Unsupported request method.")
}
}
func (s *Server) renderTOTPPage(secret *storage.MFASecret, lastFail bool, issuer, connectorID string, w http.ResponseWriter, r *http.Request) {
var qrCode string
if !secret.Confirmed {
var err error
qrCode, err = generateTOTPQRCode(secret.Secret)
if err != nil {
s.logger.ErrorContext(r.Context(), "failed to generate QR code", "err", err)
s.renderError(r, w, http.StatusInternalServerError, "Internal server error.")
return
}
}
if err := s.templates.totpVerify(r, w, r.URL.String(), issuer, connectorID, qrCode, lastFail); err != nil {
s.logger.ErrorContext(r.Context(), "server template error", "err", err)
}
}
func generateTOTPQRCode(keyURL string) (string, error) {
generated, err := otp.NewKeyFromURL(keyURL)
if err != nil {
return "", fmt.Errorf("failed to load TOTP key: %w", err)
}
qrCodeImage, err := generated.Image(300, 300)
if err != nil {
return "", fmt.Errorf("failed to generate TOTP QR code: %w", err)
}
var buf bytes.Buffer
if err := png.Encode(&buf, qrCodeImage); err != nil {
return "", fmt.Errorf("failed to encode TOTP QR code: %w", err)
}
return base64.StdEncoding.EncodeToString(buf.Bytes()), nil
}
// mfaChainForClient returns the MFA chain for a client filtered by connector type,
// falling back to the server's defaultMFAChain if the client has none.
// Returns nil if no MFA is configured/applicable.
func (s *Server) mfaChainForClient(ctx context.Context, clientID, connectorID string) ([]string, error) {
if len(s.mfaProviders) == 0 {
return nil, nil
}
client, err := s.storage.GetClient(ctx, clientID)
if err != nil {
return nil, err
}
// nil means "not set" — fall back to default.
// Explicit empty slice ([]string{}) means "no MFA" — don't fall back.
source := client.MFAChain
if source == nil {
source = s.defaultMFAChain
}
// Resolve connector type from connector ID.
connectorType := s.getConnectorType(connectorID)
var chain []string
for _, authID := range source {
provider, ok := s.mfaProviders[authID]
if ok && provider.EnabledForConnectorType(connectorType) {
chain = append(chain, authID)
}
}
return chain, nil
}
// getConnectorType returns the type of the connector with the given ID.
func (s *Server) getConnectorType(connectorID string) string {
conn, err := s.storage.GetConnector(context.Background(), connectorID)
if err != nil {
return ""
}
return conn.Type
}

12
server/server.go

@ -137,6 +137,12 @@ type Config struct {
// If enabled, the server will continue starting even if some connectors fail to initialize. // If enabled, the server will continue starting even if some connectors fail to initialize.
// This allows the server to operate with a subset of connectors if some are misconfigured. // This allows the server to operate with a subset of connectors if some are misconfigured.
ContinueOnConnectorFailure bool ContinueOnConnectorFailure bool
// MFAProviders maps authenticator IDs to their provider implementations.
MFAProviders map[string]MFAProvider
// DefaultMFAChain is applied to clients that don't specify their own mfaChain.
DefaultMFAChain []string
} }
// WebConfig holds the server's frontend templates and asset configuration. // WebConfig holds the server's frontend templates and asset configuration.
@ -226,6 +232,9 @@ type Server struct {
logger *slog.Logger logger *slog.Logger
signer signer.Signer signer signer.Signer
mfaProviders map[string]MFAProvider
defaultMFAChain []string
} }
// NewServer constructs a server from the provided config. // NewServer constructs a server from the provided config.
@ -349,6 +358,8 @@ func newServer(ctx context.Context, c Config) (*Server, error) {
passwordConnector: c.PasswordConnector, passwordConnector: c.PasswordConnector,
logger: c.Logger, logger: c.Logger,
signer: c.Signer, signer: c.Signer,
mfaProviders: c.MFAProviders,
defaultMFAChain: c.DefaultMFAChain,
} }
// Retrieves connector objects in backend storage. This list includes the static connectors // Retrieves connector objects in backend storage. This list includes the static connectors
@ -537,6 +548,7 @@ func newServer(ctx context.Context, c Config) (*Server, error) {
// "authproxy" connector. // "authproxy" connector.
handleFunc("/callback/{connector}", s.handleConnectorCallback) handleFunc("/callback/{connector}", s.handleConnectorCallback)
handleFunc("/approval", s.handleApproval) handleFunc("/approval", s.handleApproval)
handleFunc("/totp/verify", s.handleMFAVerify)
handle("/healthz", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handle("/healthz", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !c.HealthChecker.IsHealthy() { if !c.HealthChecker.IsHealthy() {
s.renderError(r, w, http.StatusInternalServerError, "Health check failed.") s.renderError(r, w, http.StatusInternalServerError, "Health check failed.")

18
server/templates.go

@ -22,6 +22,7 @@ const (
tmplError = "error.html" tmplError = "error.html"
tmplDevice = "device.html" tmplDevice = "device.html"
tmplDeviceSuccess = "device_success.html" tmplDeviceSuccess = "device_success.html"
tmplTOTPVerify = "totp_verify.html"
) )
var requiredTmpls = []string{ var requiredTmpls = []string{
@ -42,6 +43,7 @@ type templates struct {
errorTmpl *template.Template errorTmpl *template.Template
deviceTmpl *template.Template deviceTmpl *template.Template
deviceSuccessTmpl *template.Template deviceSuccessTmpl *template.Template
totpVerifyTmpl *template.Template
} }
type webConfig struct { type webConfig struct {
@ -169,6 +171,7 @@ func loadTemplates(c webConfig, templatesDir string) (*templates, error) {
errorTmpl: tmpls.Lookup(tmplError), errorTmpl: tmpls.Lookup(tmplError),
deviceTmpl: tmpls.Lookup(tmplDevice), deviceTmpl: tmpls.Lookup(tmplDevice),
deviceSuccessTmpl: tmpls.Lookup(tmplDeviceSuccess), deviceSuccessTmpl: tmpls.Lookup(tmplDeviceSuccess),
totpVerifyTmpl: tmpls.Lookup(tmplTOTPVerify),
}, nil }, nil
} }
@ -325,6 +328,21 @@ func (t *templates) approval(r *http.Request, w http.ResponseWriter, authReqID,
return renderTemplate(w, t.approvalTmpl, data) return renderTemplate(w, t.approvalTmpl, data)
} }
func (t *templates) totpVerify(r *http.Request, w http.ResponseWriter, postURL, issuer, connector, qrCode string, lastWasInvalid bool) error {
if lastWasInvalid {
w.WriteHeader(http.StatusUnauthorized)
}
data := struct {
PostURL string
Invalid bool
Issuer string
Connector string
QRCode string
ReqPath string
}{postURL, lastWasInvalid, issuer, connector, qrCode, r.URL.Path}
return renderTemplate(w, t.totpVerifyTmpl, data)
}
func (t *templates) oob(r *http.Request, w http.ResponseWriter, code string) error { func (t *templates) oob(r *http.Request, w http.ResponseWriter, code string) error {
data := struct { data := struct {
Code string Code string

2
storage/ent/client/authrequest.go

@ -32,6 +32,7 @@ func (d *Database) CreateAuthRequest(ctx context.Context, authRequest storage.Au
SetConnectorID(authRequest.ConnectorID). SetConnectorID(authRequest.ConnectorID).
SetConnectorData(authRequest.ConnectorData). SetConnectorData(authRequest.ConnectorData).
SetHmacKey(authRequest.HMACKey). SetHmacKey(authRequest.HMACKey).
SetMfaValidated(authRequest.MFAValidated).
Save(ctx) Save(ctx)
if err != nil { if err != nil {
return convertDBError("create auth request: %w", err) return convertDBError("create auth request: %w", err)
@ -96,6 +97,7 @@ func (d *Database) UpdateAuthRequest(ctx context.Context, id string, updater fun
SetConnectorID(newAuthRequest.ConnectorID). SetConnectorID(newAuthRequest.ConnectorID).
SetConnectorData(newAuthRequest.ConnectorData). SetConnectorData(newAuthRequest.ConnectorData).
SetHmacKey(newAuthRequest.HMACKey). SetHmacKey(newAuthRequest.HMACKey).
SetMfaValidated(newAuthRequest.MFAValidated).
Save(context.TODO()) Save(context.TODO())
if err != nil { if err != nil {
return rollback(tx, "update auth request uploading: %w", err) return rollback(tx, "update auth request uploading: %w", err)

2
storage/ent/client/client.go

@ -17,6 +17,7 @@ func (d *Database) CreateClient(ctx context.Context, client storage.Client) erro
SetRedirectUris(client.RedirectURIs). SetRedirectUris(client.RedirectURIs).
SetTrustedPeers(client.TrustedPeers). SetTrustedPeers(client.TrustedPeers).
SetAllowedConnectors(client.AllowedConnectors). SetAllowedConnectors(client.AllowedConnectors).
SetMfaChain(client.MFAChain).
Save(ctx) Save(ctx)
if err != nil { if err != nil {
return convertDBError("create oauth2 client: %w", err) return convertDBError("create oauth2 client: %w", err)
@ -81,6 +82,7 @@ func (d *Database) UpdateClient(ctx context.Context, id string, updater func(old
SetRedirectUris(newClient.RedirectURIs). SetRedirectUris(newClient.RedirectURIs).
SetTrustedPeers(newClient.TrustedPeers). SetTrustedPeers(newClient.TrustedPeers).
SetAllowedConnectors(newClient.AllowedConnectors). SetAllowedConnectors(newClient.AllowedConnectors).
SetMfaChain(newClient.MFAChain).
Save(ctx) Save(ctx)
if err != nil { if err != nil {
return rollback(tx, "update client uploading: %w", err) return rollback(tx, "update client uploading: %w", err)

16
storage/ent/client/types.go

@ -45,7 +45,8 @@ func toStorageAuthRequest(a *db.AuthRequest) storage.AuthRequest {
CodeChallenge: a.CodeChallenge, CodeChallenge: a.CodeChallenge,
CodeChallengeMethod: a.CodeChallengeMethod, CodeChallengeMethod: a.CodeChallengeMethod,
}, },
HMACKey: a.HmacKey, HMACKey: a.HmacKey,
MFAValidated: a.MfaValidated,
} }
} }
@ -84,6 +85,7 @@ func toStorageClient(c *db.OAuth2Client) storage.Client {
Name: c.Name, Name: c.Name,
LogoURL: c.LogoURL, LogoURL: c.LogoURL,
AllowedConnectors: c.AllowedConnectors, AllowedConnectors: c.AllowedConnectors,
MFAChain: c.MfaChain,
} }
} }
@ -193,6 +195,18 @@ func toStorageUserIdentity(u *db.UserIdentity) storage.UserIdentity {
// Server code assumes this will be non-nil. // Server code assumes this will be non-nil.
s.Consents = make(map[string][]string) s.Consents = make(map[string][]string)
} }
if u.MfaSecrets != nil {
if err := json.Unmarshal(*u.MfaSecrets, &s.MFASecrets); err != nil {
// Correctness of json structure is guaranteed on uploading
panic(err)
}
if s.MFASecrets == nil {
s.MFASecrets = make(map[string]*storage.MFASecret)
}
} else {
s.MFASecrets = make(map[string]*storage.MFASecret)
}
return s return s
} }

19
storage/ent/client/useridentity.go

@ -18,6 +18,14 @@ func (d *Database) CreateUserIdentity(ctx context.Context, identity storage.User
return fmt.Errorf("encode consents user identity: %w", err) return fmt.Errorf("encode consents user identity: %w", err)
} }
if identity.MFASecrets == nil {
identity.MFASecrets = make(map[string]*storage.MFASecret)
}
encodedMFASecrets, err := json.Marshal(identity.MFASecrets)
if err != nil {
return fmt.Errorf("encode mfa secrets user identity: %w", err)
}
id := compositeKeyID(identity.UserID, identity.ConnectorID, d.hasher) id := compositeKeyID(identity.UserID, identity.ConnectorID, d.hasher)
_, err = d.client.UserIdentity.Create(). _, err = d.client.UserIdentity.Create().
SetID(id). SetID(id).
@ -30,6 +38,7 @@ func (d *Database) CreateUserIdentity(ctx context.Context, identity storage.User
SetClaimsEmailVerified(identity.Claims.EmailVerified). SetClaimsEmailVerified(identity.Claims.EmailVerified).
SetClaimsGroups(identity.Claims.Groups). SetClaimsGroups(identity.Claims.Groups).
SetConsents(encodedConsents). SetConsents(encodedConsents).
SetMfaSecrets(encodedMFASecrets).
SetCreatedAt(identity.CreatedAt). SetCreatedAt(identity.CreatedAt).
SetLastLogin(identity.LastLogin). SetLastLogin(identity.LastLogin).
SetBlockedUntil(identity.BlockedUntil). SetBlockedUntil(identity.BlockedUntil).
@ -90,6 +99,15 @@ func (d *Database) UpdateUserIdentity(ctx context.Context, userID string, connec
return rollback(tx, "encode consents user identity: %w", err) return rollback(tx, "encode consents user identity: %w", err)
} }
if newUserIdentity.MFASecrets == nil {
newUserIdentity.MFASecrets = make(map[string]*storage.MFASecret)
}
encodedMFASecrets, err := json.Marshal(newUserIdentity.MFASecrets)
if err != nil {
return rollback(tx, "encode mfa secrets user identity: %w", err)
}
_, err = tx.UserIdentity.UpdateOneID(id). _, err = tx.UserIdentity.UpdateOneID(id).
SetUserID(newUserIdentity.UserID). SetUserID(newUserIdentity.UserID).
SetConnectorID(newUserIdentity.ConnectorID). SetConnectorID(newUserIdentity.ConnectorID).
@ -100,6 +118,7 @@ func (d *Database) UpdateUserIdentity(ctx context.Context, userID string, connec
SetClaimsEmailVerified(newUserIdentity.Claims.EmailVerified). SetClaimsEmailVerified(newUserIdentity.Claims.EmailVerified).
SetClaimsGroups(newUserIdentity.Claims.Groups). SetClaimsGroups(newUserIdentity.Claims.Groups).
SetConsents(encodedConsents). SetConsents(encodedConsents).
SetMfaSecrets(encodedMFASecrets).
SetCreatedAt(newUserIdentity.CreatedAt). SetCreatedAt(newUserIdentity.CreatedAt).
SetLastLogin(newUserIdentity.LastLogin). SetLastLogin(newUserIdentity.LastLogin).
SetBlockedUntil(newUserIdentity.BlockedUntil). SetBlockedUntil(newUserIdentity.BlockedUntil).

15
storage/ent/db/authrequest.go

@ -57,7 +57,9 @@ type AuthRequest struct {
// CodeChallengeMethod holds the value of the "code_challenge_method" field. // CodeChallengeMethod holds the value of the "code_challenge_method" field.
CodeChallengeMethod string `json:"code_challenge_method,omitempty"` CodeChallengeMethod string `json:"code_challenge_method,omitempty"`
// HmacKey holds the value of the "hmac_key" field. // HmacKey holds the value of the "hmac_key" field.
HmacKey []byte `json:"hmac_key,omitempty"` HmacKey []byte `json:"hmac_key,omitempty"`
// MfaValidated holds the value of the "mfa_validated" field.
MfaValidated bool `json:"mfa_validated,omitempty"`
selectValues sql.SelectValues selectValues sql.SelectValues
} }
@ -68,7 +70,7 @@ func (*AuthRequest) scanValues(columns []string) ([]any, error) {
switch columns[i] { switch columns[i] {
case authrequest.FieldScopes, authrequest.FieldResponseTypes, authrequest.FieldClaimsGroups, authrequest.FieldConnectorData, authrequest.FieldHmacKey: case authrequest.FieldScopes, authrequest.FieldResponseTypes, authrequest.FieldClaimsGroups, authrequest.FieldConnectorData, authrequest.FieldHmacKey:
values[i] = new([]byte) values[i] = new([]byte)
case authrequest.FieldForceApprovalPrompt, authrequest.FieldLoggedIn, authrequest.FieldClaimsEmailVerified: case authrequest.FieldForceApprovalPrompt, authrequest.FieldLoggedIn, authrequest.FieldClaimsEmailVerified, authrequest.FieldMfaValidated:
values[i] = new(sql.NullBool) values[i] = new(sql.NullBool)
case authrequest.FieldID, authrequest.FieldClientID, authrequest.FieldRedirectURI, authrequest.FieldNonce, authrequest.FieldState, authrequest.FieldClaimsUserID, authrequest.FieldClaimsUsername, authrequest.FieldClaimsEmail, authrequest.FieldClaimsPreferredUsername, authrequest.FieldConnectorID, authrequest.FieldCodeChallenge, authrequest.FieldCodeChallengeMethod: case authrequest.FieldID, authrequest.FieldClientID, authrequest.FieldRedirectURI, authrequest.FieldNonce, authrequest.FieldState, authrequest.FieldClaimsUserID, authrequest.FieldClaimsUsername, authrequest.FieldClaimsEmail, authrequest.FieldClaimsPreferredUsername, authrequest.FieldConnectorID, authrequest.FieldCodeChallenge, authrequest.FieldCodeChallengeMethod:
values[i] = new(sql.NullString) values[i] = new(sql.NullString)
@ -221,6 +223,12 @@ func (_m *AuthRequest) assignValues(columns []string, values []any) error {
} else if value != nil { } else if value != nil {
_m.HmacKey = *value _m.HmacKey = *value
} }
case authrequest.FieldMfaValidated:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field mfa_validated", values[i])
} else if value.Valid {
_m.MfaValidated = value.Bool
}
default: default:
_m.selectValues.Set(columns[i], values[i]) _m.selectValues.Set(columns[i], values[i])
} }
@ -318,6 +326,9 @@ func (_m *AuthRequest) String() string {
builder.WriteString(", ") builder.WriteString(", ")
builder.WriteString("hmac_key=") builder.WriteString("hmac_key=")
builder.WriteString(fmt.Sprintf("%v", _m.HmacKey)) builder.WriteString(fmt.Sprintf("%v", _m.HmacKey))
builder.WriteString(", ")
builder.WriteString("mfa_validated=")
builder.WriteString(fmt.Sprintf("%v", _m.MfaValidated))
builder.WriteByte(')') builder.WriteByte(')')
return builder.String() return builder.String()
} }

10
storage/ent/db/authrequest/authrequest.go

@ -51,6 +51,8 @@ const (
FieldCodeChallengeMethod = "code_challenge_method" FieldCodeChallengeMethod = "code_challenge_method"
// FieldHmacKey holds the string denoting the hmac_key field in the database. // FieldHmacKey holds the string denoting the hmac_key field in the database.
FieldHmacKey = "hmac_key" FieldHmacKey = "hmac_key"
// FieldMfaValidated holds the string denoting the mfa_validated field in the database.
FieldMfaValidated = "mfa_validated"
// Table holds the table name of the authrequest in the database. // Table holds the table name of the authrequest in the database.
Table = "auth_requests" Table = "auth_requests"
) )
@ -78,6 +80,7 @@ var Columns = []string{
FieldCodeChallenge, FieldCodeChallenge,
FieldCodeChallengeMethod, FieldCodeChallengeMethod,
FieldHmacKey, FieldHmacKey,
FieldMfaValidated,
} }
// ValidColumn reports if the column name is valid (part of the table columns). // ValidColumn reports if the column name is valid (part of the table columns).
@ -97,6 +100,8 @@ var (
DefaultCodeChallenge string DefaultCodeChallenge string
// DefaultCodeChallengeMethod holds the default value on creation for the "code_challenge_method" field. // DefaultCodeChallengeMethod holds the default value on creation for the "code_challenge_method" field.
DefaultCodeChallengeMethod string DefaultCodeChallengeMethod string
// DefaultMfaValidated holds the default value on creation for the "mfa_validated" field.
DefaultMfaValidated bool
// IDValidator is a validator for the "id" field. It is called by the builders before save. // IDValidator is a validator for the "id" field. It is called by the builders before save.
IDValidator func(string) error IDValidator func(string) error
) )
@ -183,3 +188,8 @@ func ByCodeChallenge(opts ...sql.OrderTermOption) OrderOption {
func ByCodeChallengeMethod(opts ...sql.OrderTermOption) OrderOption { func ByCodeChallengeMethod(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCodeChallengeMethod, opts...).ToFunc() return sql.OrderByField(FieldCodeChallengeMethod, opts...).ToFunc()
} }
// ByMfaValidated orders the results by the mfa_validated field.
func ByMfaValidated(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldMfaValidated, opts...).ToFunc()
}

15
storage/ent/db/authrequest/where.go

@ -149,6 +149,11 @@ func HmacKey(v []byte) predicate.AuthRequest {
return predicate.AuthRequest(sql.FieldEQ(FieldHmacKey, v)) return predicate.AuthRequest(sql.FieldEQ(FieldHmacKey, v))
} }
// MfaValidated applies equality check predicate on the "mfa_validated" field. It's identical to MfaValidatedEQ.
func MfaValidated(v bool) predicate.AuthRequest {
return predicate.AuthRequest(sql.FieldEQ(FieldMfaValidated, v))
}
// ClientIDEQ applies the EQ predicate on the "client_id" field. // ClientIDEQ applies the EQ predicate on the "client_id" field.
func ClientIDEQ(v string) predicate.AuthRequest { func ClientIDEQ(v string) predicate.AuthRequest {
return predicate.AuthRequest(sql.FieldEQ(FieldClientID, v)) return predicate.AuthRequest(sql.FieldEQ(FieldClientID, v))
@ -1054,6 +1059,16 @@ func HmacKeyLTE(v []byte) predicate.AuthRequest {
return predicate.AuthRequest(sql.FieldLTE(FieldHmacKey, v)) return predicate.AuthRequest(sql.FieldLTE(FieldHmacKey, v))
} }
// MfaValidatedEQ applies the EQ predicate on the "mfa_validated" field.
func MfaValidatedEQ(v bool) predicate.AuthRequest {
return predicate.AuthRequest(sql.FieldEQ(FieldMfaValidated, v))
}
// MfaValidatedNEQ applies the NEQ predicate on the "mfa_validated" field.
func MfaValidatedNEQ(v bool) predicate.AuthRequest {
return predicate.AuthRequest(sql.FieldNEQ(FieldMfaValidated, v))
}
// And groups predicates with the AND operator between them. // And groups predicates with the AND operator between them.
func And(predicates ...predicate.AuthRequest) predicate.AuthRequest { func And(predicates ...predicate.AuthRequest) predicate.AuthRequest {
return predicate.AuthRequest(sql.AndPredicates(predicates...)) return predicate.AuthRequest(sql.AndPredicates(predicates...))

25
storage/ent/db/authrequest_create.go

@ -164,6 +164,20 @@ func (_c *AuthRequestCreate) SetHmacKey(v []byte) *AuthRequestCreate {
return _c return _c
} }
// SetMfaValidated sets the "mfa_validated" field.
func (_c *AuthRequestCreate) SetMfaValidated(v bool) *AuthRequestCreate {
_c.mutation.SetMfaValidated(v)
return _c
}
// SetNillableMfaValidated sets the "mfa_validated" field if the given value is not nil.
func (_c *AuthRequestCreate) SetNillableMfaValidated(v *bool) *AuthRequestCreate {
if v != nil {
_c.SetMfaValidated(*v)
}
return _c
}
// SetID sets the "id" field. // SetID sets the "id" field.
func (_c *AuthRequestCreate) SetID(v string) *AuthRequestCreate { func (_c *AuthRequestCreate) SetID(v string) *AuthRequestCreate {
_c.mutation.SetID(v) _c.mutation.SetID(v)
@ -217,6 +231,10 @@ func (_c *AuthRequestCreate) defaults() {
v := authrequest.DefaultCodeChallengeMethod v := authrequest.DefaultCodeChallengeMethod
_c.mutation.SetCodeChallengeMethod(v) _c.mutation.SetCodeChallengeMethod(v)
} }
if _, ok := _c.mutation.MfaValidated(); !ok {
v := authrequest.DefaultMfaValidated
_c.mutation.SetMfaValidated(v)
}
} }
// check runs all checks and user-defined validators on the builder. // check runs all checks and user-defined validators on the builder.
@ -269,6 +287,9 @@ func (_c *AuthRequestCreate) check() error {
if _, ok := _c.mutation.HmacKey(); !ok { if _, ok := _c.mutation.HmacKey(); !ok {
return &ValidationError{Name: "hmac_key", err: errors.New(`db: missing required field "AuthRequest.hmac_key"`)} return &ValidationError{Name: "hmac_key", err: errors.New(`db: missing required field "AuthRequest.hmac_key"`)}
} }
if _, ok := _c.mutation.MfaValidated(); !ok {
return &ValidationError{Name: "mfa_validated", err: errors.New(`db: missing required field "AuthRequest.mfa_validated"`)}
}
if v, ok := _c.mutation.ID(); ok { if v, ok := _c.mutation.ID(); ok {
if err := authrequest.IDValidator(v); err != nil { if err := authrequest.IDValidator(v); err != nil {
return &ValidationError{Name: "id", err: fmt.Errorf(`db: validator failed for field "AuthRequest.id": %w`, err)} return &ValidationError{Name: "id", err: fmt.Errorf(`db: validator failed for field "AuthRequest.id": %w`, err)}
@ -389,6 +410,10 @@ func (_c *AuthRequestCreate) createSpec() (*AuthRequest, *sqlgraph.CreateSpec) {
_spec.SetField(authrequest.FieldHmacKey, field.TypeBytes, value) _spec.SetField(authrequest.FieldHmacKey, field.TypeBytes, value)
_node.HmacKey = value _node.HmacKey = value
} }
if value, ok := _c.mutation.MfaValidated(); ok {
_spec.SetField(authrequest.FieldMfaValidated, field.TypeBool, value)
_node.MfaValidated = value
}
return _node, _spec return _node, _spec
} }

34
storage/ent/db/authrequest_update.go

@ -311,6 +311,20 @@ func (_u *AuthRequestUpdate) SetHmacKey(v []byte) *AuthRequestUpdate {
return _u return _u
} }
// SetMfaValidated sets the "mfa_validated" field.
func (_u *AuthRequestUpdate) SetMfaValidated(v bool) *AuthRequestUpdate {
_u.mutation.SetMfaValidated(v)
return _u
}
// SetNillableMfaValidated sets the "mfa_validated" field if the given value is not nil.
func (_u *AuthRequestUpdate) SetNillableMfaValidated(v *bool) *AuthRequestUpdate {
if v != nil {
_u.SetMfaValidated(*v)
}
return _u
}
// Mutation returns the AuthRequestMutation object of the builder. // Mutation returns the AuthRequestMutation object of the builder.
func (_u *AuthRequestUpdate) Mutation() *AuthRequestMutation { func (_u *AuthRequestUpdate) Mutation() *AuthRequestMutation {
return _u.mutation return _u.mutation
@ -439,6 +453,9 @@ func (_u *AuthRequestUpdate) sqlSave(ctx context.Context) (_node int, err error)
if value, ok := _u.mutation.HmacKey(); ok { if value, ok := _u.mutation.HmacKey(); ok {
_spec.SetField(authrequest.FieldHmacKey, field.TypeBytes, value) _spec.SetField(authrequest.FieldHmacKey, field.TypeBytes, value)
} }
if value, ok := _u.mutation.MfaValidated(); ok {
_spec.SetField(authrequest.FieldMfaValidated, field.TypeBool, value)
}
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok { if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{authrequest.Label} err = &NotFoundError{authrequest.Label}
@ -741,6 +758,20 @@ func (_u *AuthRequestUpdateOne) SetHmacKey(v []byte) *AuthRequestUpdateOne {
return _u return _u
} }
// SetMfaValidated sets the "mfa_validated" field.
func (_u *AuthRequestUpdateOne) SetMfaValidated(v bool) *AuthRequestUpdateOne {
_u.mutation.SetMfaValidated(v)
return _u
}
// SetNillableMfaValidated sets the "mfa_validated" field if the given value is not nil.
func (_u *AuthRequestUpdateOne) SetNillableMfaValidated(v *bool) *AuthRequestUpdateOne {
if v != nil {
_u.SetMfaValidated(*v)
}
return _u
}
// Mutation returns the AuthRequestMutation object of the builder. // Mutation returns the AuthRequestMutation object of the builder.
func (_u *AuthRequestUpdateOne) Mutation() *AuthRequestMutation { func (_u *AuthRequestUpdateOne) Mutation() *AuthRequestMutation {
return _u.mutation return _u.mutation
@ -899,6 +930,9 @@ func (_u *AuthRequestUpdateOne) sqlSave(ctx context.Context) (_node *AuthRequest
if value, ok := _u.mutation.HmacKey(); ok { if value, ok := _u.mutation.HmacKey(); ok {
_spec.SetField(authrequest.FieldHmacKey, field.TypeBytes, value) _spec.SetField(authrequest.FieldHmacKey, field.TypeBytes, value)
} }
if value, ok := _u.mutation.MfaValidated(); ok {
_spec.SetField(authrequest.FieldMfaValidated, field.TypeBool, value)
}
_node = &AuthRequest{config: _u.config} _node = &AuthRequest{config: _u.config}
_spec.Assign = _node.assignValues _spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues _spec.ScanValues = _node.scanValues

3
storage/ent/db/migrate/schema.go

@ -56,6 +56,7 @@ var (
{Name: "code_challenge", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, {Name: "code_challenge", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}},
{Name: "code_challenge_method", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, {Name: "code_challenge_method", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}},
{Name: "hmac_key", Type: field.TypeBytes}, {Name: "hmac_key", Type: field.TypeBytes},
{Name: "mfa_validated", Type: field.TypeBool, Default: false},
} }
// AuthRequestsTable holds the schema information for the "auth_requests" table. // AuthRequestsTable holds the schema information for the "auth_requests" table.
AuthRequestsTable = &schema.Table{ AuthRequestsTable = &schema.Table{
@ -151,6 +152,7 @@ var (
{Name: "name", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, {Name: "name", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}},
{Name: "logo_url", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, {Name: "logo_url", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}},
{Name: "allowed_connectors", Type: field.TypeJSON, Nullable: true}, {Name: "allowed_connectors", Type: field.TypeJSON, Nullable: true},
{Name: "mfa_chain", Type: field.TypeJSON, Nullable: true},
} }
// Oauth2clientsTable holds the schema information for the "oauth2clients" table. // Oauth2clientsTable holds the schema information for the "oauth2clients" table.
Oauth2clientsTable = &schema.Table{ Oauth2clientsTable = &schema.Table{
@ -227,6 +229,7 @@ var (
{Name: "claims_email_verified", Type: field.TypeBool, Default: false}, {Name: "claims_email_verified", Type: field.TypeBool, Default: false},
{Name: "claims_groups", Type: field.TypeJSON, Nullable: true}, {Name: "claims_groups", Type: field.TypeJSON, Nullable: true},
{Name: "consents", Type: field.TypeBytes}, {Name: "consents", Type: field.TypeBytes},
{Name: "mfa_secrets", Type: field.TypeBytes, Nullable: true},
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime(3)", "postgres": "timestamptz", "sqlite3": "timestamp"}}, {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime(3)", "postgres": "timestamptz", "sqlite3": "timestamp"}},
{Name: "last_login", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime(3)", "postgres": "timestamptz", "sqlite3": "timestamp"}}, {Name: "last_login", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime(3)", "postgres": "timestamptz", "sqlite3": "timestamp"}},
{Name: "blocked_until", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime(3)", "postgres": "timestamptz", "sqlite3": "timestamp"}}, {Name: "blocked_until", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime(3)", "postgres": "timestamptz", "sqlite3": "timestamp"}},

223
storage/ent/db/mutation.go

@ -1262,6 +1262,7 @@ type AuthRequestMutation struct {
code_challenge *string code_challenge *string
code_challenge_method *string code_challenge_method *string
hmac_key *[]byte hmac_key *[]byte
mfa_validated *bool
clearedFields map[string]struct{} clearedFields map[string]struct{}
done bool done bool
oldValue func(context.Context) (*AuthRequest, error) oldValue func(context.Context) (*AuthRequest, error)
@ -2192,6 +2193,42 @@ func (m *AuthRequestMutation) ResetHmacKey() {
m.hmac_key = nil m.hmac_key = nil
} }
// SetMfaValidated sets the "mfa_validated" field.
func (m *AuthRequestMutation) SetMfaValidated(b bool) {
m.mfa_validated = &b
}
// MfaValidated returns the value of the "mfa_validated" field in the mutation.
func (m *AuthRequestMutation) MfaValidated() (r bool, exists bool) {
v := m.mfa_validated
if v == nil {
return
}
return *v, true
}
// OldMfaValidated returns the old "mfa_validated" field's value of the AuthRequest entity.
// If the AuthRequest object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *AuthRequestMutation) OldMfaValidated(ctx context.Context) (v bool, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldMfaValidated is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldMfaValidated requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldMfaValidated: %w", err)
}
return oldValue.MfaValidated, nil
}
// ResetMfaValidated resets all changes to the "mfa_validated" field.
func (m *AuthRequestMutation) ResetMfaValidated() {
m.mfa_validated = nil
}
// Where appends a list predicates to the AuthRequestMutation builder. // Where appends a list predicates to the AuthRequestMutation builder.
func (m *AuthRequestMutation) Where(ps ...predicate.AuthRequest) { func (m *AuthRequestMutation) Where(ps ...predicate.AuthRequest) {
m.predicates = append(m.predicates, ps...) m.predicates = append(m.predicates, ps...)
@ -2226,7 +2263,7 @@ func (m *AuthRequestMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call // order to get all numeric fields that were incremented/decremented, call
// AddedFields(). // AddedFields().
func (m *AuthRequestMutation) Fields() []string { func (m *AuthRequestMutation) Fields() []string {
fields := make([]string, 0, 20) fields := make([]string, 0, 21)
if m.client_id != nil { if m.client_id != nil {
fields = append(fields, authrequest.FieldClientID) fields = append(fields, authrequest.FieldClientID)
} }
@ -2287,6 +2324,9 @@ func (m *AuthRequestMutation) Fields() []string {
if m.hmac_key != nil { if m.hmac_key != nil {
fields = append(fields, authrequest.FieldHmacKey) fields = append(fields, authrequest.FieldHmacKey)
} }
if m.mfa_validated != nil {
fields = append(fields, authrequest.FieldMfaValidated)
}
return fields return fields
} }
@ -2335,6 +2375,8 @@ func (m *AuthRequestMutation) Field(name string) (ent.Value, bool) {
return m.CodeChallengeMethod() return m.CodeChallengeMethod()
case authrequest.FieldHmacKey: case authrequest.FieldHmacKey:
return m.HmacKey() return m.HmacKey()
case authrequest.FieldMfaValidated:
return m.MfaValidated()
} }
return nil, false return nil, false
} }
@ -2384,6 +2426,8 @@ func (m *AuthRequestMutation) OldField(ctx context.Context, name string) (ent.Va
return m.OldCodeChallengeMethod(ctx) return m.OldCodeChallengeMethod(ctx)
case authrequest.FieldHmacKey: case authrequest.FieldHmacKey:
return m.OldHmacKey(ctx) return m.OldHmacKey(ctx)
case authrequest.FieldMfaValidated:
return m.OldMfaValidated(ctx)
} }
return nil, fmt.Errorf("unknown AuthRequest field %s", name) return nil, fmt.Errorf("unknown AuthRequest field %s", name)
} }
@ -2533,6 +2577,13 @@ func (m *AuthRequestMutation) SetField(name string, value ent.Value) error {
} }
m.SetHmacKey(v) m.SetHmacKey(v)
return nil return nil
case authrequest.FieldMfaValidated:
v, ok := value.(bool)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetMfaValidated(v)
return nil
} }
return fmt.Errorf("unknown AuthRequest field %s", name) return fmt.Errorf("unknown AuthRequest field %s", name)
} }
@ -2669,6 +2720,9 @@ func (m *AuthRequestMutation) ResetField(name string) error {
case authrequest.FieldHmacKey: case authrequest.FieldHmacKey:
m.ResetHmacKey() m.ResetHmacKey()
return nil return nil
case authrequest.FieldMfaValidated:
m.ResetMfaValidated()
return nil
} }
return fmt.Errorf("unknown AuthRequest field %s", name) return fmt.Errorf("unknown AuthRequest field %s", name)
} }
@ -5779,6 +5833,8 @@ type OAuth2ClientMutation struct {
logo_url *string logo_url *string
allowed_connectors *[]string allowed_connectors *[]string
appendallowed_connectors []string appendallowed_connectors []string
mfa_chain *[]string
appendmfa_chain []string
clearedFields map[string]struct{} clearedFields map[string]struct{}
done bool done bool
oldValue func(context.Context) (*OAuth2Client, error) oldValue func(context.Context) (*OAuth2Client, error)
@ -6228,6 +6284,71 @@ func (m *OAuth2ClientMutation) ResetAllowedConnectors() {
delete(m.clearedFields, oauth2client.FieldAllowedConnectors) delete(m.clearedFields, oauth2client.FieldAllowedConnectors)
} }
// SetMfaChain sets the "mfa_chain" field.
func (m *OAuth2ClientMutation) SetMfaChain(s []string) {
m.mfa_chain = &s
m.appendmfa_chain = nil
}
// MfaChain returns the value of the "mfa_chain" field in the mutation.
func (m *OAuth2ClientMutation) MfaChain() (r []string, exists bool) {
v := m.mfa_chain
if v == nil {
return
}
return *v, true
}
// OldMfaChain returns the old "mfa_chain" field's value of the OAuth2Client entity.
// If the OAuth2Client object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *OAuth2ClientMutation) OldMfaChain(ctx context.Context) (v []string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldMfaChain is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldMfaChain requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldMfaChain: %w", err)
}
return oldValue.MfaChain, nil
}
// AppendMfaChain adds s to the "mfa_chain" field.
func (m *OAuth2ClientMutation) AppendMfaChain(s []string) {
m.appendmfa_chain = append(m.appendmfa_chain, s...)
}
// AppendedMfaChain returns the list of values that were appended to the "mfa_chain" field in this mutation.
func (m *OAuth2ClientMutation) AppendedMfaChain() ([]string, bool) {
if len(m.appendmfa_chain) == 0 {
return nil, false
}
return m.appendmfa_chain, true
}
// ClearMfaChain clears the value of the "mfa_chain" field.
func (m *OAuth2ClientMutation) ClearMfaChain() {
m.mfa_chain = nil
m.appendmfa_chain = nil
m.clearedFields[oauth2client.FieldMfaChain] = struct{}{}
}
// MfaChainCleared returns if the "mfa_chain" field was cleared in this mutation.
func (m *OAuth2ClientMutation) MfaChainCleared() bool {
_, ok := m.clearedFields[oauth2client.FieldMfaChain]
return ok
}
// ResetMfaChain resets all changes to the "mfa_chain" field.
func (m *OAuth2ClientMutation) ResetMfaChain() {
m.mfa_chain = nil
m.appendmfa_chain = nil
delete(m.clearedFields, oauth2client.FieldMfaChain)
}
// Where appends a list predicates to the OAuth2ClientMutation builder. // Where appends a list predicates to the OAuth2ClientMutation builder.
func (m *OAuth2ClientMutation) Where(ps ...predicate.OAuth2Client) { func (m *OAuth2ClientMutation) Where(ps ...predicate.OAuth2Client) {
m.predicates = append(m.predicates, ps...) m.predicates = append(m.predicates, ps...)
@ -6262,7 +6383,7 @@ func (m *OAuth2ClientMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call // order to get all numeric fields that were incremented/decremented, call
// AddedFields(). // AddedFields().
func (m *OAuth2ClientMutation) Fields() []string { func (m *OAuth2ClientMutation) Fields() []string {
fields := make([]string, 0, 7) fields := make([]string, 0, 8)
if m.secret != nil { if m.secret != nil {
fields = append(fields, oauth2client.FieldSecret) fields = append(fields, oauth2client.FieldSecret)
} }
@ -6284,6 +6405,9 @@ func (m *OAuth2ClientMutation) Fields() []string {
if m.allowed_connectors != nil { if m.allowed_connectors != nil {
fields = append(fields, oauth2client.FieldAllowedConnectors) fields = append(fields, oauth2client.FieldAllowedConnectors)
} }
if m.mfa_chain != nil {
fields = append(fields, oauth2client.FieldMfaChain)
}
return fields return fields
} }
@ -6306,6 +6430,8 @@ func (m *OAuth2ClientMutation) Field(name string) (ent.Value, bool) {
return m.LogoURL() return m.LogoURL()
case oauth2client.FieldAllowedConnectors: case oauth2client.FieldAllowedConnectors:
return m.AllowedConnectors() return m.AllowedConnectors()
case oauth2client.FieldMfaChain:
return m.MfaChain()
} }
return nil, false return nil, false
} }
@ -6329,6 +6455,8 @@ func (m *OAuth2ClientMutation) OldField(ctx context.Context, name string) (ent.V
return m.OldLogoURL(ctx) return m.OldLogoURL(ctx)
case oauth2client.FieldAllowedConnectors: case oauth2client.FieldAllowedConnectors:
return m.OldAllowedConnectors(ctx) return m.OldAllowedConnectors(ctx)
case oauth2client.FieldMfaChain:
return m.OldMfaChain(ctx)
} }
return nil, fmt.Errorf("unknown OAuth2Client field %s", name) return nil, fmt.Errorf("unknown OAuth2Client field %s", name)
} }
@ -6387,6 +6515,13 @@ func (m *OAuth2ClientMutation) SetField(name string, value ent.Value) error {
} }
m.SetAllowedConnectors(v) m.SetAllowedConnectors(v)
return nil return nil
case oauth2client.FieldMfaChain:
v, ok := value.([]string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetMfaChain(v)
return nil
} }
return fmt.Errorf("unknown OAuth2Client field %s", name) return fmt.Errorf("unknown OAuth2Client field %s", name)
} }
@ -6426,6 +6561,9 @@ func (m *OAuth2ClientMutation) ClearedFields() []string {
if m.FieldCleared(oauth2client.FieldAllowedConnectors) { if m.FieldCleared(oauth2client.FieldAllowedConnectors) {
fields = append(fields, oauth2client.FieldAllowedConnectors) fields = append(fields, oauth2client.FieldAllowedConnectors)
} }
if m.FieldCleared(oauth2client.FieldMfaChain) {
fields = append(fields, oauth2client.FieldMfaChain)
}
return fields return fields
} }
@ -6449,6 +6587,9 @@ func (m *OAuth2ClientMutation) ClearField(name string) error {
case oauth2client.FieldAllowedConnectors: case oauth2client.FieldAllowedConnectors:
m.ClearAllowedConnectors() m.ClearAllowedConnectors()
return nil return nil
case oauth2client.FieldMfaChain:
m.ClearMfaChain()
return nil
} }
return fmt.Errorf("unknown OAuth2Client nullable field %s", name) return fmt.Errorf("unknown OAuth2Client nullable field %s", name)
} }
@ -6478,6 +6619,9 @@ func (m *OAuth2ClientMutation) ResetField(name string) error {
case oauth2client.FieldAllowedConnectors: case oauth2client.FieldAllowedConnectors:
m.ResetAllowedConnectors() m.ResetAllowedConnectors()
return nil return nil
case oauth2client.FieldMfaChain:
m.ResetMfaChain()
return nil
} }
return fmt.Errorf("unknown OAuth2Client field %s", name) return fmt.Errorf("unknown OAuth2Client field %s", name)
} }
@ -9006,6 +9150,7 @@ type UserIdentityMutation struct {
claims_groups *[]string claims_groups *[]string
appendclaims_groups []string appendclaims_groups []string
consents *[]byte consents *[]byte
mfa_secrets *[]byte
created_at *time.Time created_at *time.Time
last_login *time.Time last_login *time.Time
blocked_until *time.Time blocked_until *time.Time
@ -9472,6 +9617,55 @@ func (m *UserIdentityMutation) ResetConsents() {
m.consents = nil m.consents = nil
} }
// SetMfaSecrets sets the "mfa_secrets" field.
func (m *UserIdentityMutation) SetMfaSecrets(b []byte) {
m.mfa_secrets = &b
}
// MfaSecrets returns the value of the "mfa_secrets" field in the mutation.
func (m *UserIdentityMutation) MfaSecrets() (r []byte, exists bool) {
v := m.mfa_secrets
if v == nil {
return
}
return *v, true
}
// OldMfaSecrets returns the old "mfa_secrets" field's value of the UserIdentity entity.
// If the UserIdentity object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UserIdentityMutation) OldMfaSecrets(ctx context.Context) (v *[]byte, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldMfaSecrets is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldMfaSecrets requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldMfaSecrets: %w", err)
}
return oldValue.MfaSecrets, nil
}
// ClearMfaSecrets clears the value of the "mfa_secrets" field.
func (m *UserIdentityMutation) ClearMfaSecrets() {
m.mfa_secrets = nil
m.clearedFields[useridentity.FieldMfaSecrets] = struct{}{}
}
// MfaSecretsCleared returns if the "mfa_secrets" field was cleared in this mutation.
func (m *UserIdentityMutation) MfaSecretsCleared() bool {
_, ok := m.clearedFields[useridentity.FieldMfaSecrets]
return ok
}
// ResetMfaSecrets resets all changes to the "mfa_secrets" field.
func (m *UserIdentityMutation) ResetMfaSecrets() {
m.mfa_secrets = nil
delete(m.clearedFields, useridentity.FieldMfaSecrets)
}
// SetCreatedAt sets the "created_at" field. // SetCreatedAt sets the "created_at" field.
func (m *UserIdentityMutation) SetCreatedAt(t time.Time) { func (m *UserIdentityMutation) SetCreatedAt(t time.Time) {
m.created_at = &t m.created_at = &t
@ -9614,7 +9808,7 @@ func (m *UserIdentityMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call // order to get all numeric fields that were incremented/decremented, call
// AddedFields(). // AddedFields().
func (m *UserIdentityMutation) Fields() []string { func (m *UserIdentityMutation) Fields() []string {
fields := make([]string, 0, 12) fields := make([]string, 0, 13)
if m.user_id != nil { if m.user_id != nil {
fields = append(fields, useridentity.FieldUserID) fields = append(fields, useridentity.FieldUserID)
} }
@ -9642,6 +9836,9 @@ func (m *UserIdentityMutation) Fields() []string {
if m.consents != nil { if m.consents != nil {
fields = append(fields, useridentity.FieldConsents) fields = append(fields, useridentity.FieldConsents)
} }
if m.mfa_secrets != nil {
fields = append(fields, useridentity.FieldMfaSecrets)
}
if m.created_at != nil { if m.created_at != nil {
fields = append(fields, useridentity.FieldCreatedAt) fields = append(fields, useridentity.FieldCreatedAt)
} }
@ -9677,6 +9874,8 @@ func (m *UserIdentityMutation) Field(name string) (ent.Value, bool) {
return m.ClaimsGroups() return m.ClaimsGroups()
case useridentity.FieldConsents: case useridentity.FieldConsents:
return m.Consents() return m.Consents()
case useridentity.FieldMfaSecrets:
return m.MfaSecrets()
case useridentity.FieldCreatedAt: case useridentity.FieldCreatedAt:
return m.CreatedAt() return m.CreatedAt()
case useridentity.FieldLastLogin: case useridentity.FieldLastLogin:
@ -9710,6 +9909,8 @@ func (m *UserIdentityMutation) OldField(ctx context.Context, name string) (ent.V
return m.OldClaimsGroups(ctx) return m.OldClaimsGroups(ctx)
case useridentity.FieldConsents: case useridentity.FieldConsents:
return m.OldConsents(ctx) return m.OldConsents(ctx)
case useridentity.FieldMfaSecrets:
return m.OldMfaSecrets(ctx)
case useridentity.FieldCreatedAt: case useridentity.FieldCreatedAt:
return m.OldCreatedAt(ctx) return m.OldCreatedAt(ctx)
case useridentity.FieldLastLogin: case useridentity.FieldLastLogin:
@ -9788,6 +9989,13 @@ func (m *UserIdentityMutation) SetField(name string, value ent.Value) error {
} }
m.SetConsents(v) m.SetConsents(v)
return nil return nil
case useridentity.FieldMfaSecrets:
v, ok := value.([]byte)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetMfaSecrets(v)
return nil
case useridentity.FieldCreatedAt: case useridentity.FieldCreatedAt:
v, ok := value.(time.Time) v, ok := value.(time.Time)
if !ok { if !ok {
@ -9842,6 +10050,9 @@ func (m *UserIdentityMutation) ClearedFields() []string {
if m.FieldCleared(useridentity.FieldClaimsGroups) { if m.FieldCleared(useridentity.FieldClaimsGroups) {
fields = append(fields, useridentity.FieldClaimsGroups) fields = append(fields, useridentity.FieldClaimsGroups)
} }
if m.FieldCleared(useridentity.FieldMfaSecrets) {
fields = append(fields, useridentity.FieldMfaSecrets)
}
return fields return fields
} }
@ -9859,6 +10070,9 @@ func (m *UserIdentityMutation) ClearField(name string) error {
case useridentity.FieldClaimsGroups: case useridentity.FieldClaimsGroups:
m.ClearClaimsGroups() m.ClearClaimsGroups()
return nil return nil
case useridentity.FieldMfaSecrets:
m.ClearMfaSecrets()
return nil
} }
return fmt.Errorf("unknown UserIdentity nullable field %s", name) return fmt.Errorf("unknown UserIdentity nullable field %s", name)
} }
@ -9894,6 +10108,9 @@ func (m *UserIdentityMutation) ResetField(name string) error {
case useridentity.FieldConsents: case useridentity.FieldConsents:
m.ResetConsents() m.ResetConsents()
return nil return nil
case useridentity.FieldMfaSecrets:
m.ResetMfaSecrets()
return nil
case useridentity.FieldCreatedAt: case useridentity.FieldCreatedAt:
m.ResetCreatedAt() m.ResetCreatedAt()
return nil return nil

17
storage/ent/db/oauth2client.go

@ -31,7 +31,9 @@ type OAuth2Client struct {
LogoURL string `json:"logo_url,omitempty"` LogoURL string `json:"logo_url,omitempty"`
// AllowedConnectors holds the value of the "allowed_connectors" field. // AllowedConnectors holds the value of the "allowed_connectors" field.
AllowedConnectors []string `json:"allowed_connectors,omitempty"` AllowedConnectors []string `json:"allowed_connectors,omitempty"`
selectValues sql.SelectValues // MfaChain holds the value of the "mfa_chain" field.
MfaChain []string `json:"mfa_chain,omitempty"`
selectValues sql.SelectValues
} }
// scanValues returns the types for scanning values from sql.Rows. // scanValues returns the types for scanning values from sql.Rows.
@ -39,7 +41,7 @@ func (*OAuth2Client) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns)) values := make([]any, len(columns))
for i := range columns { for i := range columns {
switch columns[i] { switch columns[i] {
case oauth2client.FieldRedirectUris, oauth2client.FieldTrustedPeers, oauth2client.FieldAllowedConnectors: case oauth2client.FieldRedirectUris, oauth2client.FieldTrustedPeers, oauth2client.FieldAllowedConnectors, oauth2client.FieldMfaChain:
values[i] = new([]byte) values[i] = new([]byte)
case oauth2client.FieldPublic: case oauth2client.FieldPublic:
values[i] = new(sql.NullBool) values[i] = new(sql.NullBool)
@ -114,6 +116,14 @@ func (_m *OAuth2Client) assignValues(columns []string, values []any) error {
return fmt.Errorf("unmarshal field allowed_connectors: %w", err) return fmt.Errorf("unmarshal field allowed_connectors: %w", err)
} }
} }
case oauth2client.FieldMfaChain:
if value, ok := values[i].(*[]byte); !ok {
return fmt.Errorf("unexpected type %T for field mfa_chain", values[i])
} else if value != nil && len(*value) > 0 {
if err := json.Unmarshal(*value, &_m.MfaChain); err != nil {
return fmt.Errorf("unmarshal field mfa_chain: %w", err)
}
}
default: default:
_m.selectValues.Set(columns[i], values[i]) _m.selectValues.Set(columns[i], values[i])
} }
@ -170,6 +180,9 @@ func (_m *OAuth2Client) String() string {
builder.WriteString(", ") builder.WriteString(", ")
builder.WriteString("allowed_connectors=") builder.WriteString("allowed_connectors=")
builder.WriteString(fmt.Sprintf("%v", _m.AllowedConnectors)) builder.WriteString(fmt.Sprintf("%v", _m.AllowedConnectors))
builder.WriteString(", ")
builder.WriteString("mfa_chain=")
builder.WriteString(fmt.Sprintf("%v", _m.MfaChain))
builder.WriteByte(')') builder.WriteByte(')')
return builder.String() return builder.String()
} }

3
storage/ent/db/oauth2client/oauth2client.go

@ -25,6 +25,8 @@ const (
FieldLogoURL = "logo_url" FieldLogoURL = "logo_url"
// FieldAllowedConnectors holds the string denoting the allowed_connectors field in the database. // FieldAllowedConnectors holds the string denoting the allowed_connectors field in the database.
FieldAllowedConnectors = "allowed_connectors" FieldAllowedConnectors = "allowed_connectors"
// FieldMfaChain holds the string denoting the mfa_chain field in the database.
FieldMfaChain = "mfa_chain"
// Table holds the table name of the oauth2client in the database. // Table holds the table name of the oauth2client in the database.
Table = "oauth2clients" Table = "oauth2clients"
) )
@ -39,6 +41,7 @@ var Columns = []string{
FieldName, FieldName,
FieldLogoURL, FieldLogoURL,
FieldAllowedConnectors, FieldAllowedConnectors,
FieldMfaChain,
} }
// ValidColumn reports if the column name is valid (part of the table columns). // ValidColumn reports if the column name is valid (part of the table columns).

10
storage/ent/db/oauth2client/where.go

@ -317,6 +317,16 @@ func AllowedConnectorsNotNil() predicate.OAuth2Client {
return predicate.OAuth2Client(sql.FieldNotNull(FieldAllowedConnectors)) return predicate.OAuth2Client(sql.FieldNotNull(FieldAllowedConnectors))
} }
// MfaChainIsNil applies the IsNil predicate on the "mfa_chain" field.
func MfaChainIsNil() predicate.OAuth2Client {
return predicate.OAuth2Client(sql.FieldIsNull(FieldMfaChain))
}
// MfaChainNotNil applies the NotNil predicate on the "mfa_chain" field.
func MfaChainNotNil() predicate.OAuth2Client {
return predicate.OAuth2Client(sql.FieldNotNull(FieldMfaChain))
}
// And groups predicates with the AND operator between them. // And groups predicates with the AND operator between them.
func And(predicates ...predicate.OAuth2Client) predicate.OAuth2Client { func And(predicates ...predicate.OAuth2Client) predicate.OAuth2Client {
return predicate.OAuth2Client(sql.AndPredicates(predicates...)) return predicate.OAuth2Client(sql.AndPredicates(predicates...))

10
storage/ent/db/oauth2client_create.go

@ -61,6 +61,12 @@ func (_c *OAuth2ClientCreate) SetAllowedConnectors(v []string) *OAuth2ClientCrea
return _c return _c
} }
// SetMfaChain sets the "mfa_chain" field.
func (_c *OAuth2ClientCreate) SetMfaChain(v []string) *OAuth2ClientCreate {
_c.mutation.SetMfaChain(v)
return _c
}
// SetID sets the "id" field. // SetID sets the "id" field.
func (_c *OAuth2ClientCreate) SetID(v string) *OAuth2ClientCreate { func (_c *OAuth2ClientCreate) SetID(v string) *OAuth2ClientCreate {
_c.mutation.SetID(v) _c.mutation.SetID(v)
@ -196,6 +202,10 @@ func (_c *OAuth2ClientCreate) createSpec() (*OAuth2Client, *sqlgraph.CreateSpec)
_spec.SetField(oauth2client.FieldAllowedConnectors, field.TypeJSON, value) _spec.SetField(oauth2client.FieldAllowedConnectors, field.TypeJSON, value)
_node.AllowedConnectors = value _node.AllowedConnectors = value
} }
if value, ok := _c.mutation.MfaChain(); ok {
_spec.SetField(oauth2client.FieldMfaChain, field.TypeJSON, value)
_node.MfaChain = value
}
return _node, _spec return _node, _spec
} }

58
storage/ent/db/oauth2client_update.go

@ -138,6 +138,24 @@ func (_u *OAuth2ClientUpdate) ClearAllowedConnectors() *OAuth2ClientUpdate {
return _u return _u
} }
// SetMfaChain sets the "mfa_chain" field.
func (_u *OAuth2ClientUpdate) SetMfaChain(v []string) *OAuth2ClientUpdate {
_u.mutation.SetMfaChain(v)
return _u
}
// AppendMfaChain appends value to the "mfa_chain" field.
func (_u *OAuth2ClientUpdate) AppendMfaChain(v []string) *OAuth2ClientUpdate {
_u.mutation.AppendMfaChain(v)
return _u
}
// ClearMfaChain clears the value of the "mfa_chain" field.
func (_u *OAuth2ClientUpdate) ClearMfaChain() *OAuth2ClientUpdate {
_u.mutation.ClearMfaChain()
return _u
}
// Mutation returns the OAuth2ClientMutation object of the builder. // Mutation returns the OAuth2ClientMutation object of the builder.
func (_u *OAuth2ClientUpdate) Mutation() *OAuth2ClientMutation { func (_u *OAuth2ClientUpdate) Mutation() *OAuth2ClientMutation {
return _u.mutation return _u.mutation
@ -247,6 +265,17 @@ func (_u *OAuth2ClientUpdate) sqlSave(ctx context.Context) (_node int, err error
if _u.mutation.AllowedConnectorsCleared() { if _u.mutation.AllowedConnectorsCleared() {
_spec.ClearField(oauth2client.FieldAllowedConnectors, field.TypeJSON) _spec.ClearField(oauth2client.FieldAllowedConnectors, field.TypeJSON)
} }
if value, ok := _u.mutation.MfaChain(); ok {
_spec.SetField(oauth2client.FieldMfaChain, field.TypeJSON, value)
}
if value, ok := _u.mutation.AppendedMfaChain(); ok {
_spec.AddModifier(func(u *sql.UpdateBuilder) {
sqljson.Append(u, oauth2client.FieldMfaChain, value)
})
}
if _u.mutation.MfaChainCleared() {
_spec.ClearField(oauth2client.FieldMfaChain, field.TypeJSON)
}
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok { if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{oauth2client.Label} err = &NotFoundError{oauth2client.Label}
@ -377,6 +406,24 @@ func (_u *OAuth2ClientUpdateOne) ClearAllowedConnectors() *OAuth2ClientUpdateOne
return _u return _u
} }
// SetMfaChain sets the "mfa_chain" field.
func (_u *OAuth2ClientUpdateOne) SetMfaChain(v []string) *OAuth2ClientUpdateOne {
_u.mutation.SetMfaChain(v)
return _u
}
// AppendMfaChain appends value to the "mfa_chain" field.
func (_u *OAuth2ClientUpdateOne) AppendMfaChain(v []string) *OAuth2ClientUpdateOne {
_u.mutation.AppendMfaChain(v)
return _u
}
// ClearMfaChain clears the value of the "mfa_chain" field.
func (_u *OAuth2ClientUpdateOne) ClearMfaChain() *OAuth2ClientUpdateOne {
_u.mutation.ClearMfaChain()
return _u
}
// Mutation returns the OAuth2ClientMutation object of the builder. // Mutation returns the OAuth2ClientMutation object of the builder.
func (_u *OAuth2ClientUpdateOne) Mutation() *OAuth2ClientMutation { func (_u *OAuth2ClientUpdateOne) Mutation() *OAuth2ClientMutation {
return _u.mutation return _u.mutation
@ -516,6 +563,17 @@ func (_u *OAuth2ClientUpdateOne) sqlSave(ctx context.Context) (_node *OAuth2Clie
if _u.mutation.AllowedConnectorsCleared() { if _u.mutation.AllowedConnectorsCleared() {
_spec.ClearField(oauth2client.FieldAllowedConnectors, field.TypeJSON) _spec.ClearField(oauth2client.FieldAllowedConnectors, field.TypeJSON)
} }
if value, ok := _u.mutation.MfaChain(); ok {
_spec.SetField(oauth2client.FieldMfaChain, field.TypeJSON, value)
}
if value, ok := _u.mutation.AppendedMfaChain(); ok {
_spec.AddModifier(func(u *sql.UpdateBuilder) {
sqljson.Append(u, oauth2client.FieldMfaChain, value)
})
}
if _u.mutation.MfaChainCleared() {
_spec.ClearField(oauth2client.FieldMfaChain, field.TypeJSON)
}
_node = &OAuth2Client{config: _u.config} _node = &OAuth2Client{config: _u.config}
_spec.Assign = _node.assignValues _spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues _spec.ScanValues = _node.scanValues

4
storage/ent/db/runtime.go

@ -84,6 +84,10 @@ func init() {
authrequestDescCodeChallengeMethod := authrequestFields[19].Descriptor() authrequestDescCodeChallengeMethod := authrequestFields[19].Descriptor()
// authrequest.DefaultCodeChallengeMethod holds the default value on creation for the code_challenge_method field. // authrequest.DefaultCodeChallengeMethod holds the default value on creation for the code_challenge_method field.
authrequest.DefaultCodeChallengeMethod = authrequestDescCodeChallengeMethod.Default.(string) authrequest.DefaultCodeChallengeMethod = authrequestDescCodeChallengeMethod.Default.(string)
// authrequestDescMfaValidated is the schema descriptor for mfa_validated field.
authrequestDescMfaValidated := authrequestFields[21].Descriptor()
// authrequest.DefaultMfaValidated holds the default value on creation for the mfa_validated field.
authrequest.DefaultMfaValidated = authrequestDescMfaValidated.Default.(bool)
// authrequestDescID is the schema descriptor for id field. // authrequestDescID is the schema descriptor for id field.
authrequestDescID := authrequestFields[0].Descriptor() authrequestDescID := authrequestFields[0].Descriptor()
// authrequest.IDValidator is a validator for the "id" field. It is called by the builders before save. // authrequest.IDValidator is a validator for the "id" field. It is called by the builders before save.

15
storage/ent/db/useridentity.go

@ -36,6 +36,8 @@ type UserIdentity struct {
ClaimsGroups []string `json:"claims_groups,omitempty"` ClaimsGroups []string `json:"claims_groups,omitempty"`
// Consents holds the value of the "consents" field. // Consents holds the value of the "consents" field.
Consents []byte `json:"consents,omitempty"` Consents []byte `json:"consents,omitempty"`
// MfaSecrets holds the value of the "mfa_secrets" field.
MfaSecrets *[]byte `json:"mfa_secrets,omitempty"`
// CreatedAt holds the value of the "created_at" field. // CreatedAt holds the value of the "created_at" field.
CreatedAt time.Time `json:"created_at,omitempty"` CreatedAt time.Time `json:"created_at,omitempty"`
// LastLogin holds the value of the "last_login" field. // LastLogin holds the value of the "last_login" field.
@ -50,7 +52,7 @@ func (*UserIdentity) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns)) values := make([]any, len(columns))
for i := range columns { for i := range columns {
switch columns[i] { switch columns[i] {
case useridentity.FieldClaimsGroups, useridentity.FieldConsents: case useridentity.FieldClaimsGroups, useridentity.FieldConsents, useridentity.FieldMfaSecrets:
values[i] = new([]byte) values[i] = new([]byte)
case useridentity.FieldClaimsEmailVerified: case useridentity.FieldClaimsEmailVerified:
values[i] = new(sql.NullBool) values[i] = new(sql.NullBool)
@ -135,6 +137,12 @@ func (_m *UserIdentity) assignValues(columns []string, values []any) error {
} else if value != nil { } else if value != nil {
_m.Consents = *value _m.Consents = *value
} }
case useridentity.FieldMfaSecrets:
if value, ok := values[i].(*[]byte); !ok {
return fmt.Errorf("unexpected type %T for field mfa_secrets", values[i])
} else if value != nil {
_m.MfaSecrets = value
}
case useridentity.FieldCreatedAt: case useridentity.FieldCreatedAt:
if value, ok := values[i].(*sql.NullTime); !ok { if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field created_at", values[i]) return fmt.Errorf("unexpected type %T for field created_at", values[i])
@ -216,6 +224,11 @@ func (_m *UserIdentity) String() string {
builder.WriteString("consents=") builder.WriteString("consents=")
builder.WriteString(fmt.Sprintf("%v", _m.Consents)) builder.WriteString(fmt.Sprintf("%v", _m.Consents))
builder.WriteString(", ") builder.WriteString(", ")
if v := _m.MfaSecrets; v != nil {
builder.WriteString("mfa_secrets=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
builder.WriteString("created_at=") builder.WriteString("created_at=")
builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
builder.WriteString(", ") builder.WriteString(", ")

3
storage/ent/db/useridentity/useridentity.go

@ -29,6 +29,8 @@ const (
FieldClaimsGroups = "claims_groups" FieldClaimsGroups = "claims_groups"
// FieldConsents holds the string denoting the consents field in the database. // FieldConsents holds the string denoting the consents field in the database.
FieldConsents = "consents" FieldConsents = "consents"
// FieldMfaSecrets holds the string denoting the mfa_secrets field in the database.
FieldMfaSecrets = "mfa_secrets"
// FieldCreatedAt holds the string denoting the created_at field in the database. // FieldCreatedAt holds the string denoting the created_at field in the database.
FieldCreatedAt = "created_at" FieldCreatedAt = "created_at"
// FieldLastLogin holds the string denoting the last_login field in the database. // FieldLastLogin holds the string denoting the last_login field in the database.
@ -51,6 +53,7 @@ var Columns = []string{
FieldClaimsEmailVerified, FieldClaimsEmailVerified,
FieldClaimsGroups, FieldClaimsGroups,
FieldConsents, FieldConsents,
FieldMfaSecrets,
FieldCreatedAt, FieldCreatedAt,
FieldLastLogin, FieldLastLogin,
FieldBlockedUntil, FieldBlockedUntil,

55
storage/ent/db/useridentity/where.go

@ -104,6 +104,11 @@ func Consents(v []byte) predicate.UserIdentity {
return predicate.UserIdentity(sql.FieldEQ(FieldConsents, v)) return predicate.UserIdentity(sql.FieldEQ(FieldConsents, v))
} }
// MfaSecrets applies equality check predicate on the "mfa_secrets" field. It's identical to MfaSecretsEQ.
func MfaSecrets(v []byte) predicate.UserIdentity {
return predicate.UserIdentity(sql.FieldEQ(FieldMfaSecrets, v))
}
// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. // CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
func CreatedAt(v time.Time) predicate.UserIdentity { func CreatedAt(v time.Time) predicate.UserIdentity {
return predicate.UserIdentity(sql.FieldEQ(FieldCreatedAt, v)) return predicate.UserIdentity(sql.FieldEQ(FieldCreatedAt, v))
@ -569,6 +574,56 @@ func ConsentsLTE(v []byte) predicate.UserIdentity {
return predicate.UserIdentity(sql.FieldLTE(FieldConsents, v)) return predicate.UserIdentity(sql.FieldLTE(FieldConsents, v))
} }
// MfaSecretsEQ applies the EQ predicate on the "mfa_secrets" field.
func MfaSecretsEQ(v []byte) predicate.UserIdentity {
return predicate.UserIdentity(sql.FieldEQ(FieldMfaSecrets, v))
}
// MfaSecretsNEQ applies the NEQ predicate on the "mfa_secrets" field.
func MfaSecretsNEQ(v []byte) predicate.UserIdentity {
return predicate.UserIdentity(sql.FieldNEQ(FieldMfaSecrets, v))
}
// MfaSecretsIn applies the In predicate on the "mfa_secrets" field.
func MfaSecretsIn(vs ...[]byte) predicate.UserIdentity {
return predicate.UserIdentity(sql.FieldIn(FieldMfaSecrets, vs...))
}
// MfaSecretsNotIn applies the NotIn predicate on the "mfa_secrets" field.
func MfaSecretsNotIn(vs ...[]byte) predicate.UserIdentity {
return predicate.UserIdentity(sql.FieldNotIn(FieldMfaSecrets, vs...))
}
// MfaSecretsGT applies the GT predicate on the "mfa_secrets" field.
func MfaSecretsGT(v []byte) predicate.UserIdentity {
return predicate.UserIdentity(sql.FieldGT(FieldMfaSecrets, v))
}
// MfaSecretsGTE applies the GTE predicate on the "mfa_secrets" field.
func MfaSecretsGTE(v []byte) predicate.UserIdentity {
return predicate.UserIdentity(sql.FieldGTE(FieldMfaSecrets, v))
}
// MfaSecretsLT applies the LT predicate on the "mfa_secrets" field.
func MfaSecretsLT(v []byte) predicate.UserIdentity {
return predicate.UserIdentity(sql.FieldLT(FieldMfaSecrets, v))
}
// MfaSecretsLTE applies the LTE predicate on the "mfa_secrets" field.
func MfaSecretsLTE(v []byte) predicate.UserIdentity {
return predicate.UserIdentity(sql.FieldLTE(FieldMfaSecrets, v))
}
// MfaSecretsIsNil applies the IsNil predicate on the "mfa_secrets" field.
func MfaSecretsIsNil() predicate.UserIdentity {
return predicate.UserIdentity(sql.FieldIsNull(FieldMfaSecrets))
}
// MfaSecretsNotNil applies the NotNil predicate on the "mfa_secrets" field.
func MfaSecretsNotNil() predicate.UserIdentity {
return predicate.UserIdentity(sql.FieldNotNull(FieldMfaSecrets))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field. // CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.UserIdentity { func CreatedAtEQ(v time.Time) predicate.UserIdentity {
return predicate.UserIdentity(sql.FieldEQ(FieldCreatedAt, v)) return predicate.UserIdentity(sql.FieldEQ(FieldCreatedAt, v))

10
storage/ent/db/useridentity_create.go

@ -114,6 +114,12 @@ func (_c *UserIdentityCreate) SetConsents(v []byte) *UserIdentityCreate {
return _c return _c
} }
// SetMfaSecrets sets the "mfa_secrets" field.
func (_c *UserIdentityCreate) SetMfaSecrets(v []byte) *UserIdentityCreate {
_c.mutation.SetMfaSecrets(v)
return _c
}
// SetCreatedAt sets the "created_at" field. // SetCreatedAt sets the "created_at" field.
func (_c *UserIdentityCreate) SetCreatedAt(v time.Time) *UserIdentityCreate { func (_c *UserIdentityCreate) SetCreatedAt(v time.Time) *UserIdentityCreate {
_c.mutation.SetCreatedAt(v) _c.mutation.SetCreatedAt(v)
@ -316,6 +322,10 @@ func (_c *UserIdentityCreate) createSpec() (*UserIdentity, *sqlgraph.CreateSpec)
_spec.SetField(useridentity.FieldConsents, field.TypeBytes, value) _spec.SetField(useridentity.FieldConsents, field.TypeBytes, value)
_node.Consents = value _node.Consents = value
} }
if value, ok := _c.mutation.MfaSecrets(); ok {
_spec.SetField(useridentity.FieldMfaSecrets, field.TypeBytes, value)
_node.MfaSecrets = &value
}
if value, ok := _c.mutation.CreatedAt(); ok { if value, ok := _c.mutation.CreatedAt(); ok {
_spec.SetField(useridentity.FieldCreatedAt, field.TypeTime, value) _spec.SetField(useridentity.FieldCreatedAt, field.TypeTime, value)
_node.CreatedAt = value _node.CreatedAt = value

36
storage/ent/db/useridentity_update.go

@ -151,6 +151,18 @@ func (_u *UserIdentityUpdate) SetConsents(v []byte) *UserIdentityUpdate {
return _u return _u
} }
// SetMfaSecrets sets the "mfa_secrets" field.
func (_u *UserIdentityUpdate) SetMfaSecrets(v []byte) *UserIdentityUpdate {
_u.mutation.SetMfaSecrets(v)
return _u
}
// ClearMfaSecrets clears the value of the "mfa_secrets" field.
func (_u *UserIdentityUpdate) ClearMfaSecrets() *UserIdentityUpdate {
_u.mutation.ClearMfaSecrets()
return _u
}
// SetCreatedAt sets the "created_at" field. // SetCreatedAt sets the "created_at" field.
func (_u *UserIdentityUpdate) SetCreatedAt(v time.Time) *UserIdentityUpdate { func (_u *UserIdentityUpdate) SetCreatedAt(v time.Time) *UserIdentityUpdate {
_u.mutation.SetCreatedAt(v) _u.mutation.SetCreatedAt(v)
@ -287,6 +299,12 @@ func (_u *UserIdentityUpdate) sqlSave(ctx context.Context) (_node int, err error
if value, ok := _u.mutation.Consents(); ok { if value, ok := _u.mutation.Consents(); ok {
_spec.SetField(useridentity.FieldConsents, field.TypeBytes, value) _spec.SetField(useridentity.FieldConsents, field.TypeBytes, value)
} }
if value, ok := _u.mutation.MfaSecrets(); ok {
_spec.SetField(useridentity.FieldMfaSecrets, field.TypeBytes, value)
}
if _u.mutation.MfaSecretsCleared() {
_spec.ClearField(useridentity.FieldMfaSecrets, field.TypeBytes)
}
if value, ok := _u.mutation.CreatedAt(); ok { if value, ok := _u.mutation.CreatedAt(); ok {
_spec.SetField(useridentity.FieldCreatedAt, field.TypeTime, value) _spec.SetField(useridentity.FieldCreatedAt, field.TypeTime, value)
} }
@ -438,6 +456,18 @@ func (_u *UserIdentityUpdateOne) SetConsents(v []byte) *UserIdentityUpdateOne {
return _u return _u
} }
// SetMfaSecrets sets the "mfa_secrets" field.
func (_u *UserIdentityUpdateOne) SetMfaSecrets(v []byte) *UserIdentityUpdateOne {
_u.mutation.SetMfaSecrets(v)
return _u
}
// ClearMfaSecrets clears the value of the "mfa_secrets" field.
func (_u *UserIdentityUpdateOne) ClearMfaSecrets() *UserIdentityUpdateOne {
_u.mutation.ClearMfaSecrets()
return _u
}
// SetCreatedAt sets the "created_at" field. // SetCreatedAt sets the "created_at" field.
func (_u *UserIdentityUpdateOne) SetCreatedAt(v time.Time) *UserIdentityUpdateOne { func (_u *UserIdentityUpdateOne) SetCreatedAt(v time.Time) *UserIdentityUpdateOne {
_u.mutation.SetCreatedAt(v) _u.mutation.SetCreatedAt(v)
@ -604,6 +634,12 @@ func (_u *UserIdentityUpdateOne) sqlSave(ctx context.Context) (_node *UserIdenti
if value, ok := _u.mutation.Consents(); ok { if value, ok := _u.mutation.Consents(); ok {
_spec.SetField(useridentity.FieldConsents, field.TypeBytes, value) _spec.SetField(useridentity.FieldConsents, field.TypeBytes, value)
} }
if value, ok := _u.mutation.MfaSecrets(); ok {
_spec.SetField(useridentity.FieldMfaSecrets, field.TypeBytes, value)
}
if _u.mutation.MfaSecretsCleared() {
_spec.ClearField(useridentity.FieldMfaSecrets, field.TypeBytes)
}
if value, ok := _u.mutation.CreatedAt(); ok { if value, ok := _u.mutation.CreatedAt(); ok {
_spec.SetField(useridentity.FieldCreatedAt, field.TypeTime, value) _spec.SetField(useridentity.FieldCreatedAt, field.TypeTime, value)
} }

2
storage/ent/schema/authrequest.go

@ -88,6 +88,8 @@ func (AuthRequest) Fields() []ent.Field {
SchemaType(textSchema). SchemaType(textSchema).
Default(""), Default(""),
field.Bytes("hmac_key"), field.Bytes("hmac_key"),
field.Bool("mfa_validated").
Default(false),
} }
} }

2
storage/ent/schema/client.go

@ -47,6 +47,8 @@ func (OAuth2Client) Fields() []ent.Field {
NotEmpty(), NotEmpty(),
field.JSON("allowed_connectors", []string{}). field.JSON("allowed_connectors", []string{}).
Optional(), Optional(),
field.JSON("mfa_chain", []string{}).
Optional(),
} }
} }

3
storage/ent/schema/useridentity.go

@ -41,6 +41,9 @@ func (UserIdentity) Fields() []ent.Field {
field.JSON("claims_groups", []string{}). field.JSON("claims_groups", []string{}).
Optional(), Optional(),
field.Bytes("consents"), field.Bytes("consents"),
field.Bytes("mfa_secrets").
Nillable().
Optional(),
field.Time("created_at"). field.Time("created_at").
SchemaType(timeSchema), SchemaType(timeSchema),
field.Time("last_login"). field.Time("last_login").

23
storage/etcd/types.go

@ -86,6 +86,8 @@ type AuthRequest struct {
CodeChallengeMethod string `json:"code_challenge_method,omitempty"` CodeChallengeMethod string `json:"code_challenge_method,omitempty"`
HMACKey []byte `json:"hmac_key"` HMACKey []byte `json:"hmac_key"`
MFAValidated bool `json:"mfa_validated"`
} }
func fromStorageAuthRequest(a storage.AuthRequest) AuthRequest { func fromStorageAuthRequest(a storage.AuthRequest) AuthRequest {
@ -106,6 +108,7 @@ func fromStorageAuthRequest(a storage.AuthRequest) AuthRequest {
CodeChallenge: a.PKCE.CodeChallenge, CodeChallenge: a.PKCE.CodeChallenge,
CodeChallengeMethod: a.PKCE.CodeChallengeMethod, CodeChallengeMethod: a.PKCE.CodeChallengeMethod,
HMACKey: a.HMACKey, HMACKey: a.HMACKey,
MFAValidated: a.MFAValidated,
} }
} }
@ -128,7 +131,8 @@ func toStorageAuthRequest(a AuthRequest) storage.AuthRequest {
CodeChallenge: a.CodeChallenge, CodeChallenge: a.CodeChallenge,
CodeChallengeMethod: a.CodeChallengeMethod, CodeChallengeMethod: a.CodeChallengeMethod,
}, },
HMACKey: a.HMACKey, HMACKey: a.HMACKey,
MFAValidated: a.MFAValidated,
} }
} }
@ -258,13 +262,14 @@ func toStorageOfflineSessions(o OfflineSessions) storage.OfflineSessions {
// UserIdentity is a mirrored struct from storage with JSON struct tags // UserIdentity is a mirrored struct from storage with JSON struct tags
type UserIdentity struct { type UserIdentity struct {
UserID string `json:"user_id,omitempty"` UserID string `json:"user_id,omitempty"`
ConnectorID string `json:"connector_id,omitempty"` ConnectorID string `json:"connector_id,omitempty"`
Claims Claims `json:"claims,omitempty"` Claims Claims `json:"claims,omitempty"`
Consents map[string][]string `json:"consents,omitempty"` Consents map[string][]string `json:"consents,omitempty"`
CreatedAt time.Time `json:"created_at"` MFASecrets map[string]*storage.MFASecret `json:"mfa_secrets,omitempty"`
LastLogin time.Time `json:"last_login"` CreatedAt time.Time `json:"created_at"`
BlockedUntil time.Time `json:"blocked_until"` LastLogin time.Time `json:"last_login"`
BlockedUntil time.Time `json:"blocked_until"`
} }
func fromStorageUserIdentity(u storage.UserIdentity) UserIdentity { func fromStorageUserIdentity(u storage.UserIdentity) UserIdentity {
@ -273,6 +278,7 @@ func fromStorageUserIdentity(u storage.UserIdentity) UserIdentity {
ConnectorID: u.ConnectorID, ConnectorID: u.ConnectorID,
Claims: fromStorageClaims(u.Claims), Claims: fromStorageClaims(u.Claims),
Consents: u.Consents, Consents: u.Consents,
MFASecrets: u.MFASecrets,
CreatedAt: u.CreatedAt, CreatedAt: u.CreatedAt,
LastLogin: u.LastLogin, LastLogin: u.LastLogin,
BlockedUntil: u.BlockedUntil, BlockedUntil: u.BlockedUntil,
@ -285,6 +291,7 @@ func toStorageUserIdentity(u UserIdentity) storage.UserIdentity {
ConnectorID: u.ConnectorID, ConnectorID: u.ConnectorID,
Claims: toStorageClaims(u.Claims), Claims: toStorageClaims(u.Claims),
Consents: u.Consents, Consents: u.Consents,
MFASecrets: u.MFASecrets,
CreatedAt: u.CreatedAt, CreatedAt: u.CreatedAt,
LastLogin: u.LastLogin, LastLogin: u.LastLogin,
BlockedUntil: u.BlockedUntil, BlockedUntil: u.BlockedUntil,

27
storage/kubernetes/types.go

@ -287,6 +287,8 @@ type Client struct {
LogoURL string `json:"logoURL,omitempty"` LogoURL string `json:"logoURL,omitempty"`
AllowedConnectors []string `json:"allowedConnectors,omitempty"` AllowedConnectors []string `json:"allowedConnectors,omitempty"`
MFAChain []string `json:"mfaChain,omitempty"`
} }
// ClientList is a list of Clients. // ClientList is a list of Clients.
@ -314,6 +316,7 @@ func (cli *client) fromStorageClient(c storage.Client) Client {
Name: c.Name, Name: c.Name,
LogoURL: c.LogoURL, LogoURL: c.LogoURL,
AllowedConnectors: c.AllowedConnectors, AllowedConnectors: c.AllowedConnectors,
MFAChain: c.MFAChain,
} }
} }
@ -327,6 +330,7 @@ func toStorageClient(c Client) storage.Client {
Name: c.Name, Name: c.Name,
LogoURL: c.LogoURL, LogoURL: c.LogoURL,
AllowedConnectors: c.AllowedConnectors, AllowedConnectors: c.AllowedConnectors,
MFAChain: c.MFAChain,
} }
} }
@ -396,6 +400,8 @@ type AuthRequest struct {
CodeChallengeMethod string `json:"code_challenge_method,omitempty"` CodeChallengeMethod string `json:"code_challenge_method,omitempty"`
HMACKey []byte `json:"hmac_key"` HMACKey []byte `json:"hmac_key"`
MFAValidated bool `json:"mfa_validated"`
} }
// AuthRequestList is a list of AuthRequests. // AuthRequestList is a list of AuthRequests.
@ -424,7 +430,8 @@ func toStorageAuthRequest(req AuthRequest) storage.AuthRequest {
CodeChallenge: req.CodeChallenge, CodeChallenge: req.CodeChallenge,
CodeChallengeMethod: req.CodeChallengeMethod, CodeChallengeMethod: req.CodeChallengeMethod,
}, },
HMACKey: req.HMACKey, HMACKey: req.HMACKey,
MFAValidated: req.MFAValidated,
} }
return a return a
} }
@ -454,6 +461,7 @@ func (cli *client) fromStorageAuthRequest(a storage.AuthRequest) AuthRequest {
CodeChallenge: a.PKCE.CodeChallenge, CodeChallenge: a.PKCE.CodeChallenge,
CodeChallengeMethod: a.PKCE.CodeChallengeMethod, CodeChallengeMethod: a.PKCE.CodeChallengeMethod,
HMACKey: a.HMACKey, HMACKey: a.HMACKey,
MFAValidated: a.MFAValidated,
} }
return req return req
} }
@ -913,13 +921,14 @@ type UserIdentity struct {
k8sapi.TypeMeta `json:",inline"` k8sapi.TypeMeta `json:",inline"`
k8sapi.ObjectMeta `json:"metadata,omitempty"` k8sapi.ObjectMeta `json:"metadata,omitempty"`
UserID string `json:"userID,omitempty"` UserID string `json:"userID,omitempty"`
ConnectorID string `json:"connectorID,omitempty"` ConnectorID string `json:"connectorID,omitempty"`
Claims Claims `json:"claims,omitempty"` Claims Claims `json:"claims,omitempty"`
Consents map[string][]string `json:"consents,omitempty"` Consents map[string][]string `json:"consents,omitempty"`
CreatedAt time.Time `json:"createdAt,omitempty"` MFASecrets map[string]*storage.MFASecret `json:"mfaSecrets,omitempty"`
LastLogin time.Time `json:"lastLogin,omitempty"` CreatedAt time.Time `json:"createdAt,omitempty"`
BlockedUntil time.Time `json:"blockedUntil,omitempty"` LastLogin time.Time `json:"lastLogin,omitempty"`
BlockedUntil time.Time `json:"blockedUntil,omitempty"`
} }
// UserIdentityList is a list of UserIdentities. // UserIdentityList is a list of UserIdentities.
@ -943,6 +952,7 @@ func (cli *client) fromStorageUserIdentity(u storage.UserIdentity) UserIdentity
ConnectorID: u.ConnectorID, ConnectorID: u.ConnectorID,
Claims: fromStorageClaims(u.Claims), Claims: fromStorageClaims(u.Claims),
Consents: u.Consents, Consents: u.Consents,
MFASecrets: u.MFASecrets,
CreatedAt: u.CreatedAt, CreatedAt: u.CreatedAt,
LastLogin: u.LastLogin, LastLogin: u.LastLogin,
BlockedUntil: u.BlockedUntil, BlockedUntil: u.BlockedUntil,
@ -955,6 +965,7 @@ func toStorageUserIdentity(u UserIdentity) storage.UserIdentity {
ConnectorID: u.ConnectorID, ConnectorID: u.ConnectorID,
Claims: toStorageClaims(u.Claims), Claims: toStorageClaims(u.Claims),
Consents: u.Consents, Consents: u.Consents,
MFASecrets: u.MFASecrets,
CreatedAt: u.CreatedAt, CreatedAt: u.CreatedAt,
LastLogin: u.LastLogin, LastLogin: u.LastLogin,
BlockedUntil: u.BlockedUntil, BlockedUntil: u.BlockedUntil,

70
storage/sql/crud.go

@ -134,10 +134,11 @@ func (c *conn) CreateAuthRequest(ctx context.Context, a storage.AuthRequest) err
connector_id, connector_data, connector_id, connector_data,
expiry, expiry,
code_challenge, code_challenge_method, code_challenge, code_challenge_method,
hmac_key hmac_key,
mfa_validated
) )
values ( values (
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21 $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22
); );
`, `,
a.ID, a.ClientID, encoder(a.ResponseTypes), encoder(a.Scopes), a.RedirectURI, a.Nonce, a.State, a.ID, a.ClientID, encoder(a.ResponseTypes), encoder(a.Scopes), a.RedirectURI, a.Nonce, a.State,
@ -148,6 +149,7 @@ func (c *conn) CreateAuthRequest(ctx context.Context, a storage.AuthRequest) err
a.Expiry, a.Expiry,
a.PKCE.CodeChallenge, a.PKCE.CodeChallengeMethod, a.PKCE.CodeChallenge, a.PKCE.CodeChallengeMethod,
a.HMACKey, a.HMACKey,
a.MFAValidated,
) )
if err != nil { if err != nil {
if c.alreadyExistsCheck(err) { if c.alreadyExistsCheck(err) {
@ -180,8 +182,9 @@ func (c *conn) UpdateAuthRequest(ctx context.Context, id string, updater func(a
connector_id = $15, connector_data = $16, connector_id = $15, connector_data = $16,
expiry = $17, expiry = $17,
code_challenge = $18, code_challenge_method = $19, code_challenge = $18, code_challenge_method = $19,
hmac_key = $20 hmac_key = $20,
where id = $21; mfa_validated = $21
where id = $22;
`, `,
a.ClientID, encoder(a.ResponseTypes), encoder(a.Scopes), a.RedirectURI, a.Nonce, a.State, a.ClientID, encoder(a.ResponseTypes), encoder(a.Scopes), a.RedirectURI, a.Nonce, a.State,
a.ForceApprovalPrompt, a.LoggedIn, a.ForceApprovalPrompt, a.LoggedIn,
@ -191,6 +194,7 @@ func (c *conn) UpdateAuthRequest(ctx context.Context, id string, updater func(a
a.ConnectorID, a.ConnectorData, a.ConnectorID, a.ConnectorData,
a.Expiry, a.Expiry,
a.PKCE.CodeChallenge, a.PKCE.CodeChallengeMethod, a.HMACKey, a.PKCE.CodeChallenge, a.PKCE.CodeChallengeMethod, a.HMACKey,
a.MFAValidated,
r.ID, r.ID,
) )
if err != nil { if err != nil {
@ -212,7 +216,8 @@ func getAuthRequest(ctx context.Context, q querier, id string) (a storage.AuthRe
claims_user_id, claims_username, claims_preferred_username, claims_user_id, claims_username, claims_preferred_username,
claims_email, claims_email_verified, claims_groups, claims_email, claims_email_verified, claims_groups,
connector_id, connector_data, expiry, connector_id, connector_data, expiry,
code_challenge, code_challenge_method, hmac_key code_challenge, code_challenge_method, hmac_key,
mfa_validated
from auth_request where id = $1; from auth_request where id = $1;
`, id).Scan( `, id).Scan(
&a.ID, &a.ClientID, decoder(&a.ResponseTypes), decoder(&a.Scopes), &a.RedirectURI, &a.Nonce, &a.State, &a.ID, &a.ClientID, decoder(&a.ResponseTypes), decoder(&a.Scopes), &a.RedirectURI, &a.Nonce, &a.State,
@ -222,6 +227,7 @@ func getAuthRequest(ctx context.Context, q querier, id string) (a storage.AuthRe
decoder(&a.Claims.Groups), decoder(&a.Claims.Groups),
&a.ConnectorID, &a.ConnectorData, &a.Expiry, &a.ConnectorID, &a.ConnectorData, &a.Expiry,
&a.PKCE.CodeChallenge, &a.PKCE.CodeChallengeMethod, &a.HMACKey, &a.PKCE.CodeChallenge, &a.PKCE.CodeChallengeMethod, &a.HMACKey,
&a.MFAValidated,
) )
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@ -514,9 +520,10 @@ func (c *conn) UpdateClient(ctx context.Context, id string, updater func(old sto
public = $4, public = $4,
name = $5, name = $5,
logo_url = $6, logo_url = $6,
allowed_connectors = $7 allowed_connectors = $7,
where id = $8; mfa_chain = $8
`, nc.Secret, encoder(nc.RedirectURIs), encoder(nc.TrustedPeers), nc.Public, nc.Name, nc.LogoURL, encoder(nc.AllowedConnectors), id, where id = $9;
`, nc.Secret, encoder(nc.RedirectURIs), encoder(nc.TrustedPeers), nc.Public, nc.Name, nc.LogoURL, encoder(nc.AllowedConnectors), encoder(nc.MFAChain), id,
) )
if err != nil { if err != nil {
return fmt.Errorf("update client: %v", err) return fmt.Errorf("update client: %v", err)
@ -528,12 +535,12 @@ func (c *conn) UpdateClient(ctx context.Context, id string, updater func(old sto
func (c *conn) CreateClient(ctx context.Context, cli storage.Client) error { func (c *conn) CreateClient(ctx context.Context, cli storage.Client) error {
_, err := c.Exec(` _, err := c.Exec(`
insert into client ( insert into client (
id, secret, redirect_uris, trusted_peers, public, name, logo_url, allowed_connectors id, secret, redirect_uris, trusted_peers, public, name, logo_url, allowed_connectors, mfa_chain
) )
values ($1, $2, $3, $4, $5, $6, $7, $8); values ($1, $2, $3, $4, $5, $6, $7, $8, $9);
`, `,
cli.ID, cli.Secret, encoder(cli.RedirectURIs), encoder(cli.TrustedPeers), cli.ID, cli.Secret, encoder(cli.RedirectURIs), encoder(cli.TrustedPeers),
cli.Public, cli.Name, cli.LogoURL, encoder(cli.AllowedConnectors), cli.Public, cli.Name, cli.LogoURL, encoder(cli.AllowedConnectors), encoder(cli.MFAChain),
) )
if err != nil { if err != nil {
if c.alreadyExistsCheck(err) { if c.alreadyExistsCheck(err) {
@ -547,7 +554,7 @@ func (c *conn) CreateClient(ctx context.Context, cli storage.Client) error {
func getClient(ctx context.Context, q querier, id string) (storage.Client, error) { func getClient(ctx context.Context, q querier, id string) (storage.Client, error) {
return scanClient(q.QueryRow(` return scanClient(q.QueryRow(`
select select
id, secret, redirect_uris, trusted_peers, public, name, logo_url, allowed_connectors id, secret, redirect_uris, trusted_peers, public, name, logo_url, allowed_connectors, mfa_chain
from client where id = $1; from client where id = $1;
`, id)) `, id))
} }
@ -559,7 +566,7 @@ func (c *conn) GetClient(ctx context.Context, id string) (storage.Client, error)
func (c *conn) ListClients(ctx context.Context) ([]storage.Client, error) { func (c *conn) ListClients(ctx context.Context) ([]storage.Client, error) {
rows, err := c.Query(` rows, err := c.Query(`
select select
id, secret, redirect_uris, trusted_peers, public, name, logo_url, allowed_connectors id, secret, redirect_uris, trusted_peers, public, name, logo_url, allowed_connectors, mfa_chain
from client; from client;
`) `)
if err != nil { if err != nil {
@ -583,9 +590,10 @@ func (c *conn) ListClients(ctx context.Context) ([]storage.Client, error) {
func scanClient(s scanner) (cli storage.Client, err error) { func scanClient(s scanner) (cli storage.Client, err error) {
var allowedConnectors []byte var allowedConnectors []byte
var mfaChain []byte
err = s.Scan( err = s.Scan(
&cli.ID, &cli.Secret, decoder(&cli.RedirectURIs), decoder(&cli.TrustedPeers), &cli.ID, &cli.Secret, decoder(&cli.RedirectURIs), decoder(&cli.TrustedPeers),
&cli.Public, &cli.Name, &cli.LogoURL, &allowedConnectors, &cli.Public, &cli.Name, &cli.LogoURL, &allowedConnectors, &mfaChain,
) )
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@ -598,6 +606,11 @@ func scanClient(s scanner) (cli storage.Client, err error) {
return cli, fmt.Errorf("unmarshal client allowed connectors: %v", err) return cli, fmt.Errorf("unmarshal client allowed connectors: %v", err)
} }
} }
if len(mfaChain) > 0 {
if err := json.Unmarshal(mfaChain, &cli.MFAChain); err != nil {
return cli, fmt.Errorf("unmarshal client mfa chain: %v", err)
}
}
return cli, nil return cli, nil
} }
@ -781,17 +794,17 @@ func (c *conn) CreateUserIdentity(ctx context.Context, u storage.UserIdentity) e
user_id, connector_id, user_id, connector_id,
claims_user_id, claims_username, claims_preferred_username, claims_user_id, claims_username, claims_preferred_username,
claims_email, claims_email_verified, claims_groups, claims_email, claims_email_verified, claims_groups,
consents, consents, mfa_secrets,
created_at, last_login, blocked_until created_at, last_login, blocked_until
) )
values ( values (
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12 $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13
); );
`, `,
u.UserID, u.ConnectorID, u.UserID, u.ConnectorID,
u.Claims.UserID, u.Claims.Username, u.Claims.PreferredUsername, u.Claims.UserID, u.Claims.Username, u.Claims.PreferredUsername,
u.Claims.Email, u.Claims.EmailVerified, encoder(u.Claims.Groups), u.Claims.Email, u.Claims.EmailVerified, encoder(u.Claims.Groups),
encoder(u.Consents), encoder(u.Consents), encoder(u.MFASecrets),
u.CreatedAt, u.LastLogin, u.BlockedUntil, u.CreatedAt, u.LastLogin, u.BlockedUntil,
) )
if err != nil { if err != nil {
@ -824,14 +837,15 @@ func (c *conn) UpdateUserIdentity(ctx context.Context, userID, connectorID strin
claims_email_verified = $5, claims_email_verified = $5,
claims_groups = $6, claims_groups = $6,
consents = $7, consents = $7,
created_at = $8, mfa_secrets = $8,
last_login = $9, created_at = $9,
blocked_until = $10 last_login = $10,
where user_id = $11 AND connector_id = $12; blocked_until = $11
where user_id = $12 AND connector_id = $13;
`, `,
newIdentity.Claims.UserID, newIdentity.Claims.Username, newIdentity.Claims.PreferredUsername, newIdentity.Claims.UserID, newIdentity.Claims.Username, newIdentity.Claims.PreferredUsername,
newIdentity.Claims.Email, newIdentity.Claims.EmailVerified, encoder(newIdentity.Claims.Groups), newIdentity.Claims.Email, newIdentity.Claims.EmailVerified, encoder(newIdentity.Claims.Groups),
encoder(newIdentity.Consents), encoder(newIdentity.Consents), encoder(newIdentity.MFASecrets),
newIdentity.CreatedAt, newIdentity.LastLogin, newIdentity.BlockedUntil, newIdentity.CreatedAt, newIdentity.LastLogin, newIdentity.BlockedUntil,
u.UserID, u.ConnectorID, u.UserID, u.ConnectorID,
) )
@ -852,7 +866,7 @@ func getUserIdentity(ctx context.Context, q querier, userID, connectorID string)
user_id, connector_id, user_id, connector_id,
claims_user_id, claims_username, claims_preferred_username, claims_user_id, claims_username, claims_preferred_username,
claims_email, claims_email_verified, claims_groups, claims_email, claims_email_verified, claims_groups,
consents, consents, mfa_secrets,
created_at, last_login, blocked_until created_at, last_login, blocked_until
from user_identity from user_identity
where user_id = $1 AND connector_id = $2; where user_id = $1 AND connector_id = $2;
@ -865,7 +879,7 @@ func (c *conn) ListUserIdentities(ctx context.Context) ([]storage.UserIdentity,
user_id, connector_id, user_id, connector_id,
claims_user_id, claims_username, claims_preferred_username, claims_user_id, claims_username, claims_preferred_username,
claims_email, claims_email_verified, claims_groups, claims_email, claims_email_verified, claims_groups,
consents, consents, mfa_secrets,
created_at, last_login, blocked_until created_at, last_login, blocked_until
from user_identity; from user_identity;
`) `)
@ -889,11 +903,12 @@ func (c *conn) ListUserIdentities(ctx context.Context) ([]storage.UserIdentity,
} }
func scanUserIdentity(s scanner) (u storage.UserIdentity, err error) { func scanUserIdentity(s scanner) (u storage.UserIdentity, err error) {
var mfaSecrets []byte
err = s.Scan( err = s.Scan(
&u.UserID, &u.ConnectorID, &u.UserID, &u.ConnectorID,
&u.Claims.UserID, &u.Claims.Username, &u.Claims.PreferredUsername, &u.Claims.UserID, &u.Claims.Username, &u.Claims.PreferredUsername,
&u.Claims.Email, &u.Claims.EmailVerified, decoder(&u.Claims.Groups), &u.Claims.Email, &u.Claims.EmailVerified, decoder(&u.Claims.Groups),
decoder(&u.Consents), decoder(&u.Consents), &mfaSecrets,
&u.CreatedAt, &u.LastLogin, &u.BlockedUntil, &u.CreatedAt, &u.LastLogin, &u.BlockedUntil,
) )
if err != nil { if err != nil {
@ -905,6 +920,11 @@ func scanUserIdentity(s scanner) (u storage.UserIdentity, err error) {
if u.Consents == nil { if u.Consents == nil {
u.Consents = make(map[string][]string) u.Consents = make(map[string][]string)
} }
if len(mfaSecrets) > 0 {
if err := json.Unmarshal(mfaSecrets, &u.MFASecrets); err != nil {
return u, fmt.Errorf("unmarshal user identity mfa secrets: %v", err)
}
}
return u, nil return u, nil
} }

13
storage/sql/migrate.go

@ -422,4 +422,17 @@ var migrations = []migration{
);`, );`,
}, },
}, },
{
stmts: []string{
`
alter table auth_request
add column mfa_validated boolean not null default false;`,
`
alter table user_identity
add column mfa_secrets bytea;`,
`
alter table client
add column mfa_chain bytea;`,
},
},
} }

19
storage/storage.go

@ -185,6 +185,10 @@ type Client struct {
// AllowedConnectors is a list of connector IDs that the client is allowed to use for authentication. // AllowedConnectors is a list of connector IDs that the client is allowed to use for authentication.
// If empty, all connectors are allowed. // If empty, all connectors are allowed.
AllowedConnectors []string `json:"allowedConnectors"` AllowedConnectors []string `json:"allowedConnectors"`
// MFAChain is an ordered list of MFA authenticator IDs that a user must complete
// during login. Empty means no MFA required.
MFAChain []string `json:"mfaChain"`
} }
// Claims represents the ID Token claims supported by the server. // Claims represents the ID Token claims supported by the server.
@ -247,6 +251,9 @@ type AuthRequest struct {
// HMACKey is used when generating an AuthRequest-specific HMAC // HMACKey is used when generating an AuthRequest-specific HMAC
HMACKey []byte HMACKey []byte
// MFAValidated is set to true if the user has completed multi-factor authentication.
MFAValidated bool
} }
// AuthCode represents a code which can be exchanged for an OAuth2 token response. // AuthCode represents a code which can be exchanged for an OAuth2 token response.
@ -330,12 +337,22 @@ type RefreshTokenRef struct {
LastUsed time.Time LastUsed time.Time
} }
// MFASecret stores the enrollment state and secret for an MFA authenticator.
type MFASecret struct {
AuthenticatorID string `json:"authenticatorID"`
Type string `json:"type"`
Secret string `json:"secret"`
Confirmed bool `json:"confirmed"`
CreatedAt time.Time `json:"createdAt"`
}
// UserIdentity represents persistent per-user identity data. // UserIdentity represents persistent per-user identity data.
type UserIdentity struct { type UserIdentity struct {
UserID string UserID string
ConnectorID string ConnectorID string
Claims Claims Claims Claims
Consents map[string][]string // clientID -> approved scopes Consents map[string][]string // clientID -> approved scopes
MFASecrets map[string]*MFASecret // authenticatorID -> secret
CreatedAt time.Time CreatedAt time.Time
LastLogin time.Time LastLogin time.Time
BlockedUntil time.Time BlockedUntil time.Time

5
t

@ -0,0 +1,5 @@
export DEX_POSTGRES_DATABASE=dex
export DEX_POSTGRES_USER=postgres
export DEX_POSTGRES_PASSWORD=postgres
export DEX_POSTGRES_HOST=0.0.0.0
export DEX_POSTGRES_PORT=5432

35
web/templates/totp_verify.html

@ -0,0 +1,35 @@
{{ template "header.html" . }}
<div class="theme-panel">
<h2 class="theme-heading">Two-factor authentication</h2>
{{ if not (eq .QRCode "") }}
<p>Scan the QR code below using your authenticator app, then enter the code.</p>
<div style="text-align: center; margin: 1em 0;">
<img src="data:image/png;base64,{{ .QRCode }}" alt="QR code" width="200" height="200"/>
</div>
{{ else }}
<p>Enter the code from your authenticator app for <b>{{ .Issuer }}</b>.</p>
{{ end }}
<form method="post" action="{{ .PostURL }}">
<div class="theme-form-row">
<div class="theme-form-label">
<label for="totp">One-time code</label>
</div>
<input tabindex="1" required id="totp" name="totp" type="text"
inputmode="numeric" pattern="[0-9]*" maxlength="6"
autocomplete="one-time-code"
class="theme-form-input" placeholder="000000"
autofocus/>
</div>
{{ if .Invalid }}
<div id="login-error" class="dex-error-box">
Invalid code. Please try again.
</div>
{{ end }}
<button tabindex="2" id="submit-login" type="submit" class="dex-btn theme-btn--primary">Verify</button>
</form>
</div>
{{ template "footer.html" . }}
Loading…
Cancel
Save