mirror of https://github.com/dexidp/dex.git
6 changed files with 390 additions and 97 deletions
@ -0,0 +1,249 @@
|
||||
package server |
||||
|
||||
import ( |
||||
"context" |
||||
"crypto/rand" |
||||
"crypto/rsa" |
||||
"encoding/hex" |
||||
"errors" |
||||
"fmt" |
||||
"io" |
||||
"log/slog" |
||||
"time" |
||||
|
||||
"github.com/go-jose/go-jose/v4" |
||||
|
||||
"github.com/dexidp/dex/storage" |
||||
) |
||||
|
||||
var errAlreadyRotated = errors.New("keys already rotated by another server instance") |
||||
|
||||
// rotationStrategy describes a strategy for generating cryptographic keys, how
|
||||
// often to rotate them, and how long they can validate signatures after rotation.
|
||||
type rotationStrategy struct { |
||||
// Time between rotations.
|
||||
rotationFrequency time.Duration |
||||
|
||||
// After being rotated how long should the key be kept around for validating
|
||||
// signatures?
|
||||
idTokenValidFor time.Duration |
||||
|
||||
// Keys are always RSA keys. Though cryptopasta recommends ECDSA keys, not every
|
||||
// client may support these (e.g. github.com/coreos/go-oidc/oidc).
|
||||
key func() (*rsa.PrivateKey, error) |
||||
} |
||||
|
||||
// staticRotationStrategy returns a strategy which never rotates keys.
|
||||
func staticRotationStrategy(key *rsa.PrivateKey) rotationStrategy { |
||||
return rotationStrategy{ |
||||
// Setting these values to 100 years is easier than having a flag indicating no rotation.
|
||||
rotationFrequency: time.Hour * 8760 * 100, |
||||
idTokenValidFor: time.Hour * 8760 * 100, |
||||
key: func() (*rsa.PrivateKey, error) { return key, nil }, |
||||
} |
||||
} |
||||
|
||||
// defaultRotationStrategy returns a strategy which rotates keys every provided period,
|
||||
// holding onto the public parts for some specified amount of time.
|
||||
func defaultRotationStrategy(rotationFrequency, idTokenValidFor time.Duration) rotationStrategy { |
||||
return rotationStrategy{ |
||||
rotationFrequency: rotationFrequency, |
||||
idTokenValidFor: idTokenValidFor, |
||||
key: func() (*rsa.PrivateKey, error) { |
||||
return rsa.GenerateKey(rand.Reader, 2048) |
||||
}, |
||||
} |
||||
} |
||||
|
||||
type keyRotator struct { |
||||
storage.Storage |
||||
|
||||
strategy rotationStrategy |
||||
now func() time.Time |
||||
|
||||
logger *slog.Logger |
||||
} |
||||
|
||||
// startKeyRotation begins key rotation in a new goroutine, closing once the context is canceled.
|
||||
//
|
||||
// The method blocks until after the first attempt to rotate keys has completed. That way
|
||||
// healthy storages will return from this call with valid keys.
|
||||
func (s *Server) startKeyRotation(ctx context.Context, strategy rotationStrategy, now func() time.Time) { |
||||
rotator := keyRotator{s.storage, strategy, now, s.logger} |
||||
|
||||
// Try to rotate immediately so properly configured storages will have keys.
|
||||
if err := rotator.rotate(); err != nil { |
||||
if err == errAlreadyRotated { |
||||
s.logger.Info("key rotation not needed", "err", err) |
||||
} else { |
||||
s.logger.Error("failed to rotate keys", "err", err) |
||||
} |
||||
} |
||||
|
||||
go func() { |
||||
for { |
||||
select { |
||||
case <-ctx.Done(): |
||||
return |
||||
case <-time.After(time.Second * 30): |
||||
if err := rotator.rotate(); err != nil { |
||||
s.logger.Error("failed to rotate keys", "err", err) |
||||
} |
||||
} |
||||
} |
||||
}() |
||||
} |
||||
|
||||
func (k keyRotator) rotate() error { |
||||
keys, err := k.GetKeys(context.Background()) |
||||
if err != nil && err != storage.ErrNotFound { |
||||
return fmt.Errorf("get keys: %v", err) |
||||
} |
||||
if k.now().Before(keys.NextRotation) { |
||||
return nil |
||||
} |
||||
k.logger.Info("keys expired, rotating") |
||||
|
||||
// Generate the key outside of a storage transaction.
|
||||
key, err := k.strategy.key() |
||||
if err != nil { |
||||
return fmt.Errorf("generate key: %v", err) |
||||
} |
||||
b := make([]byte, 20) |
||||
if _, err := io.ReadFull(rand.Reader, b); err != nil { |
||||
panic(err) |
||||
} |
||||
keyID := hex.EncodeToString(b) |
||||
priv := &jose.JSONWebKey{ |
||||
Key: key, |
||||
KeyID: keyID, |
||||
Algorithm: "RS256", |
||||
Use: "sig", |
||||
} |
||||
pub := &jose.JSONWebKey{ |
||||
Key: key.Public(), |
||||
KeyID: keyID, |
||||
Algorithm: "RS256", |
||||
Use: "sig", |
||||
} |
||||
|
||||
var nextRotation time.Time |
||||
err = k.UpdateKeys(context.Background(), func(keys storage.Keys) (storage.Keys, error) { |
||||
tNow := k.now() |
||||
|
||||
// if you are running multiple instances of dex, another instance
|
||||
// could have already rotated the keys.
|
||||
if tNow.Before(keys.NextRotation) { |
||||
return storage.Keys{}, errAlreadyRotated |
||||
} |
||||
|
||||
expired := func(key storage.VerificationKey) bool { |
||||
return tNow.After(key.Expiry) |
||||
} |
||||
|
||||
// Remove any verification keys that have expired.
|
||||
i := 0 |
||||
for _, key := range keys.VerificationKeys { |
||||
if !expired(key) { |
||||
keys.VerificationKeys[i] = key |
||||
i++ |
||||
} |
||||
} |
||||
keys.VerificationKeys = keys.VerificationKeys[:i] |
||||
|
||||
if keys.SigningKeyPub != nil { |
||||
// Move current signing key to a verification only key, throwing
|
||||
// away the private part.
|
||||
verificationKey := storage.VerificationKey{ |
||||
PublicKey: keys.SigningKeyPub, |
||||
// After demoting the signing key, keep the token around for at least
|
||||
// the amount of time an ID Token is valid for. This ensures the
|
||||
// verification key won't expire until all ID Tokens it's signed
|
||||
// expired as well.
|
||||
Expiry: tNow.Add(k.strategy.idTokenValidFor), |
||||
} |
||||
keys.VerificationKeys = append(keys.VerificationKeys, verificationKey) |
||||
} |
||||
|
||||
nextRotation = k.now().Add(k.strategy.rotationFrequency) |
||||
keys.SigningKey = priv |
||||
keys.SigningKeyPub = pub |
||||
keys.NextRotation = nextRotation |
||||
return keys, nil |
||||
}) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
k.logger.Info("keys rotated", "next_rotation", nextRotation) |
||||
return nil |
||||
} |
||||
|
||||
type RefreshTokenPolicy struct { |
||||
rotateRefreshTokens bool // enable rotation
|
||||
|
||||
absoluteLifetime time.Duration // interval from token creation to the end of its life
|
||||
validIfNotUsedFor time.Duration // interval from last token update to the end of its life
|
||||
reuseInterval time.Duration // interval within which old refresh token is allowed to be reused
|
||||
|
||||
now func() time.Time |
||||
|
||||
logger *slog.Logger |
||||
} |
||||
|
||||
func NewRefreshTokenPolicy(logger *slog.Logger, rotation bool, validIfNotUsedFor, absoluteLifetime, reuseInterval string) (*RefreshTokenPolicy, error) { |
||||
r := RefreshTokenPolicy{now: time.Now, logger: logger} |
||||
var err error |
||||
|
||||
if validIfNotUsedFor != "" { |
||||
r.validIfNotUsedFor, err = time.ParseDuration(validIfNotUsedFor) |
||||
if err != nil { |
||||
return nil, fmt.Errorf("invalid config value %q for refresh token valid if not used for: %v", validIfNotUsedFor, err) |
||||
} |
||||
logger.Info("config refresh tokens", "valid_if_not_used_for", validIfNotUsedFor) |
||||
} |
||||
|
||||
if absoluteLifetime != "" { |
||||
r.absoluteLifetime, err = time.ParseDuration(absoluteLifetime) |
||||
if err != nil { |
||||
return nil, fmt.Errorf("invalid config value %q for refresh tokens absolute lifetime: %v", absoluteLifetime, err) |
||||
} |
||||
logger.Info("config refresh tokens", "absolute_lifetime", absoluteLifetime) |
||||
} |
||||
|
||||
if reuseInterval != "" { |
||||
r.reuseInterval, err = time.ParseDuration(reuseInterval) |
||||
if err != nil { |
||||
return nil, fmt.Errorf("invalid config value %q for refresh tokens reuse interval: %v", reuseInterval, err) |
||||
} |
||||
logger.Info("config refresh tokens", "reuse_interval", reuseInterval) |
||||
} |
||||
|
||||
r.rotateRefreshTokens = !rotation |
||||
logger.Info("config refresh tokens rotation", "enabled", r.rotateRefreshTokens) |
||||
return &r, nil |
||||
} |
||||
|
||||
func (r *RefreshTokenPolicy) RotationEnabled() bool { |
||||
return r.rotateRefreshTokens |
||||
} |
||||
|
||||
func (r *RefreshTokenPolicy) CompletelyExpired(lastUsed time.Time) bool { |
||||
if r.absoluteLifetime == 0 { |
||||
return false // expiration disabled
|
||||
} |
||||
return r.now().After(lastUsed.Add(r.absoluteLifetime)) |
||||
} |
||||
|
||||
func (r *RefreshTokenPolicy) ExpiredBecauseUnused(lastUsed time.Time) bool { |
||||
if r.validIfNotUsedFor == 0 { |
||||
return false // expiration disabled
|
||||
} |
||||
return r.now().After(lastUsed.Add(r.validIfNotUsedFor)) |
||||
} |
||||
|
||||
func (r *RefreshTokenPolicy) AllowedToReuse(lastUsed time.Time) bool { |
||||
if r.reuseInterval == 0 { |
||||
return false // expiration disabled
|
||||
} |
||||
return !r.now().After(lastUsed.Add(r.reuseInterval)) |
||||
} |
||||
Loading…
Reference in new issue