diff --git a/server/refreshhandlers.go b/server/refreshhandlers.go index de8d9b7b..8ea5a8b0 100644 --- a/server/refreshhandlers.go +++ b/server/refreshhandlers.go @@ -232,7 +232,9 @@ func (s *Server) updateOfflineSession(ctx context.Context, refresh *storage.Refr return nil } -// updateRefreshToken updates refresh token and offline session in the storage +// updateRefreshToken updates refresh token and offline session in the storage. +// Connector refresh is guarded by a per-refresh-ID mutex so only one concurrent +// caller hits the IdP. func (s *Server) updateRefreshToken(ctx context.Context, rCtx *refreshContext) (*internal.RefreshToken, connector.Identity, *refreshError) { var rerr *refreshError @@ -240,7 +242,6 @@ func (s *Server) updateRefreshToken(ctx context.Context, rCtx *refreshContext) ( Token: rCtx.requestToken.Token, RefreshId: rCtx.requestToken.RefreshId, } - lastUsed := s.now() ident := connector.Identity{ @@ -250,6 +251,31 @@ func (s *Server) updateRefreshToken(ctx context.Context, rCtx *refreshContext) ( Email: rCtx.storageToken.Claims.Email, EmailVerified: rCtx.storageToken.Claims.EmailVerified, Groups: rCtx.storageToken.Claims.Groups, + ConnectorData: rCtx.connectorData, + } + + rotationEnabled := s.refreshTokenPolicy.RotationEnabled() + reusingAllowed := s.refreshTokenPolicy.AllowedToReuse(rCtx.storageToken.LastUsed) + needConnectorRefresh := rotationEnabled && !reusingAllowed + + if needConnectorRefresh { + // Serialize concurrent refreshes for the same refresh ID. + lock := s.getRefreshLock(rCtx.storageToken.ID) + lock.Lock() + s.logger.Debug("Acquired refresh lock", "refreshID", rCtx.storageToken.ID) + defer func() { + lock.Unlock() + s.logger.Debug("Released refresh lock", "refreshID", rCtx.storageToken.ID) + }() + + // Double-check if another goroutine already refreshed while we waited: + if !s.refreshTokenPolicy.AllowedToReuse(rCtx.storageToken.LastUsed) { + var rerr *refreshError + ident, rerr = s.refreshWithConnector(ctx, rCtx, ident) + if rerr != nil { + return nil, ident, rerr + } + } } refreshTokenUpdater := func(old storage.RefreshToken) (storage.RefreshToken, error) { @@ -293,14 +319,6 @@ func (s *Server) updateRefreshToken(ctx context.Context, rCtx *refreshContext) ( // ConnectorData has been moved to OfflineSession old.ConnectorData = nil - // Call only once if there is a request which is not in the reuse interval. - // This is required to avoid multiple calls to the external IdP for concurrent requests. - // Dex will call the connector's Refresh method only once if request is not in reuse interval. - ident, rerr = s.refreshWithConnector(ctx, rCtx, ident) - if rerr != nil { - return old, rerr - } - // Update the claims of the refresh token. // // UserID intentionally ignored for now. diff --git a/server/server.go b/server/server.go index 70e8ae75..c2425053 100644 --- a/server/server.go +++ b/server/server.go @@ -198,6 +198,8 @@ type Server struct { deviceRequestsValidFor time.Duration refreshTokenPolicy *RefreshTokenPolicy + // mutex to refresh the same token only once for concurrent requests + refreshLocks sync.Map logger *slog.Logger } @@ -758,6 +760,12 @@ func (s *Server) getConnector(ctx context.Context, id string) (Connector, error) return conn, nil } +// getRefreshLock returns a per-refresh-ID mutex. +func (s *Server) getRefreshLock(refreshID string) *sync.Mutex { + m, _ := s.refreshLocks.LoadOrStore(refreshID, &sync.Mutex{}) + return m.(*sync.Mutex) +} + type logRequestKey string const (