From 2cf82eec535c3bc493cb47ee75707717c5cd7f36 Mon Sep 17 00:00:00 2001 From: kanywst Date: Fri, 6 Mar 2026 01:28:06 +0900 Subject: [PATCH] feat(oauth2): implement ID-JAG token issuance via Token Exchange (DEP #4600) Add support for issuing Identity Assertion JWTs (ID-JAG) per draft-ietf-oauth-identity-assertion-authz-grant-01. - Token Exchange with requested_token_type=id-jag issues a signed JWT with typ "oauth-id-jag+jwt", preserving the subject_token's sub claim - Per-client idJAGPolicies in staticClients control allowed audiences and scopes (default-deny) - Reject public clients per Section 7.1 of the spec - Require connector_id parameter for ID-JAG requests - Advertise identity_chaining_requested_token_types_supported and id_jag_signing_alg_values_supported in OIDC Discovery (Section 6) - ID-JAG is opt-in via oauth2.tokenExchange.tokenTypes config Signed-off-by: kanywst --- cmd/dex/config.go | 19 ++++- cmd/dex/config_test.go | 12 +-- cmd/dex/serve.go | 36 +++++++- config.yaml.dist | 24 ++++++ server/handlers.go | 178 ++++++++++++++++++++++++++++++++++++---- server/handlers_test.go | 164 ++++++++++++++++++++++++++++++++++++ server/oauth2.go | 67 ++++++++++++++- server/policy.go | 70 ++++++++++++++++ server/policy_test.go | 102 +++++++++++++++++++++++ server/server.go | 30 +++++++ server/signer/local.go | 18 ++++ server/signer/mock.go | 4 + server/signer/signer.go | 2 + server/signer/utils.go | 17 ++++ server/signer/vault.go | 82 ++++++++++++++++++ 15 files changed, 795 insertions(+), 30 deletions(-) create mode 100644 server/policy.go create mode 100644 server/policy_test.go diff --git a/cmd/dex/config.go b/cmd/dex/config.go index 913d4dfe..e99618c7 100644 --- a/cmd/dex/config.go +++ b/cmd/dex/config.go @@ -55,7 +55,7 @@ type Config struct { // StaticClients cause the server to use this list of clients rather than // querying the storage. Write operations, like creating a client, will fail. - StaticClients []storage.Client `json:"staticClients"` + StaticClients []staticClient `json:"staticClients"` // If enabled, the server will maintain a list of passwords which can be used // to identify a user. @@ -158,6 +158,18 @@ func (p *password) UnmarshalJSON(b []byte) error { return nil } +// staticClient wraps storage.Client with optional per-client ID-JAG policy. +type staticClient struct { + storage.Client + IDJAGPolicies *IDJAGClientPolicy `json:"idJAGPolicies,omitempty"` +} + +// IDJAGClientPolicy configures allowed audiences and scopes for ID-JAG exchange. +type IDJAGClientPolicy struct { + AllowedAudiences []string `json:"allowedAudiences"` + AllowedScopes []string `json:"allowedScopes"` +} + // OAuth2 describes enabled OAuth2 extensions. type OAuth2 struct { // list of allowed grant types, @@ -174,6 +186,8 @@ type OAuth2 struct { PasswordConnector string `json:"passwordConnector"` // PKCE configuration PKCE PKCE `json:"pkce"` + // TokenExchange configures Token Exchange support. + TokenExchange server.TokenExchangeConfig `json:"tokenExchange"` } // PKCE holds the PKCE (Proof Key for Code Exchange) configuration. @@ -554,6 +568,9 @@ type Expiry struct { // IdTokens defines the duration of time for which the IdTokens will be valid. IDTokens string `json:"idTokens"` + // IDJAGTokens defines the duration of time for which ID-JAG tokens will be valid. + IDJAGTokens string `json:"idJAGTokens"` + // AuthRequests defines the duration of time for which the AuthRequests will be valid. AuthRequests string `json:"authRequests"` diff --git a/cmd/dex/config_test.go b/cmd/dex/config_test.go index 26385f56..8451ba8f 100644 --- a/cmd/dex/config_test.go +++ b/cmd/dex/config_test.go @@ -182,15 +182,15 @@ additionalFeatures: [ "foo": "bar", }, }, - StaticClients: []storage.Client{ - { + StaticClients: []staticClient{ + {Client: storage.Client{ ID: "example-app", Secret: "ZXhhbXBsZS1hcHAtc2VjcmV0", Name: "Example App", RedirectURIs: []string{ "http://127.0.0.1:5555/callback", }, - }, + }}, }, OAuth2: OAuth2{ AlwaysShowLoginScreen: true, @@ -411,15 +411,15 @@ logger: "foo": "bar", }, }, - StaticClients: []storage.Client{ - { + StaticClients: []staticClient{ + {Client: storage.Client{ ID: "example-app", Secret: "ZXhhbXBsZS1hcHAtc2VjcmV0", Name: "Example App", RedirectURIs: []string{ "http://127.0.0.1:5555/callback", }, - }, + }}, }, OAuth2: OAuth2{ AlwaysShowLoginScreen: true, diff --git a/cmd/dex/serve.go b/cmd/dex/serve.go index cd9d3839..89aa3660 100644 --- a/cmd/dex/serve.go +++ b/cmd/dex/serve.go @@ -213,7 +213,9 @@ func runServe(options serveOptions) error { logger.Info("config storage", "storage_type", c.Storage.Type) if len(c.StaticClients) > 0 { - for i, client := range c.StaticClients { + storageClients := make([]storage.Client, len(c.StaticClients)) + for i, sc := range c.StaticClients { + client := sc.Client if client.Name == "" { return fmt.Errorf("invalid config: Name field is required for a client") } @@ -224,7 +226,7 @@ func runServe(options serveOptions) error { if client.ID != "" { return fmt.Errorf("invalid config: ID and IDEnv fields are exclusive for client %q", client.ID) } - c.StaticClients[i].ID = os.Getenv(client.IDEnv) + client.ID = os.Getenv(client.IDEnv) } if client.Secret == "" && client.SecretEnv == "" && !client.Public { return fmt.Errorf("invalid config: Secret or SecretEnv field is required for client %q", client.ID) @@ -233,11 +235,12 @@ func runServe(options serveOptions) error { if client.Secret != "" { return fmt.Errorf("invalid config: Secret and SecretEnv fields are exclusive for client %q", client.ID) } - c.StaticClients[i].Secret = os.Getenv(client.SecretEnv) + client.Secret = os.Getenv(client.SecretEnv) } logger.Info("config static client", "client_name", client.Name) + storageClients[i] = client } - s = storage.WithStaticClients(s, c.StaticClients) + s = storage.WithStaticClients(s, storageClients) } if len(c.StaticPasswords) > 0 { passwords := make([]storage.Password, len(c.StaticPasswords)) @@ -384,6 +387,7 @@ func runServe(options serveOptions) error { ContinueOnConnectorFailure: featureflags.ContinueOnConnectorFailure.Enabled(), Signer: signerInstance, IDTokensValidFor: idTokensValidFor, + TokenExchange: c.OAuth2.TokenExchange, } if c.Expiry.AuthRequests != "" { @@ -402,6 +406,30 @@ func runServe(options serveOptions) error { logger.Info("config device requests", "valid_for", deviceRequests) serverConfig.DeviceRequestsValidFor = deviceRequests } + if c.Expiry.IDJAGTokens != "" { + idJAGTokens, err := time.ParseDuration(c.Expiry.IDJAGTokens) + if err != nil { + return fmt.Errorf("invalid config value %q for ID-JAG token expiry: %v", c.Expiry.IDJAGTokens, err) + } + logger.Info("config ID-JAG tokens", "valid_for", idJAGTokens) + serverConfig.IDJAGTokensValidFor = idJAGTokens + } + + // Build per-client ID-JAG policies from static client config. + for _, sc := range c.StaticClients { + if sc.IDJAGPolicies != nil { + clientID := sc.Client.ID + if clientID == "" && sc.Client.IDEnv != "" { + clientID = os.Getenv(sc.Client.IDEnv) + } + serverConfig.IDJAGPolicies = append(serverConfig.IDJAGPolicies, server.TokenExchangePolicy{ + ClientID: clientID, + AllowedAudiences: sc.IDJAGPolicies.AllowedAudiences, + AllowedScopes: sc.IDJAGPolicies.AllowedScopes, + }) + } + } + refreshTokenPolicy, err := server.NewRefreshTokenPolicy( logger, c.Expiry.RefreshTokens.DisableRotation, diff --git a/config.yaml.dist b/config.yaml.dist index 5d2c37ea..094501ac 100644 --- a/config.yaml.dist +++ b/config.yaml.dist @@ -87,6 +87,7 @@ web: # deviceRequests: "5m" # signingKeys: "6h" # idTokens: "24h" +# idJAGTokens: "5m" # default: 5m; independent of idTokens # refreshTokens: # disableRotation: false # reuseInterval: "3s" @@ -116,6 +117,14 @@ web: # enforce: false # # Supported code challenge methods. Defaults to ["S256", "plain"]. # codeChallengeMethodsSupported: ["S256", "plain"] +# +# # Token Exchange configuration +# tokenExchange: +# # List of token types enabled for exchange. Adding id-jag enables ID-JAG support. +# # Omitting it (default) disables ID-JAG without affecting other token exchange flows. +# tokenTypes: +# - urn:ietf:params:oauth:token-type:id_token +# - urn:ietf:params:oauth:token-type:id-jag # Static clients registered in Dex by default. # @@ -151,6 +160,21 @@ web: # allowedConnectors: # - github # - google +# +# # Example of a client with ID-JAG token exchange policy +# - id: wiki-app +# secret: wiki-secret +# redirectURIs: +# - 'https://wiki.example/callback' +# name: 'Wiki Application' +# # Per-client ID-JAG policy. Clients without this section cannot obtain ID-JAG tokens. +# idJAGPolicies: +# allowedAudiences: +# - "https://chat.example/" +# - "https://calendar.example/" +# allowedScopes: +# - "chat.read" +# - "calendar.read" # Connectors are used to authenticate users against upstream identity providers. # diff --git a/server/handlers.go b/server/handlers.go index e60715d9..e9297dcc 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -73,21 +73,23 @@ func (s *Server) handlePublicKeys(w http.ResponseWriter, r *http.Request) { } type discovery struct { - Issuer string `json:"issuer"` - Auth string `json:"authorization_endpoint"` - Token string `json:"token_endpoint"` - Keys string `json:"jwks_uri"` - UserInfo string `json:"userinfo_endpoint"` - DeviceEndpoint string `json:"device_authorization_endpoint"` - Introspect string `json:"introspection_endpoint"` - GrantTypes []string `json:"grant_types_supported"` - ResponseTypes []string `json:"response_types_supported"` - Subjects []string `json:"subject_types_supported"` - IDTokenAlgs []string `json:"id_token_signing_alg_values_supported"` - CodeChallengeAlgs []string `json:"code_challenge_methods_supported"` - Scopes []string `json:"scopes_supported"` - AuthMethods []string `json:"token_endpoint_auth_methods_supported"` - Claims []string `json:"claims_supported"` + Issuer string `json:"issuer"` + Auth string `json:"authorization_endpoint"` + Token string `json:"token_endpoint"` + Keys string `json:"jwks_uri"` + UserInfo string `json:"userinfo_endpoint"` + DeviceEndpoint string `json:"device_authorization_endpoint"` + Introspect string `json:"introspection_endpoint"` + GrantTypes []string `json:"grant_types_supported"` + ResponseTypes []string `json:"response_types_supported"` + Subjects []string `json:"subject_types_supported"` + IDTokenAlgs []string `json:"id_token_signing_alg_values_supported"` + CodeChallengeAlgs []string `json:"code_challenge_methods_supported"` + Scopes []string `json:"scopes_supported"` + AuthMethods []string `json:"token_endpoint_auth_methods_supported"` + Claims []string `json:"claims_supported"` + IDJAGSigningAlgs []string `json:"id_jag_signing_alg_values_supported,omitempty"` + IdentityChainingTokenTypes []string `json:"identity_chaining_requested_token_types_supported,omitempty"` } func (s *Server) discoveryHandler(ctx context.Context) (http.HandlerFunc, error) { @@ -133,6 +135,11 @@ func (s *Server) constructDiscovery(ctx context.Context) discovery { d.IDTokenAlgs = []string{string(signingAlg)} } + if s.enableIDJAG { + d.IDJAGSigningAlgs = d.IDTokenAlgs + d.IdentityChainingTokenTypes = []string{tokenTypeIDJAG} + } + for responseType := range s.supportedResponseTypes { d.ResponseTypes = append(d.ResponseTypes, responseType) } @@ -1547,6 +1554,15 @@ func (s *Server) handleTokenExchange(w http.ResponseWriter, r *http.Request, cli return } + if requestedTokenType == tokenTypeIDJAG { + if !s.enableIDJAG { + s.tokenErrHelper(w, errRequestNotSupported, "ID-JAG token exchange is not enabled on this server.", http.StatusBadRequest) + return + } + s.handleIDJAGExchange(w, r, client, subjectToken, subjectTokenType, connID, scopes) + return + } + conn, err := s.getConnector(ctx, connID) if err != nil { s.logger.ErrorContext(r.Context(), "failed to get connector", "err", err) @@ -1607,6 +1623,138 @@ func (s *Server) handleTokenExchange(w http.ResponseWriter, r *http.Request, cli json.NewEncoder(w).Encode(resp) } +// handleIDJAGExchange handles a Token Exchange request with requested_token_type=ID-JAG. +// See: https://datatracker.ietf.org/doc/draft-ietf-oauth-identity-assertion-authz-grant/ +func (s *Server) handleIDJAGExchange(w http.ResponseWriter, r *http.Request, client storage.Client, subjectToken, subjectTokenType string, connectorID string, scopes []string) { + ctx := r.Context() + q := r.Form + + audience := q.Get("audience") + resource := q.Get("resource") + + // Reject public clients (Section 7.1). + if client.Public { + s.tokenErrHelper(w, errUnauthorizedClient, "Public clients cannot use ID-JAG token exchange.", http.StatusBadRequest) + return + } + + // connector_id is required for identifying the upstream connector. + if connectorID == "" { + s.tokenErrHelper(w, errInvalidRequest, "Missing required parameter connector_id for ID-JAG token exchange.", http.StatusBadRequest) + return + } + + if _, err := s.getConnector(ctx, connectorID); err != nil { + s.logger.ErrorContext(ctx, "connector not found for ID-JAG exchange", "connector_id", connectorID, "err", err) + s.tokenErrHelper(w, errInvalidRequest, "Requested connector does not exist.", http.StatusBadRequest) + return + } + + // audience is required. + if audience == "" { + s.tokenErrHelper(w, errInvalidRequest, "Missing required parameter audience for ID-JAG token exchange.", http.StatusBadRequest) + return + } + + // subject_token_type must be id_token. + if subjectTokenType != tokenTypeID { + s.tokenErrHelper(w, errRequestNotSupported, "ID-JAG token exchange requires subject_token_type=urn:ietf:params:oauth:token-type:id_token.", http.StatusBadRequest) + return + } + + // Extract sub and aud from the subject_token. + sub, tokenAud, err := extractJWTSubAndAud(subjectToken) + if err != nil { + s.logger.ErrorContext(ctx, "failed to extract claims from subject_token", "err", err) + s.tokenErrHelper(w, errInvalidRequest, "Invalid subject_token: could not parse JWT claims.", http.StatusBadRequest) + return + } + if sub == "" { + s.tokenErrHelper(w, errInvalidRequest, "subject_token missing required sub claim.", http.StatusBadRequest) + return + } + + // Validate that the subject_token audience matches the requesting client (Section 4.3). + if !audContains(tokenAud, client.ID) { + s.logger.InfoContext(ctx, "subject_token audience does not match client_id", + "token_aud", tokenAud, "client_id", client.ID) + s.tokenErrHelper(w, errInvalidRequest, "subject_token audience does not match client_id.", http.StatusBadRequest) + return + } + + // Evaluate access policy. + if err := evaluateIDJAGPolicy(s.tokenExchangePolicies, client.ID, audience, scopes); err != nil { + s.logger.InfoContext(ctx, "ID-JAG policy denied", "client_id", client.ID, "audience", audience, "err", err) + s.tokenErrHelper(w, errAccessDenied, "", http.StatusForbidden) + return + } + + idJAGToken, expiry, err := s.newIDJAG(ctx, client.ID, sub, audience, resource, scopes) + if err != nil { + s.logger.ErrorContext(ctx, "failed to create ID-JAG token", "err", err) + s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + return + } + + // RFC 8693 ยง2.2.1: token_type is "N_A" for non-access tokens. + resp := accessTokenResponse{ + AccessToken: idJAGToken, + IssuedTokenType: tokenTypeIDJAG, + TokenType: "N_A", + ExpiresIn: int(time.Until(expiry).Seconds()), + } + + w.Header().Set("Cache-Control", "no-store") + w.Header().Set("Pragma", "no-cache") + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) +} + +// extractJWTSubAndAud extracts the "sub" and "aud" claims from a JWT without +// verifying the signature. The aud claim may be a string or []string. +func extractJWTSubAndAud(token string) (sub string, aud []string, err error) { + parts := strings.SplitN(token, ".", 3) + if len(parts) != 3 { + return "", nil, fmt.Errorf("malformed JWT: expected 3 parts, got %d", len(parts)) + } + payloadBytes, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return "", nil, fmt.Errorf("failed to decode JWT payload: %v", err) + } + + var claims struct { + Sub string `json:"sub"` + Aud json.RawMessage `json:"aud"` + } + if err := json.Unmarshal(payloadBytes, &claims); err != nil { + return "", nil, fmt.Errorf("failed to unmarshal JWT payload: %v", err) + } + + if len(claims.Aud) > 0 { + var single string + if err := json.Unmarshal(claims.Aud, &single); err == nil { + aud = []string{single} + } else { + var multi []string + if err := json.Unmarshal(claims.Aud, &multi); err == nil { + aud = multi + } + } + } + + return claims.Sub, aud, nil +} + +// audContains reports whether target is in aud. +func audContains(aud []string, target string) bool { + for _, a := range aud { + if a == target { + return true + } + } + return false +} + func (s *Server) handleClientCredentialsGrant(w http.ResponseWriter, r *http.Request, client storage.Client) { ctx := r.Context() diff --git a/server/handlers_test.go b/server/handlers_test.go index 12c664f5..b9b3e193 100644 --- a/server/handlers_test.go +++ b/server/handlers_test.go @@ -3,6 +3,7 @@ package server import ( "bytes" "context" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -1343,6 +1344,169 @@ func (m *mockSAMLRefreshConnector) Refresh(ctx context.Context, s connector.Scop return m.refreshIdentity, nil } +// makeTestJWT builds a minimal JWT with the given sub for testing. +func makeTestJWT(sub string) string { + header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256","typ":"JWT"}`)) + payload := base64.RawURLEncoding.EncodeToString([]byte(`{"sub":"` + sub + `","iss":"https://issuer.example","aud":"client_1","exp":9999999999}`)) + return header + "." + payload + ".fakesig" +} + +// TestExtractJWTSubAndAud tests extractJWTSubAndAud. +func TestExtractJWTSubAndAud(t *testing.T) { + tests := []struct { + name string + token string + wantSub string + wantAud []string + wantErr bool + }{ + { + name: "valid JWT returns sub and aud", + token: makeTestJWT("user-abc-123"), + wantSub: "user-abc-123", + wantAud: []string{"client_1"}, + }, + { + name: "not a JWT (no dots)", + token: "notajwt", + wantErr: true, + }, + { + name: "invalid base64 payload", + token: "aGVhZGVy.!!!.c2ln", + wantErr: true, + }, + { + name: "valid JWT without sub returns empty string", + token: func() string { + h := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256"}`)) + p := base64.RawURLEncoding.EncodeToString([]byte(`{"iss":"https://issuer.example"}`)) + return h + "." + p + ".sig" + }(), + wantSub: "", + wantAud: nil, + wantErr: false, + }, + { + name: "aud as array", + token: func() string { + h := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256"}`)) + p := base64.RawURLEncoding.EncodeToString([]byte(`{"sub":"u1","aud":["a","b"]}`)) + return h + "." + p + ".sig" + }(), + wantSub: "u1", + wantAud: []string{"a", "b"}, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + sub, aud, err := extractJWTSubAndAud(tc.token) + if tc.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, tc.wantSub, sub) + require.Equal(t, tc.wantAud, aud) + }) + } +} + +// TestHandleIDJAGExchange tests the ID-JAG token exchange handler. +func TestHandleIDJAGExchange(t *testing.T) { + subjectToken := makeTestJWT("user-123") + + tests := []struct { + name string + audience string + subjectTokenType string + subjectToken string + policies []TokenExchangePolicy + wantCode int + wantTokenTypeNA bool + }{ + { + name: "happy path: valid ID-JAG issued", + audience: "https://resource-as.example.com", + subjectTokenType: tokenTypeID, + subjectToken: subjectToken, + wantCode: http.StatusOK, + wantTokenTypeNA: true, + }, + { + name: "missing audience returns 400", + audience: "", + subjectTokenType: tokenTypeID, + subjectToken: subjectToken, + wantCode: http.StatusBadRequest, + }, + { + name: "wrong subject_token_type returns 400", + audience: "https://resource-as.example.com", + subjectTokenType: tokenTypeAccess, // must be id_token for ID-JAG + subjectToken: subjectToken, + wantCode: http.StatusBadRequest, + }, + { + name: "policy denies the audience: 403", + audience: "https://resource-as.example.com", + subjectTokenType: tokenTypeID, + subjectToken: subjectToken, + policies: []TokenExchangePolicy{ + // client_1 may only reach other.example.com, not resource-as.example.com + {ClientID: "client_1", AllowedAudiences: []string{"https://other.example.com"}}, + }, + wantCode: http.StatusForbidden, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctx := t.Context() + httpServer, s := newTestServer(t, func(c *Config) { + require.NoError(t, c.Storage.CreateClient(ctx, storage.Client{ + ID: "client_1", + Secret: "secret_1", + })) + c.TokenExchange = TokenExchangeConfig{ + TokenTypes: []string{tokenTypeIDJAG}, + } + c.IDJAGPolicies = tc.policies + }) + defer httpServer.Close() + + vals := url.Values{} + vals.Set("grant_type", grantTypeTokenExchange) + vals.Set("requested_token_type", tokenTypeIDJAG) + vals.Set("subject_token_type", tc.subjectTokenType) + vals.Set("subject_token", tc.subjectToken) + vals.Set("connector_id", "mock") + if tc.audience != "" { + vals.Set("audience", tc.audience) + } + vals.Set("client_id", "client_1") + vals.Set("client_secret", "secret_1") + + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, httpServer.URL+"/token", strings.NewReader(vals.Encode())) + req.Header.Set("content-type", "application/x-www-form-urlencoded") + s.handleToken(rr, req) + + require.Equal(t, tc.wantCode, rr.Code, "body: %s", rr.Body.String()) + + if tc.wantTokenTypeNA { + var res accessTokenResponse + require.NoError(t, json.NewDecoder(rr.Result().Body).Decode(&res)) + require.Equal(t, "N_A", res.TokenType) + require.Equal(t, tokenTypeIDJAG, res.IssuedTokenType) + require.NotEmpty(t, res.AccessToken) + require.Equal(t, 3, len(strings.Split(res.AccessToken, ".")), "expected compact JWT") + require.Greater(t, res.ExpiresIn, 0) + } + }) + } +} + func TestFilterConnectors(t *testing.T) { connectors := []storage.Connector{ {ID: "github", Type: "github", Name: "GitHub"}, diff --git a/server/oauth2.go b/server/oauth2.go index 9f12d1d0..027a909c 100644 --- a/server/oauth2.go +++ b/server/oauth2.go @@ -165,6 +165,8 @@ const ( tokenTypeSAML1 = "urn:ietf:params:oauth:token-type:saml1" tokenTypeSAML2 = "urn:ietf:params:oauth:token-type:saml2" tokenTypeJWT = "urn:ietf:params:oauth:token-type:jwt" + // https://datatracker.ietf.org/doc/draft-ietf-oauth-identity-assertion-authz-grant/ + tokenTypeIDJAG = "urn:ietf:params:oauth:token-type:id-jag" ) const ( @@ -281,6 +283,65 @@ func (s *Server) newAccessToken(ctx context.Context, clientID string, claims sto return s.newIDToken(ctx, clientID, claims, scopes, nonce, storage.NewID(), "", connID) } +// idJAGTyp is the JWT "typ" header value for ID-JAG tokens. +const idJAGTyp = "oauth-id-jag+jwt" + +// idJAGClaims is the JWT payload for an ID-JAG token. +type idJAGClaims struct { + Issuer string `json:"iss"` + Subject string `json:"sub"` + Audience string `json:"aud"` + ClientID string `json:"client_id"` + JTI string `json:"jti"` + Expiry int64 `json:"exp"` + IssuedAt int64 `json:"iat"` + + // Optional claims. + Resource string `json:"resource,omitempty"` + Scope string `json:"scope,omitempty"` + AuthTime int64 `json:"auth_time,omitempty"` + ACR string `json:"acr,omitempty"` + AMR []string `json:"amr,omitempty"` +} + +// newIDJAG creates an ID-JAG token with the given subject and audience. +func (s *Server) newIDJAG( + ctx context.Context, + clientID string, + subject string, + audience string, + resource string, + scopes []string, +) (token string, expiry time.Time, err error) { + issuedAt := s.now() + expiry = issuedAt.Add(s.idJAGTokensValidFor) + + claims := idJAGClaims{ + Issuer: s.issuerURL.String(), + Subject: subject, + Audience: audience, + ClientID: clientID, + JTI: storage.NewID(), + Expiry: expiry.Unix(), + IssuedAt: issuedAt.Unix(), + Resource: resource, + } + + if len(scopes) > 0 { + claims.Scope = strings.Join(scopes, " ") + } + + payload, err := json.Marshal(claims) + if err != nil { + return "", expiry, fmt.Errorf("could not serialize ID-JAG claims: %v", err) + } + + if token, err = s.signer.SignWithType(ctx, payload, idJAGTyp); err != nil { + return "", expiry, fmt.Errorf("failed to sign ID-JAG payload: %v", err) + } + return token, expiry, nil +} + func getClientID(aud audience, azp string) (string, error) { switch len(aud) { case 0: @@ -488,8 +549,7 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques } if codeChallenge != "" && !slices.Contains(s.pkce.CodeChallengeMethodsSupported, codeChallengeMethod) { - description := fmt.Sprintf("Unsupported PKCE challenge method (%q).", codeChallengeMethod) - return nil, newRedirectedErr(errInvalidRequest, description) + return nil, newRedirectedErr(errInvalidRequest, "Unsupported PKCE challenge method (%q).", codeChallengeMethod) } // Enforce PKCE if configured. @@ -578,8 +638,7 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques } if rt.token { if redirectURI == redirectURIOOB { - err := fmt.Sprintf("Cannot use response type 'token' with redirect_uri '%s'.", redirectURIOOB) - return nil, newRedirectedErr(errInvalidRequest, err) + return nil, newRedirectedErr(errInvalidRequest, "Cannot use response type 'token' with redirect_uri '%s'.", redirectURIOOB) } } diff --git a/server/policy.go b/server/policy.go new file mode 100644 index 00000000..793bf699 --- /dev/null +++ b/server/policy.go @@ -0,0 +1,70 @@ +package server + +import "fmt" + +// TokenExchangePolicy defines per-client access control for ID-JAG token exchange. +type TokenExchangePolicy struct { + // ClientID is the client this policy applies to. Use "*" for a default policy. + ClientID string `json:"clientID"` + AllowedAudiences []string `json:"allowedAudiences"` + AllowedScopes []string `json:"allowedScopes"` +} + +// evaluateIDJAGPolicy checks whether the client is permitted to obtain an ID-JAG +// for the given audience and scopes. No policies configured means allow all. +func evaluateIDJAGPolicy(policies []TokenExchangePolicy, clientID, audience string, scopes []string) error { + if len(policies) == 0 { + return nil + } + + // Find the most-specific policy for this client: exact match first, then wildcard. + var matched *TokenExchangePolicy + for i := range policies { + p := &policies[i] + if p.ClientID == clientID { + matched = p + break + } + if p.ClientID == "*" && matched == nil { + matched = p + } + } + + if matched == nil { + return fmt.Errorf("no policy found for client %q: access_denied", clientID) + } + + // Check audience. + if !audienceAllowed(matched.AllowedAudiences, audience) { + return fmt.Errorf("audience %q is not allowed for client %q: access_denied", audience, clientID) + } + + // Check scopes (only if policy restricts them). + if len(matched.AllowedScopes) > 0 { + for _, scope := range scopes { + if !scopeAllowed(matched.AllowedScopes, scope) { + return fmt.Errorf("scope %q is not allowed for client %q: access_denied", scope, clientID) + } + } + } + + return nil +} + +func audienceAllowed(allowedAudiences []string, audience string) bool { + for _, a := range allowedAudiences { + if a == audience { + return true + } + } + return false +} + +func scopeAllowed(allowedScopes []string, scope string) bool { + for _, s := range allowedScopes { + if s == scope { + return true + } + } + return false +} diff --git a/server/policy_test.go b/server/policy_test.go new file mode 100644 index 00000000..41a8fe1e --- /dev/null +++ b/server/policy_test.go @@ -0,0 +1,102 @@ +package server + +import ( + "testing" +) + +func TestEvaluateIDJAGPolicy(t *testing.T) { + tests := []struct { + name string + policies []TokenExchangePolicy + clientID string + audience string + scopes []string + wantErr bool + }{ + { + name: "no policies: allow all", + policies: nil, + clientID: "any-client", + audience: "https://resource.example.com", + wantErr: false, + }, + { + name: "exact match allowed", + policies: []TokenExchangePolicy{ + {ClientID: "client-a", AllowedAudiences: []string{"https://resource.example.com"}}, + }, + clientID: "client-a", + audience: "https://resource.example.com", + wantErr: false, + }, + { + name: "audience not allowed", + policies: []TokenExchangePolicy{ + {ClientID: "client-a", AllowedAudiences: []string{"https://other.example.com"}}, + }, + clientID: "client-a", + audience: "https://resource.example.com", + wantErr: true, + }, + { + name: "client not found: denied", + policies: []TokenExchangePolicy{ + {ClientID: "client-a", AllowedAudiences: []string{"https://resource.example.com"}}, + }, + clientID: "unknown-client", + audience: "https://resource.example.com", + wantErr: true, + }, + { + name: "wildcard client matches", + policies: []TokenExchangePolicy{ + {ClientID: "*", AllowedAudiences: []string{"https://resource.example.com"}}, + }, + clientID: "any-client", + audience: "https://resource.example.com", + wantErr: false, + }, + { + name: "exact match takes priority over wildcard", + policies: []TokenExchangePolicy{ + {ClientID: "*", AllowedAudiences: []string{"https://other.example.com"}}, + {ClientID: "client-a", AllowedAudiences: []string{"https://resource.example.com"}}, + }, + clientID: "client-a", + audience: "https://resource.example.com", + wantErr: false, + }, + { + name: "scope denied by policy", + policies: []TokenExchangePolicy{ + {ClientID: "client-a", AllowedAudiences: []string{"https://resource.example.com"}, AllowedScopes: []string{"read"}}, + }, + clientID: "client-a", + audience: "https://resource.example.com", + scopes: []string{"admin"}, + wantErr: true, + }, + { + name: "allowed scope passes", + policies: []TokenExchangePolicy{ + {ClientID: "client-a", AllowedAudiences: []string{"https://resource.example.com"}, AllowedScopes: []string{"read", "write"}}, + }, + clientID: "client-a", + audience: "https://resource.example.com", + scopes: []string{"read"}, + wantErr: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := evaluateIDJAGPolicy(tc.policies, tc.clientID, tc.audience, tc.scopes) + if tc.wantErr && err == nil { + t.Error("expected error but got none") + } + if !tc.wantErr && err != nil { + t.Errorf("unexpected error: %v", err) + } + }) + } +} diff --git a/server/server.go b/server/server.go index e6945c72..e01573dd 100644 --- a/server/server.go +++ b/server/server.go @@ -105,6 +105,7 @@ type Config struct { AlwaysShowLoginScreen bool IDTokensValidFor time.Duration // Defaults to 24 hours + IDJAGTokensValidFor time.Duration // Defaults to 5 minutes AuthRequestsValidFor time.Duration // Defaults to 24 hours DeviceRequestsValidFor time.Duration // Defaults to 5 minutes @@ -136,6 +137,26 @@ type Config struct { // If enabled, the server will continue starting even if some connectors fail to initialize. // This allows the server to operate with a subset of connectors if some are misconfigured. ContinueOnConnectorFailure bool + + // TokenExchange configures Token Exchange support. + TokenExchange TokenExchangeConfig + + IDJAGPolicies []TokenExchangePolicy +} + +// TokenExchangeConfig holds configuration for Token Exchange support. +type TokenExchangeConfig struct { + TokenTypes []string `json:"tokenTypes"` +} + +// IDJAGEnabled reports whether the ID-JAG token type is enabled. +func (c TokenExchangeConfig) IDJAGEnabled() bool { + for _, t := range c.TokenTypes { + if t == "urn:ietf:params:oauth:token-type:id-jag" { + return true + } + } + return false } // WebConfig holds the server's frontend templates and asset configuration. @@ -225,6 +246,10 @@ type Server struct { logger *slog.Logger signer signer.Signer + + enableIDJAG bool + idJAGTokensValidFor time.Duration + tokenExchangePolicies []TokenExchangePolicy } // NewServer constructs a server from the provided config. @@ -330,6 +355,8 @@ func newServer(ctx context.Context, c Config) (*Server, error) { now = time.Now } + idJAGTokensValidFor := value(c.IDJAGTokensValidFor, 5*time.Minute) + s := &Server{ issuerURL: *issuerURL, connectors: make(map[string]Connector), @@ -348,6 +375,9 @@ func newServer(ctx context.Context, c Config) (*Server, error) { passwordConnector: c.PasswordConnector, logger: c.Logger, signer: c.Signer, + enableIDJAG: c.TokenExchange.IDJAGEnabled(), + idJAGTokensValidFor: idJAGTokensValidFor, + tokenExchangePolicies: c.IDJAGPolicies, } // Retrieves connector objects in backend storage. This list includes the static connectors diff --git a/server/signer/local.go b/server/signer/local.go index a210aaa0..b4e45e94 100644 --- a/server/signer/local.go +++ b/server/signer/local.go @@ -87,6 +87,24 @@ func (l *localSigner) Sign(ctx context.Context, payload []byte) (string, error) return signPayload(signingKey, signingAlg, payload) } +func (l *localSigner) SignWithType(ctx context.Context, payload []byte, tokenType string) (string, error) { + keys, err := l.storage.GetKeys(ctx) + if err != nil { + return "", fmt.Errorf("failed to get keys: %v", err) + } + + signingKey := keys.SigningKey + if signingKey == nil { + return "", fmt.Errorf("no key to sign payload with") + } + signingAlg, err := signatureAlgorithm(signingKey) + if err != nil { + return "", err + } + + return signPayloadWithType(signingKey, signingAlg, payload, tokenType) +} + func (l *localSigner) ValidationKeys(ctx context.Context) ([]*jose.JSONWebKey, error) { keys, err := l.storage.GetKeys(ctx) if err != nil { diff --git a/server/signer/mock.go b/server/signer/mock.go index 832a9be8..cfdf0047 100644 --- a/server/signer/mock.go +++ b/server/signer/mock.go @@ -59,6 +59,10 @@ func (m *mockSigner) Sign(_ context.Context, payload []byte) (string, error) { return signPayload(m.key, jose.RS256, payload) } +func (m *mockSigner) SignWithType(_ context.Context, payload []byte, tokenType string) (string, error) { + return signPayloadWithType(m.key, jose.RS256, payload, tokenType) +} + func (m *mockSigner) ValidationKeys(_ context.Context) ([]*jose.JSONWebKey, error) { return []*jose.JSONWebKey{m.pubKey}, nil } diff --git a/server/signer/signer.go b/server/signer/signer.go index 1e15bbd1..801aab9f 100644 --- a/server/signer/signer.go +++ b/server/signer/signer.go @@ -10,6 +10,8 @@ import ( type Signer interface { // Sign signs the provided payload. Sign(ctx context.Context, payload []byte) (string, error) + // SignWithType signs the provided payload with a custom JWT "typ" header. + SignWithType(ctx context.Context, payload []byte, tokenType string) (string, error) // ValidationKeys returns the current public keys used for signature validation. ValidationKeys(ctx context.Context) ([]*jose.JSONWebKey, error) // Algorithm returns the signing algorithm used by this signer. diff --git a/server/signer/utils.go b/server/signer/utils.go index 6d607a10..11a5f60e 100644 --- a/server/signer/utils.go +++ b/server/signer/utils.go @@ -56,3 +56,20 @@ func signPayload(key *jose.JSONWebKey, alg jose.SignatureAlgorithm, payload []by } return signature.CompactSerialize() } + +func signPayloadWithType(key *jose.JSONWebKey, alg jose.SignatureAlgorithm, payload []byte, tokenType string) (jws string, err error) { + signingKey := jose.SigningKey{Key: key, Algorithm: alg} + + opts := &jose.SignerOptions{} + opts.WithType(jose.ContentType(tokenType)) + + signer, err := jose.NewSigner(signingKey, opts) + if err != nil { + return "", fmt.Errorf("new signer: %v", err) + } + signature, err := signer.Sign(payload) + if err != nil { + return "", fmt.Errorf("signing payload: %v", err) + } + return signature.CompactSerialize() +} diff --git a/server/signer/vault.go b/server/signer/vault.go index ba175f27..a073e863 100644 --- a/server/signer/vault.go +++ b/server/signer/vault.go @@ -179,6 +179,88 @@ func (v *vaultSigner) Sign(ctx context.Context, payload []byte) (string, error) return fmt.Sprintf("%s.%s.%s", headerB64, payloadB64, signatureB64), nil } +func (v *vaultSigner) SignWithType(ctx context.Context, payload []byte, tokenType string) (string, error) { + // 1. Fetch keys to determine the key to use (latest version) and its ID. + keysMap, latestVersion, err := v.getTransitKeysMap(ctx) + if err != nil { + return "", fmt.Errorf("failed to get keys for signing context: %v", err) + } + + // Determine the key version and ID to use + signingJWK, ok := keysMap[latestVersion] + if !ok { + return "", fmt.Errorf("latest key version %d not found in public keys", latestVersion) + } + + // 2. Construct JWS Header with custom typ and Payload first (Signing Input) + header := map[string]interface{}{ + "alg": signingJWK.Algorithm, + "kid": signingJWK.KeyID, + "typ": tokenType, + } + + headerBytes, err := json.Marshal(header) + if err != nil { + return "", fmt.Errorf("failed to marshal header: %v", err) + } + + headerB64 := base64.RawURLEncoding.EncodeToString(headerBytes) + payloadB64 := base64.RawURLEncoding.EncodeToString(payload) + + // The input to the signature is "header.payload" + signingInput := fmt.Sprintf("%s.%s", headerB64, payloadB64) + + // 3. Sign the signingInput using Vault + var vaultInput string + data := map[string]interface{}{} + + // Determine Vault params based on JWS algorithm + params, err := getVaultParams(signingJWK.Algorithm) + if err != nil { + return "", err + } + + // Apply params to data map + for k, v := range params.extraParams { + data[k] = v + } + + // Hash input if needed + if params.hasher != nil { + params.hasher.Write([]byte(signingInput)) + hash := params.hasher.Sum(nil) + vaultInput = base64.StdEncoding.EncodeToString(hash) + } else { + // No pre-hashing (EdDSA) + vaultInput = base64.StdEncoding.EncodeToString([]byte(signingInput)) + } + data["input"] = vaultInput + + signPath := fmt.Sprintf("transit/sign/%s", v.keyName) + signSecret, err := v.client.Logical().WriteWithContext(ctx, signPath, data) + if err != nil { + return "", fmt.Errorf("vault sign: %v", err) + } + + signatureString, ok := signSecret.Data["signature"].(string) + if !ok { + return "", fmt.Errorf("vault response missing signature") + } + + // Parse vault signature: "vault:v1:base64sig" + var signatureB64 []byte + if len(signatureString) > 8 && signatureString[:6] == "vault:" { + parts := splitVaultSignature(signatureString) + if len(parts) == 3 { + signatureB64 = []byte(parts[2]) + } + } else { + return "", fmt.Errorf("unexpected signature format: %s", signatureString) + } + + return fmt.Sprintf("%s.%s.%s", headerB64, payloadB64, signatureB64), nil +} + func (v *vaultSigner) ValidationKeys(ctx context.Context) ([]*jose.JSONWebKey, error) { keysMap, _, err := v.getTransitKeysMap(ctx) if err != nil {