package server import ( "context" "errors" "fmt" "log/slog" "net/http" "strings" "time" "github.com/dexidp/dex/connector" "github.com/dexidp/dex/server/internal" "github.com/dexidp/dex/storage" ) type RefreshTokenPolicy struct { rotateRefreshTokens bool // enable rotation absoluteLifetime time.Duration // interval from token creation to the end of its life validIfNotUsedFor time.Duration // interval from last token update to the end of its life reuseInterval time.Duration // interval within which old refresh token is allowed to be reused now func() time.Time logger *slog.Logger } func NewRefreshTokenPolicy(logger *slog.Logger, rotation bool, validIfNotUsedFor, absoluteLifetime, reuseInterval string) (*RefreshTokenPolicy, error) { r := RefreshTokenPolicy{now: time.Now, logger: logger} var err error if validIfNotUsedFor != "" { r.validIfNotUsedFor, err = time.ParseDuration(validIfNotUsedFor) if err != nil { return nil, fmt.Errorf("invalid config value %q for refresh token valid if not used for: %v", validIfNotUsedFor, err) } logger.Info("config refresh tokens", "valid_if_not_used_for", validIfNotUsedFor) } if absoluteLifetime != "" { r.absoluteLifetime, err = time.ParseDuration(absoluteLifetime) if err != nil { return nil, fmt.Errorf("invalid config value %q for refresh tokens absolute lifetime: %v", absoluteLifetime, err) } logger.Info("config refresh tokens", "absolute_lifetime", absoluteLifetime) } if reuseInterval != "" { r.reuseInterval, err = time.ParseDuration(reuseInterval) if err != nil { return nil, fmt.Errorf("invalid config value %q for refresh tokens reuse interval: %v", reuseInterval, err) } logger.Info("config refresh tokens", "reuse_interval", reuseInterval) } r.rotateRefreshTokens = !rotation logger.Info("config refresh tokens rotation", "enabled", r.rotateRefreshTokens) return &r, nil } func (r *RefreshTokenPolicy) RotationEnabled() bool { return r.rotateRefreshTokens } func (r *RefreshTokenPolicy) CompletelyExpired(lastUsed time.Time) bool { if r.absoluteLifetime == 0 { return false // expiration disabled } return r.now().After(lastUsed.Add(r.absoluteLifetime)) } func (r *RefreshTokenPolicy) ExpiredBecauseUnused(lastUsed time.Time) bool { if r.validIfNotUsedFor == 0 { return false // expiration disabled } return r.now().After(lastUsed.Add(r.validIfNotUsedFor)) } func (r *RefreshTokenPolicy) AllowedToReuse(lastUsed time.Time) bool { if r.reuseInterval == 0 { return false // expiration disabled } return !r.now().After(lastUsed.Add(r.reuseInterval)) } func contains(arr []string, item string) bool { for _, itemFromArray := range arr { if itemFromArray == item { return true } } return false } type refreshError struct { msg string code int desc string } func (r *refreshError) Error() string { return fmt.Sprintf("refresh token error: status %d, %q %s", r.code, r.msg, r.desc) } func newInternalServerError() *refreshError { return &refreshError{msg: errInvalidRequest, desc: "", code: http.StatusInternalServerError} } func newBadRequestError(desc string) *refreshError { return &refreshError{msg: errInvalidRequest, desc: desc, code: http.StatusBadRequest} } var ( invalidErr = newBadRequestError("Refresh token is invalid or has already been claimed by another client.") expiredErr = newBadRequestError("Refresh token expired.") ) func (s *Server) refreshTokenErrHelper(w http.ResponseWriter, err *refreshError) { s.tokenErrHelper(w, err.msg, err.desc, err.code) } func (s *Server) extractRefreshTokenFromRequest(r *http.Request) (*internal.RefreshToken, *refreshError) { code := r.PostFormValue("refresh_token") if code == "" { return nil, newBadRequestError("No refresh token is found in request.") } token := new(internal.RefreshToken) if err := internal.Unmarshal(code, token); err != nil { // For backward compatibility, assume the refresh_token is a raw refresh token ID // if it fails to decode. // // Because refresh_token values that aren't unmarshable were generated by servers // that don't have a Token value, we'll still reject any attempts to claim a // refresh_token twice. token = &internal.RefreshToken{RefreshId: code, Token: ""} } return token, nil } type refreshContext struct { storageToken *storage.RefreshToken requestToken *internal.RefreshToken connector Connector connectorData []byte scopes []string } // getRefreshTokenFromStorage checks that refresh token is valid and exists in the storage and gets its info func (s *Server) getRefreshTokenFromStorage(ctx context.Context, clientID *string, token *internal.RefreshToken) (*refreshContext, *refreshError) { refreshCtx := refreshContext{requestToken: token} // Get RefreshToken refresh, err := s.storage.GetRefresh(ctx, token.RefreshId) if err != nil { if err != storage.ErrNotFound { s.logger.ErrorContext(ctx, "failed to get refresh token", "err", err) return nil, newInternalServerError() } return nil, invalidErr } // Only check ClientID if it was provided; if clientID != nil && (refresh.ClientID != *clientID) { s.logger.ErrorContext(ctx, "trying to claim token for different client", "client_id", clientID, "refresh_client_id", refresh.ClientID) // According to https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 Dex should respond with an // invalid grant error if token has already been claimed by another client. return nil, &refreshError{msg: errInvalidGrant, desc: invalidErr.desc, code: http.StatusBadRequest} } if refresh.Token != token.Token { switch { case !s.refreshTokenPolicy.AllowedToReuse(refresh.LastUsed): fallthrough case refresh.ObsoleteToken != token.Token: fallthrough case refresh.ObsoleteToken == "": s.logger.ErrorContext(ctx, "refresh token claimed twice", "token_id", refresh.ID) return nil, invalidErr } } if s.refreshTokenPolicy.CompletelyExpired(refresh.CreatedAt) { s.logger.ErrorContext(ctx, "refresh token expired", "token_id", refresh.ID) return nil, expiredErr } if s.refreshTokenPolicy.ExpiredBecauseUnused(refresh.LastUsed) { s.logger.ErrorContext(ctx, "refresh token expired due to inactivity", "token_id", refresh.ID) return nil, expiredErr } refreshCtx.storageToken = &refresh // Get Connector refreshCtx.connector, err = s.getConnector(ctx, refresh.ConnectorID) if err != nil { s.logger.ErrorContext(ctx, "connector not found", "connector_id", refresh.ConnectorID, "err", err) return nil, newInternalServerError() } if !GrantTypeAllowed(refreshCtx.connector.GrantTypes, grantTypeRefreshToken) { s.logger.ErrorContext(ctx, "connector does not allow refresh token grant", "connector_id", refresh.ConnectorID) return nil, &refreshError{msg: errInvalidRequest, desc: "Connector does not support refresh tokens.", code: http.StatusBadRequest} } // Get Connector Data session, err := s.storage.GetOfflineSessions(ctx, refresh.Claims.UserID, refresh.ConnectorID) switch { case err != nil: if err != storage.ErrNotFound { s.logger.ErrorContext(ctx, "failed to get offline session", "err", err) return nil, newInternalServerError() } case len(refresh.ConnectorData) > 0: // Use the old connector data if it exists, should be deleted once used refreshCtx.connectorData = refresh.ConnectorData default: refreshCtx.connectorData = session.ConnectorData } return &refreshCtx, nil } func (s *Server) getRefreshScopes(r *http.Request, refresh *storage.RefreshToken) ([]string, *refreshError) { // Per the OAuth2 spec, if the client has omitted the scopes, default to the original // authorized scopes. // // https://tools.ietf.org/html/rfc6749#section-6 scope := r.PostFormValue("scope") if scope == "" { return refresh.Scopes, nil } requestedScopes := strings.Fields(scope) var unauthorizedScopes []string // Per the OAuth2 spec, if the client has omitted the scopes, default to the original // authorized scopes. // // https://tools.ietf.org/html/rfc6749#section-6 for _, requestScope := range requestedScopes { if !contains(refresh.Scopes, requestScope) { unauthorizedScopes = append(unauthorizedScopes, requestScope) } } if len(unauthorizedScopes) > 0 { desc := fmt.Sprintf("Requested scopes contain unauthorized scope(s): %q.", unauthorizedScopes) return nil, newBadRequestError(desc) } return requestedScopes, nil } func (s *Server) refreshWithConnector(ctx context.Context, rCtx *refreshContext, ident connector.Identity) (connector.Identity, *refreshError) { // Can the connector refresh the identity? If so, attempt to refresh the data // in the connector. // // TODO(ericchiang): We may want a strict mode where connectors that don't implement // this interface can't perform refreshing. if refreshConn, ok := rCtx.connector.Connector.(connector.RefreshConnector); ok { // Set connector data to the one received from an offline session ident.ConnectorData = rCtx.connectorData s.logger.Debug("connector data before refresh", "connector_data", ident.ConnectorData) newIdent, err := refreshConn.Refresh(ctx, parseScopes(rCtx.scopes), ident) if err != nil { s.logger.ErrorContext(ctx, "failed to refresh identity", "err", err) return ident, newInternalServerError() } return newIdent, nil } return ident, nil } // updateOfflineSession updates offline session in the storage func (s *Server) updateOfflineSession(ctx context.Context, refresh *storage.RefreshToken, ident connector.Identity, lastUsed time.Time) *refreshError { offlineSessionUpdater := func(old storage.OfflineSessions) (storage.OfflineSessions, error) { if old.Refresh[refresh.ClientID].ID != refresh.ID { return old, errors.New("refresh token invalid") } old.Refresh[refresh.ClientID].LastUsed = lastUsed if len(ident.ConnectorData) > 0 { old.ConnectorData = ident.ConnectorData } s.logger.DebugContext(ctx, "saved connector data", "user_id", ident.UserID, "connector_data", ident.ConnectorData) return old, nil } // Update LastUsed time stamp in refresh token reference object // in offline session for the user. err := s.storage.UpdateOfflineSessions(ctx, refresh.Claims.UserID, refresh.ConnectorID, offlineSessionUpdater) if err != nil { s.logger.ErrorContext(ctx, "failed to update offline session", "err", err) return newInternalServerError() } return nil } // updateRefreshToken updates refresh token and offline session in the storage func (s *Server) updateRefreshToken(ctx context.Context, rCtx *refreshContext) (*internal.RefreshToken, connector.Identity, *refreshError) { var rerr *refreshError newToken := &internal.RefreshToken{ Token: rCtx.requestToken.Token, RefreshId: rCtx.requestToken.RefreshId, } lastUsed := s.now() ident := connector.Identity{ UserID: rCtx.storageToken.Claims.UserID, Username: rCtx.storageToken.Claims.Username, PreferredUsername: rCtx.storageToken.Claims.PreferredUsername, Email: rCtx.storageToken.Claims.Email, EmailVerified: rCtx.storageToken.Claims.EmailVerified, Groups: rCtx.storageToken.Claims.Groups, } refreshTokenUpdater := func(old storage.RefreshToken) (storage.RefreshToken, error) { rotationEnabled := s.refreshTokenPolicy.RotationEnabled() reusingAllowed := s.refreshTokenPolicy.AllowedToReuse(old.LastUsed) switch { case !rotationEnabled && reusingAllowed: // If rotation is disabled and the offline session was updated not so long ago - skip further actions. old.ConnectorData = nil return old, nil case rotationEnabled && reusingAllowed: if old.Token != rCtx.requestToken.Token && old.ObsoleteToken != rCtx.requestToken.Token { return old, errors.New("refresh token claimed twice") } // Return previously generated token for all requests with an obsolete tokens if old.ObsoleteToken == rCtx.requestToken.Token { newToken.Token = old.Token } // Do not update last used time for offline session if token is allowed to be reused lastUsed = old.LastUsed old.ConnectorData = nil return old, nil case rotationEnabled && !reusingAllowed: if old.Token != rCtx.requestToken.Token { return old, errors.New("refresh token claimed twice") } // Issue new refresh token old.ObsoleteToken = old.Token newToken.Token = storage.NewID() } old.Token = newToken.Token old.LastUsed = lastUsed // ConnectorData has been moved to OfflineSession old.ConnectorData = nil // Call only once if there is a request which is not in the reuse interval. // This is required to avoid multiple calls to the external IdP for concurrent requests. // Dex will call the connector's Refresh method only once if request is not in reuse interval. ident, rerr = s.refreshWithConnector(ctx, rCtx, ident) if rerr != nil { return old, rerr } // Update the claims of the refresh token. // // UserID intentionally ignored for now. old.Claims.Username = ident.Username old.Claims.PreferredUsername = ident.PreferredUsername old.Claims.Email = ident.Email old.Claims.EmailVerified = ident.EmailVerified old.Claims.Groups = ident.Groups return old, nil } // Update refresh token in the storage. err := s.storage.UpdateRefreshToken(ctx, rCtx.storageToken.ID, refreshTokenUpdater) if err != nil { s.logger.ErrorContext(ctx, "failed to update refresh token", "err", err) return nil, ident, newInternalServerError() } rerr = s.updateOfflineSession(ctx, rCtx.storageToken, ident, lastUsed) if rerr != nil { return nil, ident, rerr } return newToken, ident, nil } // handleRefreshToken handles a refresh token request https://tools.ietf.org/html/rfc6749#section-6 // this method is the entrypoint for refresh tokens handling func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, client storage.Client) { token, rerr := s.extractRefreshTokenFromRequest(r) if rerr != nil { s.refreshTokenErrHelper(w, rerr) return } rCtx, rerr := s.getRefreshTokenFromStorage(r.Context(), &client.ID, token) if rerr != nil { s.refreshTokenErrHelper(w, rerr) return } rCtx.scopes, rerr = s.getRefreshScopes(r, rCtx.storageToken) if rerr != nil { s.refreshTokenErrHelper(w, rerr) return } newToken, ident, rerr := s.updateRefreshToken(r.Context(), rCtx) if rerr != nil { s.refreshTokenErrHelper(w, rerr) return } claims := storage.Claims{ UserID: ident.UserID, Username: ident.Username, PreferredUsername: ident.PreferredUsername, Email: ident.Email, EmailVerified: ident.EmailVerified, Groups: ident.Groups, } accessToken, _, err := s.newAccessToken(r.Context(), client.ID, claims, rCtx.scopes, rCtx.storageToken.Nonce, rCtx.storageToken.ConnectorID) if err != nil { s.logger.ErrorContext(r.Context(), "failed to create new access token", "err", err) s.refreshTokenErrHelper(w, newInternalServerError()) return } idToken, expiry, err := s.newIDToken(r.Context(), client.ID, claims, rCtx.scopes, rCtx.storageToken.Nonce, accessToken, "", rCtx.storageToken.ConnectorID) if err != nil { s.logger.ErrorContext(r.Context(), "failed to create ID token", "err", err) s.refreshTokenErrHelper(w, newInternalServerError()) return } rawNewToken, err := internal.Marshal(newToken) if err != nil { s.logger.ErrorContext(r.Context(), "failed to marshal refresh token", "err", err) s.refreshTokenErrHelper(w, newInternalServerError()) return } resp := s.toAccessTokenResponse(idToken, accessToken, rawNewToken, expiry) s.writeAccessToken(w, resp) }