Browse Source

Merge 2cf82eec53 into 93985dedff

pull/4611/merge
kt 20 hours ago committed by GitHub
parent
commit
12977b17ce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 19
      cmd/dex/config.go
  2. 12
      cmd/dex/config_test.go
  3. 36
      cmd/dex/serve.go
  4. 24
      config.yaml.dist
  5. 178
      server/handlers.go
  6. 164
      server/handlers_test.go
  7. 67
      server/oauth2.go
  8. 70
      server/policy.go
  9. 102
      server/policy_test.go
  10. 30
      server/server.go
  11. 18
      server/signer/local.go
  12. 4
      server/signer/mock.go
  13. 2
      server/signer/signer.go
  14. 17
      server/signer/utils.go
  15. 82
      server/signer/vault.go

19
cmd/dex/config.go

@ -55,7 +55,7 @@ type Config struct {
// StaticClients cause the server to use this list of clients rather than
// querying the storage. Write operations, like creating a client, will fail.
StaticClients []storage.Client `json:"staticClients"`
StaticClients []staticClient `json:"staticClients"`
// If enabled, the server will maintain a list of passwords which can be used
// to identify a user.
@ -158,6 +158,18 @@ func (p *password) UnmarshalJSON(b []byte) error {
return nil
}
// staticClient wraps storage.Client with optional per-client ID-JAG policy.
type staticClient struct {
storage.Client
IDJAGPolicies *IDJAGClientPolicy `json:"idJAGPolicies,omitempty"`
}
// IDJAGClientPolicy configures allowed audiences and scopes for ID-JAG exchange.
type IDJAGClientPolicy struct {
AllowedAudiences []string `json:"allowedAudiences"`
AllowedScopes []string `json:"allowedScopes"`
}
// OAuth2 describes enabled OAuth2 extensions.
type OAuth2 struct {
// list of allowed grant types,
@ -174,6 +186,8 @@ type OAuth2 struct {
PasswordConnector string `json:"passwordConnector"`
// PKCE configuration
PKCE PKCE `json:"pkce"`
// TokenExchange configures Token Exchange support.
TokenExchange server.TokenExchangeConfig `json:"tokenExchange"`
}
// PKCE holds the PKCE (Proof Key for Code Exchange) configuration.
@ -554,6 +568,9 @@ type Expiry struct {
// IdTokens defines the duration of time for which the IdTokens will be valid.
IDTokens string `json:"idTokens"`
// IDJAGTokens defines the duration of time for which ID-JAG tokens will be valid.
IDJAGTokens string `json:"idJAGTokens"`
// AuthRequests defines the duration of time for which the AuthRequests will be valid.
AuthRequests string `json:"authRequests"`

12
cmd/dex/config_test.go

@ -182,15 +182,15 @@ additionalFeatures: [
"foo": "bar",
},
},
StaticClients: []storage.Client{
{
StaticClients: []staticClient{
{Client: storage.Client{
ID: "example-app",
Secret: "ZXhhbXBsZS1hcHAtc2VjcmV0",
Name: "Example App",
RedirectURIs: []string{
"http://127.0.0.1:5555/callback",
},
},
}},
},
OAuth2: OAuth2{
AlwaysShowLoginScreen: true,
@ -411,15 +411,15 @@ logger:
"foo": "bar",
},
},
StaticClients: []storage.Client{
{
StaticClients: []staticClient{
{Client: storage.Client{
ID: "example-app",
Secret: "ZXhhbXBsZS1hcHAtc2VjcmV0",
Name: "Example App",
RedirectURIs: []string{
"http://127.0.0.1:5555/callback",
},
},
}},
},
OAuth2: OAuth2{
AlwaysShowLoginScreen: true,

36
cmd/dex/serve.go

@ -213,7 +213,9 @@ func runServe(options serveOptions) error {
logger.Info("config storage", "storage_type", c.Storage.Type)
if len(c.StaticClients) > 0 {
for i, client := range c.StaticClients {
storageClients := make([]storage.Client, len(c.StaticClients))
for i, sc := range c.StaticClients {
client := sc.Client
if client.Name == "" {
return fmt.Errorf("invalid config: Name field is required for a client")
}
@ -224,7 +226,7 @@ func runServe(options serveOptions) error {
if client.ID != "" {
return fmt.Errorf("invalid config: ID and IDEnv fields are exclusive for client %q", client.ID)
}
c.StaticClients[i].ID = os.Getenv(client.IDEnv)
client.ID = os.Getenv(client.IDEnv)
}
if client.Secret == "" && client.SecretEnv == "" && !client.Public {
return fmt.Errorf("invalid config: Secret or SecretEnv field is required for client %q", client.ID)
@ -233,11 +235,12 @@ func runServe(options serveOptions) error {
if client.Secret != "" {
return fmt.Errorf("invalid config: Secret and SecretEnv fields are exclusive for client %q", client.ID)
}
c.StaticClients[i].Secret = os.Getenv(client.SecretEnv)
client.Secret = os.Getenv(client.SecretEnv)
}
logger.Info("config static client", "client_name", client.Name)
storageClients[i] = client
}
s = storage.WithStaticClients(s, c.StaticClients)
s = storage.WithStaticClients(s, storageClients)
}
if len(c.StaticPasswords) > 0 {
passwords := make([]storage.Password, len(c.StaticPasswords))
@ -384,6 +387,7 @@ func runServe(options serveOptions) error {
ContinueOnConnectorFailure: featureflags.ContinueOnConnectorFailure.Enabled(),
Signer: signerInstance,
IDTokensValidFor: idTokensValidFor,
TokenExchange: c.OAuth2.TokenExchange,
}
if c.Expiry.AuthRequests != "" {
@ -402,6 +406,30 @@ func runServe(options serveOptions) error {
logger.Info("config device requests", "valid_for", deviceRequests)
serverConfig.DeviceRequestsValidFor = deviceRequests
}
if c.Expiry.IDJAGTokens != "" {
idJAGTokens, err := time.ParseDuration(c.Expiry.IDJAGTokens)
if err != nil {
return fmt.Errorf("invalid config value %q for ID-JAG token expiry: %v", c.Expiry.IDJAGTokens, err)
}
logger.Info("config ID-JAG tokens", "valid_for", idJAGTokens)
serverConfig.IDJAGTokensValidFor = idJAGTokens
}
// Build per-client ID-JAG policies from static client config.
for _, sc := range c.StaticClients {
if sc.IDJAGPolicies != nil {
clientID := sc.Client.ID
if clientID == "" && sc.Client.IDEnv != "" {
clientID = os.Getenv(sc.Client.IDEnv)
}
serverConfig.IDJAGPolicies = append(serverConfig.IDJAGPolicies, server.TokenExchangePolicy{
ClientID: clientID,
AllowedAudiences: sc.IDJAGPolicies.AllowedAudiences,
AllowedScopes: sc.IDJAGPolicies.AllowedScopes,
})
}
}
refreshTokenPolicy, err := server.NewRefreshTokenPolicy(
logger,
c.Expiry.RefreshTokens.DisableRotation,

24
config.yaml.dist

@ -89,6 +89,7 @@ web:
# deviceRequests: "5m"
# signingKeys: "6h"
# idTokens: "24h"
# idJAGTokens: "5m" # default: 5m; independent of idTokens
# refreshTokens:
# disableRotation: false
# reuseInterval: "3s"
@ -118,6 +119,14 @@ web:
# enforce: false
# # Supported code challenge methods. Defaults to ["S256", "plain"].
# codeChallengeMethodsSupported: ["S256", "plain"]
#
# # Token Exchange configuration
# tokenExchange:
# # List of token types enabled for exchange. Adding id-jag enables ID-JAG support.
# # Omitting it (default) disables ID-JAG without affecting other token exchange flows.
# tokenTypes:
# - urn:ietf:params:oauth:token-type:id_token
# - urn:ietf:params:oauth:token-type:id-jag
# Static clients registered in Dex by default.
#
@ -153,6 +162,21 @@ web:
# allowedConnectors:
# - github
# - google
#
# # Example of a client with ID-JAG token exchange policy
# - id: wiki-app
# secret: wiki-secret
# redirectURIs:
# - 'https://wiki.example/callback'
# name: 'Wiki Application'
# # Per-client ID-JAG policy. Clients without this section cannot obtain ID-JAG tokens.
# idJAGPolicies:
# allowedAudiences:
# - "https://chat.example/"
# - "https://calendar.example/"
# allowedScopes:
# - "chat.read"
# - "calendar.read"
# Connectors are used to authenticate users against upstream identity providers.
#

178
server/handlers.go

@ -73,21 +73,23 @@ func (s *Server) handlePublicKeys(w http.ResponseWriter, r *http.Request) {
}
type discovery struct {
Issuer string `json:"issuer"`
Auth string `json:"authorization_endpoint"`
Token string `json:"token_endpoint"`
Keys string `json:"jwks_uri"`
UserInfo string `json:"userinfo_endpoint"`
DeviceEndpoint string `json:"device_authorization_endpoint"`
Introspect string `json:"introspection_endpoint"`
GrantTypes []string `json:"grant_types_supported"`
ResponseTypes []string `json:"response_types_supported"`
Subjects []string `json:"subject_types_supported"`
IDTokenAlgs []string `json:"id_token_signing_alg_values_supported"`
CodeChallengeAlgs []string `json:"code_challenge_methods_supported"`
Scopes []string `json:"scopes_supported"`
AuthMethods []string `json:"token_endpoint_auth_methods_supported"`
Claims []string `json:"claims_supported"`
Issuer string `json:"issuer"`
Auth string `json:"authorization_endpoint"`
Token string `json:"token_endpoint"`
Keys string `json:"jwks_uri"`
UserInfo string `json:"userinfo_endpoint"`
DeviceEndpoint string `json:"device_authorization_endpoint"`
Introspect string `json:"introspection_endpoint"`
GrantTypes []string `json:"grant_types_supported"`
ResponseTypes []string `json:"response_types_supported"`
Subjects []string `json:"subject_types_supported"`
IDTokenAlgs []string `json:"id_token_signing_alg_values_supported"`
CodeChallengeAlgs []string `json:"code_challenge_methods_supported"`
Scopes []string `json:"scopes_supported"`
AuthMethods []string `json:"token_endpoint_auth_methods_supported"`
Claims []string `json:"claims_supported"`
IDJAGSigningAlgs []string `json:"id_jag_signing_alg_values_supported,omitempty"`
IdentityChainingTokenTypes []string `json:"identity_chaining_requested_token_types_supported,omitempty"`
}
func (s *Server) discoveryHandler(ctx context.Context) (http.HandlerFunc, error) {
@ -133,6 +135,11 @@ func (s *Server) constructDiscovery(ctx context.Context) discovery {
d.IDTokenAlgs = []string{string(signingAlg)}
}
if s.enableIDJAG {
d.IDJAGSigningAlgs = d.IDTokenAlgs
d.IdentityChainingTokenTypes = []string{tokenTypeIDJAG}
}
for responseType := range s.supportedResponseTypes {
d.ResponseTypes = append(d.ResponseTypes, responseType)
}
@ -1548,6 +1555,15 @@ func (s *Server) handleTokenExchange(w http.ResponseWriter, r *http.Request, cli
return
}
if requestedTokenType == tokenTypeIDJAG {
if !s.enableIDJAG {
s.tokenErrHelper(w, errRequestNotSupported, "ID-JAG token exchange is not enabled on this server.", http.StatusBadRequest)
return
}
s.handleIDJAGExchange(w, r, client, subjectToken, subjectTokenType, connID, scopes)
return
}
conn, err := s.getConnector(ctx, connID)
if err != nil {
s.logger.ErrorContext(r.Context(), "failed to get connector", "err", err)
@ -1608,6 +1624,138 @@ func (s *Server) handleTokenExchange(w http.ResponseWriter, r *http.Request, cli
json.NewEncoder(w).Encode(resp)
}
// handleIDJAGExchange handles a Token Exchange request with requested_token_type=ID-JAG.
// See: https://datatracker.ietf.org/doc/draft-ietf-oauth-identity-assertion-authz-grant/
func (s *Server) handleIDJAGExchange(w http.ResponseWriter, r *http.Request, client storage.Client, subjectToken, subjectTokenType string, connectorID string, scopes []string) {
ctx := r.Context()
q := r.Form
audience := q.Get("audience")
resource := q.Get("resource")
// Reject public clients (Section 7.1).
if client.Public {
s.tokenErrHelper(w, errUnauthorizedClient, "Public clients cannot use ID-JAG token exchange.", http.StatusBadRequest)
return
}
// connector_id is required for identifying the upstream connector.
if connectorID == "" {
s.tokenErrHelper(w, errInvalidRequest, "Missing required parameter connector_id for ID-JAG token exchange.", http.StatusBadRequest)
return
}
if _, err := s.getConnector(ctx, connectorID); err != nil {
s.logger.ErrorContext(ctx, "connector not found for ID-JAG exchange", "connector_id", connectorID, "err", err)
s.tokenErrHelper(w, errInvalidRequest, "Requested connector does not exist.", http.StatusBadRequest)
return
}
// audience is required.
if audience == "" {
s.tokenErrHelper(w, errInvalidRequest, "Missing required parameter audience for ID-JAG token exchange.", http.StatusBadRequest)
return
}
// subject_token_type must be id_token.
if subjectTokenType != tokenTypeID {
s.tokenErrHelper(w, errRequestNotSupported, "ID-JAG token exchange requires subject_token_type=urn:ietf:params:oauth:token-type:id_token.", http.StatusBadRequest)
return
}
// Extract sub and aud from the subject_token.
sub, tokenAud, err := extractJWTSubAndAud(subjectToken)
if err != nil {
s.logger.ErrorContext(ctx, "failed to extract claims from subject_token", "err", err)
s.tokenErrHelper(w, errInvalidRequest, "Invalid subject_token: could not parse JWT claims.", http.StatusBadRequest)
return
}
if sub == "" {
s.tokenErrHelper(w, errInvalidRequest, "subject_token missing required sub claim.", http.StatusBadRequest)
return
}
// Validate that the subject_token audience matches the requesting client (Section 4.3).
if !audContains(tokenAud, client.ID) {
s.logger.InfoContext(ctx, "subject_token audience does not match client_id",
"token_aud", tokenAud, "client_id", client.ID)
s.tokenErrHelper(w, errInvalidRequest, "subject_token audience does not match client_id.", http.StatusBadRequest)
return
}
// Evaluate access policy.
if err := evaluateIDJAGPolicy(s.tokenExchangePolicies, client.ID, audience, scopes); err != nil {
s.logger.InfoContext(ctx, "ID-JAG policy denied", "client_id", client.ID, "audience", audience, "err", err)
s.tokenErrHelper(w, errAccessDenied, "", http.StatusForbidden)
return
}
idJAGToken, expiry, err := s.newIDJAG(ctx, client.ID, sub, audience, resource, scopes)
if err != nil {
s.logger.ErrorContext(ctx, "failed to create ID-JAG token", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}
// RFC 8693 §2.2.1: token_type is "N_A" for non-access tokens.
resp := accessTokenResponse{
AccessToken: idJAGToken,
IssuedTokenType: tokenTypeIDJAG,
TokenType: "N_A",
ExpiresIn: int(time.Until(expiry).Seconds()),
}
w.Header().Set("Cache-Control", "no-store")
w.Header().Set("Pragma", "no-cache")
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
// extractJWTSubAndAud extracts the "sub" and "aud" claims from a JWT without
// verifying the signature. The aud claim may be a string or []string.
func extractJWTSubAndAud(token string) (sub string, aud []string, err error) {
parts := strings.SplitN(token, ".", 3)
if len(parts) != 3 {
return "", nil, fmt.Errorf("malformed JWT: expected 3 parts, got %d", len(parts))
}
payloadBytes, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return "", nil, fmt.Errorf("failed to decode JWT payload: %v", err)
}
var claims struct {
Sub string `json:"sub"`
Aud json.RawMessage `json:"aud"`
}
if err := json.Unmarshal(payloadBytes, &claims); err != nil {
return "", nil, fmt.Errorf("failed to unmarshal JWT payload: %v", err)
}
if len(claims.Aud) > 0 {
var single string
if err := json.Unmarshal(claims.Aud, &single); err == nil {
aud = []string{single}
} else {
var multi []string
if err := json.Unmarshal(claims.Aud, &multi); err == nil {
aud = multi
}
}
}
return claims.Sub, aud, nil
}
// audContains reports whether target is in aud.
func audContains(aud []string, target string) bool {
for _, a := range aud {
if a == target {
return true
}
}
return false
}
func (s *Server) handleClientCredentialsGrant(w http.ResponseWriter, r *http.Request, client storage.Client) {
ctx := r.Context()

164
server/handlers_test.go

@ -3,6 +3,7 @@ package server
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
@ -1343,6 +1344,169 @@ func (m *mockSAMLRefreshConnector) Refresh(ctx context.Context, s connector.Scop
return m.refreshIdentity, nil
}
// makeTestJWT builds a minimal JWT with the given sub for testing.
func makeTestJWT(sub string) string {
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256","typ":"JWT"}`))
payload := base64.RawURLEncoding.EncodeToString([]byte(`{"sub":"` + sub + `","iss":"https://issuer.example","aud":"client_1","exp":9999999999}`))
return header + "." + payload + ".fakesig"
}
// TestExtractJWTSubAndAud tests extractJWTSubAndAud.
func TestExtractJWTSubAndAud(t *testing.T) {
tests := []struct {
name string
token string
wantSub string
wantAud []string
wantErr bool
}{
{
name: "valid JWT returns sub and aud",
token: makeTestJWT("user-abc-123"),
wantSub: "user-abc-123",
wantAud: []string{"client_1"},
},
{
name: "not a JWT (no dots)",
token: "notajwt",
wantErr: true,
},
{
name: "invalid base64 payload",
token: "aGVhZGVy.!!!.c2ln",
wantErr: true,
},
{
name: "valid JWT without sub returns empty string",
token: func() string {
h := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256"}`))
p := base64.RawURLEncoding.EncodeToString([]byte(`{"iss":"https://issuer.example"}`))
return h + "." + p + ".sig"
}(),
wantSub: "",
wantAud: nil,
wantErr: false,
},
{
name: "aud as array",
token: func() string {
h := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256"}`))
p := base64.RawURLEncoding.EncodeToString([]byte(`{"sub":"u1","aud":["a","b"]}`))
return h + "." + p + ".sig"
}(),
wantSub: "u1",
wantAud: []string{"a", "b"},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
sub, aud, err := extractJWTSubAndAud(tc.token)
if tc.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Equal(t, tc.wantSub, sub)
require.Equal(t, tc.wantAud, aud)
})
}
}
// TestHandleIDJAGExchange tests the ID-JAG token exchange handler.
func TestHandleIDJAGExchange(t *testing.T) {
subjectToken := makeTestJWT("user-123")
tests := []struct {
name string
audience string
subjectTokenType string
subjectToken string
policies []TokenExchangePolicy
wantCode int
wantTokenTypeNA bool
}{
{
name: "happy path: valid ID-JAG issued",
audience: "https://resource-as.example.com",
subjectTokenType: tokenTypeID,
subjectToken: subjectToken,
wantCode: http.StatusOK,
wantTokenTypeNA: true,
},
{
name: "missing audience returns 400",
audience: "",
subjectTokenType: tokenTypeID,
subjectToken: subjectToken,
wantCode: http.StatusBadRequest,
},
{
name: "wrong subject_token_type returns 400",
audience: "https://resource-as.example.com",
subjectTokenType: tokenTypeAccess, // must be id_token for ID-JAG
subjectToken: subjectToken,
wantCode: http.StatusBadRequest,
},
{
name: "policy denies the audience: 403",
audience: "https://resource-as.example.com",
subjectTokenType: tokenTypeID,
subjectToken: subjectToken,
policies: []TokenExchangePolicy{
// client_1 may only reach other.example.com, not resource-as.example.com
{ClientID: "client_1", AllowedAudiences: []string{"https://other.example.com"}},
},
wantCode: http.StatusForbidden,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ctx := t.Context()
httpServer, s := newTestServer(t, func(c *Config) {
require.NoError(t, c.Storage.CreateClient(ctx, storage.Client{
ID: "client_1",
Secret: "secret_1",
}))
c.TokenExchange = TokenExchangeConfig{
TokenTypes: []string{tokenTypeIDJAG},
}
c.IDJAGPolicies = tc.policies
})
defer httpServer.Close()
vals := url.Values{}
vals.Set("grant_type", grantTypeTokenExchange)
vals.Set("requested_token_type", tokenTypeIDJAG)
vals.Set("subject_token_type", tc.subjectTokenType)
vals.Set("subject_token", tc.subjectToken)
vals.Set("connector_id", "mock")
if tc.audience != "" {
vals.Set("audience", tc.audience)
}
vals.Set("client_id", "client_1")
vals.Set("client_secret", "secret_1")
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, httpServer.URL+"/token", strings.NewReader(vals.Encode()))
req.Header.Set("content-type", "application/x-www-form-urlencoded")
s.handleToken(rr, req)
require.Equal(t, tc.wantCode, rr.Code, "body: %s", rr.Body.String())
if tc.wantTokenTypeNA {
var res accessTokenResponse
require.NoError(t, json.NewDecoder(rr.Result().Body).Decode(&res))
require.Equal(t, "N_A", res.TokenType)
require.Equal(t, tokenTypeIDJAG, res.IssuedTokenType)
require.NotEmpty(t, res.AccessToken)
require.Equal(t, 3, len(strings.Split(res.AccessToken, ".")), "expected compact JWT")
require.Greater(t, res.ExpiresIn, 0)
}
})
}
}
func TestFilterConnectors(t *testing.T) {
connectors := []storage.Connector{
{ID: "github", Type: "github", Name: "GitHub"},

67
server/oauth2.go

@ -165,6 +165,8 @@ const (
tokenTypeSAML1 = "urn:ietf:params:oauth:token-type:saml1"
tokenTypeSAML2 = "urn:ietf:params:oauth:token-type:saml2"
tokenTypeJWT = "urn:ietf:params:oauth:token-type:jwt"
// https://datatracker.ietf.org/doc/draft-ietf-oauth-identity-assertion-authz-grant/
tokenTypeIDJAG = "urn:ietf:params:oauth:token-type:id-jag"
)
const (
@ -281,6 +283,65 @@ func (s *Server) newAccessToken(ctx context.Context, clientID string, claims sto
return s.newIDToken(ctx, clientID, claims, scopes, nonce, storage.NewID(), "", connID)
}
// idJAGTyp is the JWT "typ" header value for ID-JAG tokens.
const idJAGTyp = "oauth-id-jag+jwt"
// idJAGClaims is the JWT payload for an ID-JAG token.
type idJAGClaims struct {
Issuer string `json:"iss"`
Subject string `json:"sub"`
Audience string `json:"aud"`
ClientID string `json:"client_id"`
JTI string `json:"jti"`
Expiry int64 `json:"exp"`
IssuedAt int64 `json:"iat"`
// Optional claims.
Resource string `json:"resource,omitempty"`
Scope string `json:"scope,omitempty"`
AuthTime int64 `json:"auth_time,omitempty"`
ACR string `json:"acr,omitempty"`
AMR []string `json:"amr,omitempty"`
}
// newIDJAG creates an ID-JAG token with the given subject and audience.
func (s *Server) newIDJAG(
ctx context.Context,
clientID string,
subject string,
audience string,
resource string,
scopes []string,
) (token string, expiry time.Time, err error) {
issuedAt := s.now()
expiry = issuedAt.Add(s.idJAGTokensValidFor)
claims := idJAGClaims{
Issuer: s.issuerURL.String(),
Subject: subject,
Audience: audience,
ClientID: clientID,
JTI: storage.NewID(),
Expiry: expiry.Unix(),
IssuedAt: issuedAt.Unix(),
Resource: resource,
}
if len(scopes) > 0 {
claims.Scope = strings.Join(scopes, " ")
}
payload, err := json.Marshal(claims)
if err != nil {
return "", expiry, fmt.Errorf("could not serialize ID-JAG claims: %v", err)
}
if token, err = s.signer.SignWithType(ctx, payload, idJAGTyp); err != nil {
return "", expiry, fmt.Errorf("failed to sign ID-JAG payload: %v", err)
}
return token, expiry, nil
}
func getClientID(aud audience, azp string) (string, error) {
switch len(aud) {
case 0:
@ -488,8 +549,7 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques
}
if codeChallenge != "" && !slices.Contains(s.pkce.CodeChallengeMethodsSupported, codeChallengeMethod) {
description := fmt.Sprintf("Unsupported PKCE challenge method (%q).", codeChallengeMethod)
return nil, newRedirectedErr(errInvalidRequest, description)
return nil, newRedirectedErr(errInvalidRequest, "Unsupported PKCE challenge method (%q).", codeChallengeMethod)
}
// Enforce PKCE if configured.
@ -578,8 +638,7 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques
}
if rt.token {
if redirectURI == redirectURIOOB {
err := fmt.Sprintf("Cannot use response type 'token' with redirect_uri '%s'.", redirectURIOOB)
return nil, newRedirectedErr(errInvalidRequest, err)
return nil, newRedirectedErr(errInvalidRequest, "Cannot use response type 'token' with redirect_uri '%s'.", redirectURIOOB)
}
}

70
server/policy.go

@ -0,0 +1,70 @@
package server
import "fmt"
// TokenExchangePolicy defines per-client access control for ID-JAG token exchange.
type TokenExchangePolicy struct {
// ClientID is the client this policy applies to. Use "*" for a default policy.
ClientID string `json:"clientID"`
AllowedAudiences []string `json:"allowedAudiences"`
AllowedScopes []string `json:"allowedScopes"`
}
// evaluateIDJAGPolicy checks whether the client is permitted to obtain an ID-JAG
// for the given audience and scopes. No policies configured means allow all.
func evaluateIDJAGPolicy(policies []TokenExchangePolicy, clientID, audience string, scopes []string) error {
if len(policies) == 0 {
return nil
}
// Find the most-specific policy for this client: exact match first, then wildcard.
var matched *TokenExchangePolicy
for i := range policies {
p := &policies[i]
if p.ClientID == clientID {
matched = p
break
}
if p.ClientID == "*" && matched == nil {
matched = p
}
}
if matched == nil {
return fmt.Errorf("no policy found for client %q: access_denied", clientID)
}
// Check audience.
if !audienceAllowed(matched.AllowedAudiences, audience) {
return fmt.Errorf("audience %q is not allowed for client %q: access_denied", audience, clientID)
}
// Check scopes (only if policy restricts them).
if len(matched.AllowedScopes) > 0 {
for _, scope := range scopes {
if !scopeAllowed(matched.AllowedScopes, scope) {
return fmt.Errorf("scope %q is not allowed for client %q: access_denied", scope, clientID)
}
}
}
return nil
}
func audienceAllowed(allowedAudiences []string, audience string) bool {
for _, a := range allowedAudiences {
if a == audience {
return true
}
}
return false
}
func scopeAllowed(allowedScopes []string, scope string) bool {
for _, s := range allowedScopes {
if s == scope {
return true
}
}
return false
}

102
server/policy_test.go

@ -0,0 +1,102 @@
package server
import (
"testing"
)
func TestEvaluateIDJAGPolicy(t *testing.T) {
tests := []struct {
name string
policies []TokenExchangePolicy
clientID string
audience string
scopes []string
wantErr bool
}{
{
name: "no policies: allow all",
policies: nil,
clientID: "any-client",
audience: "https://resource.example.com",
wantErr: false,
},
{
name: "exact match allowed",
policies: []TokenExchangePolicy{
{ClientID: "client-a", AllowedAudiences: []string{"https://resource.example.com"}},
},
clientID: "client-a",
audience: "https://resource.example.com",
wantErr: false,
},
{
name: "audience not allowed",
policies: []TokenExchangePolicy{
{ClientID: "client-a", AllowedAudiences: []string{"https://other.example.com"}},
},
clientID: "client-a",
audience: "https://resource.example.com",
wantErr: true,
},
{
name: "client not found: denied",
policies: []TokenExchangePolicy{
{ClientID: "client-a", AllowedAudiences: []string{"https://resource.example.com"}},
},
clientID: "unknown-client",
audience: "https://resource.example.com",
wantErr: true,
},
{
name: "wildcard client matches",
policies: []TokenExchangePolicy{
{ClientID: "*", AllowedAudiences: []string{"https://resource.example.com"}},
},
clientID: "any-client",
audience: "https://resource.example.com",
wantErr: false,
},
{
name: "exact match takes priority over wildcard",
policies: []TokenExchangePolicy{
{ClientID: "*", AllowedAudiences: []string{"https://other.example.com"}},
{ClientID: "client-a", AllowedAudiences: []string{"https://resource.example.com"}},
},
clientID: "client-a",
audience: "https://resource.example.com",
wantErr: false,
},
{
name: "scope denied by policy",
policies: []TokenExchangePolicy{
{ClientID: "client-a", AllowedAudiences: []string{"https://resource.example.com"}, AllowedScopes: []string{"read"}},
},
clientID: "client-a",
audience: "https://resource.example.com",
scopes: []string{"admin"},
wantErr: true,
},
{
name: "allowed scope passes",
policies: []TokenExchangePolicy{
{ClientID: "client-a", AllowedAudiences: []string{"https://resource.example.com"}, AllowedScopes: []string{"read", "write"}},
},
clientID: "client-a",
audience: "https://resource.example.com",
scopes: []string{"read"},
wantErr: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
err := evaluateIDJAGPolicy(tc.policies, tc.clientID, tc.audience, tc.scopes)
if tc.wantErr && err == nil {
t.Error("expected error but got none")
}
if !tc.wantErr && err != nil {
t.Errorf("unexpected error: %v", err)
}
})
}
}

30
server/server.go

@ -105,6 +105,7 @@ type Config struct {
AlwaysShowLoginScreen bool
IDTokensValidFor time.Duration // Defaults to 24 hours
IDJAGTokensValidFor time.Duration // Defaults to 5 minutes
AuthRequestsValidFor time.Duration // Defaults to 24 hours
DeviceRequestsValidFor time.Duration // Defaults to 5 minutes
@ -136,6 +137,26 @@ type Config struct {
// If enabled, the server will continue starting even if some connectors fail to initialize.
// This allows the server to operate with a subset of connectors if some are misconfigured.
ContinueOnConnectorFailure bool
// TokenExchange configures Token Exchange support.
TokenExchange TokenExchangeConfig
IDJAGPolicies []TokenExchangePolicy
}
// TokenExchangeConfig holds configuration for Token Exchange support.
type TokenExchangeConfig struct {
TokenTypes []string `json:"tokenTypes"`
}
// IDJAGEnabled reports whether the ID-JAG token type is enabled.
func (c TokenExchangeConfig) IDJAGEnabled() bool {
for _, t := range c.TokenTypes {
if t == "urn:ietf:params:oauth:token-type:id-jag" {
return true
}
}
return false
}
// WebConfig holds the server's frontend templates and asset configuration.
@ -225,6 +246,10 @@ type Server struct {
logger *slog.Logger
signer signer.Signer
enableIDJAG bool
idJAGTokensValidFor time.Duration
tokenExchangePolicies []TokenExchangePolicy
}
// NewServer constructs a server from the provided config.
@ -330,6 +355,8 @@ func newServer(ctx context.Context, c Config) (*Server, error) {
now = time.Now
}
idJAGTokensValidFor := value(c.IDJAGTokensValidFor, 5*time.Minute)
s := &Server{
issuerURL: *issuerURL,
connectors: make(map[string]Connector),
@ -348,6 +375,9 @@ func newServer(ctx context.Context, c Config) (*Server, error) {
passwordConnector: c.PasswordConnector,
logger: c.Logger,
signer: c.Signer,
enableIDJAG: c.TokenExchange.IDJAGEnabled(),
idJAGTokensValidFor: idJAGTokensValidFor,
tokenExchangePolicies: c.IDJAGPolicies,
}
// Retrieves connector objects in backend storage. This list includes the static connectors

18
server/signer/local.go

@ -87,6 +87,24 @@ func (l *localSigner) Sign(ctx context.Context, payload []byte) (string, error)
return signPayload(signingKey, signingAlg, payload)
}
func (l *localSigner) SignWithType(ctx context.Context, payload []byte, tokenType string) (string, error) {
keys, err := l.storage.GetKeys(ctx)
if err != nil {
return "", fmt.Errorf("failed to get keys: %v", err)
}
signingKey := keys.SigningKey
if signingKey == nil {
return "", fmt.Errorf("no key to sign payload with")
}
signingAlg, err := signatureAlgorithm(signingKey)
if err != nil {
return "", err
}
return signPayloadWithType(signingKey, signingAlg, payload, tokenType)
}
func (l *localSigner) ValidationKeys(ctx context.Context) ([]*jose.JSONWebKey, error) {
keys, err := l.storage.GetKeys(ctx)
if err != nil {

4
server/signer/mock.go

@ -59,6 +59,10 @@ func (m *mockSigner) Sign(_ context.Context, payload []byte) (string, error) {
return signPayload(m.key, jose.RS256, payload)
}
func (m *mockSigner) SignWithType(_ context.Context, payload []byte, tokenType string) (string, error) {
return signPayloadWithType(m.key, jose.RS256, payload, tokenType)
}
func (m *mockSigner) ValidationKeys(_ context.Context) ([]*jose.JSONWebKey, error) {
return []*jose.JSONWebKey{m.pubKey}, nil
}

2
server/signer/signer.go

@ -10,6 +10,8 @@ import (
type Signer interface {
// Sign signs the provided payload.
Sign(ctx context.Context, payload []byte) (string, error)
// SignWithType signs the provided payload with a custom JWT "typ" header.
SignWithType(ctx context.Context, payload []byte, tokenType string) (string, error)
// ValidationKeys returns the current public keys used for signature validation.
ValidationKeys(ctx context.Context) ([]*jose.JSONWebKey, error)
// Algorithm returns the signing algorithm used by this signer.

17
server/signer/utils.go

@ -56,3 +56,20 @@ func signPayload(key *jose.JSONWebKey, alg jose.SignatureAlgorithm, payload []by
}
return signature.CompactSerialize()
}
func signPayloadWithType(key *jose.JSONWebKey, alg jose.SignatureAlgorithm, payload []byte, tokenType string) (jws string, err error) {
signingKey := jose.SigningKey{Key: key, Algorithm: alg}
opts := &jose.SignerOptions{}
opts.WithType(jose.ContentType(tokenType))
signer, err := jose.NewSigner(signingKey, opts)
if err != nil {
return "", fmt.Errorf("new signer: %v", err)
}
signature, err := signer.Sign(payload)
if err != nil {
return "", fmt.Errorf("signing payload: %v", err)
}
return signature.CompactSerialize()
}

82
server/signer/vault.go

@ -179,6 +179,88 @@ func (v *vaultSigner) Sign(ctx context.Context, payload []byte) (string, error)
return fmt.Sprintf("%s.%s.%s", headerB64, payloadB64, signatureB64), nil
}
func (v *vaultSigner) SignWithType(ctx context.Context, payload []byte, tokenType string) (string, error) {
// 1. Fetch keys to determine the key to use (latest version) and its ID.
keysMap, latestVersion, err := v.getTransitKeysMap(ctx)
if err != nil {
return "", fmt.Errorf("failed to get keys for signing context: %v", err)
}
// Determine the key version and ID to use
signingJWK, ok := keysMap[latestVersion]
if !ok {
return "", fmt.Errorf("latest key version %d not found in public keys", latestVersion)
}
// 2. Construct JWS Header with custom typ and Payload first (Signing Input)
header := map[string]interface{}{
"alg": signingJWK.Algorithm,
"kid": signingJWK.KeyID,
"typ": tokenType,
}
headerBytes, err := json.Marshal(header)
if err != nil {
return "", fmt.Errorf("failed to marshal header: %v", err)
}
headerB64 := base64.RawURLEncoding.EncodeToString(headerBytes)
payloadB64 := base64.RawURLEncoding.EncodeToString(payload)
// The input to the signature is "header.payload"
signingInput := fmt.Sprintf("%s.%s", headerB64, payloadB64)
// 3. Sign the signingInput using Vault
var vaultInput string
data := map[string]interface{}{}
// Determine Vault params based on JWS algorithm
params, err := getVaultParams(signingJWK.Algorithm)
if err != nil {
return "", err
}
// Apply params to data map
for k, v := range params.extraParams {
data[k] = v
}
// Hash input if needed
if params.hasher != nil {
params.hasher.Write([]byte(signingInput))
hash := params.hasher.Sum(nil)
vaultInput = base64.StdEncoding.EncodeToString(hash)
} else {
// No pre-hashing (EdDSA)
vaultInput = base64.StdEncoding.EncodeToString([]byte(signingInput))
}
data["input"] = vaultInput
signPath := fmt.Sprintf("transit/sign/%s", v.keyName)
signSecret, err := v.client.Logical().WriteWithContext(ctx, signPath, data)
if err != nil {
return "", fmt.Errorf("vault sign: %v", err)
}
signatureString, ok := signSecret.Data["signature"].(string)
if !ok {
return "", fmt.Errorf("vault response missing signature")
}
// Parse vault signature: "vault:v1:base64sig"
var signatureB64 []byte
if len(signatureString) > 8 && signatureString[:6] == "vault:" {
parts := splitVaultSignature(signatureString)
if len(parts) == 3 {
signatureB64 = []byte(parts[2])
}
} else {
return "", fmt.Errorf("unexpected signature format: %s", signatureString)
}
return fmt.Sprintf("%s.%s.%s", headerB64, payloadB64, signatureB64), nil
}
func (v *vaultSigner) ValidationKeys(ctx context.Context) ([]*jose.JSONWebKey, error) {
keysMap, _, err := v.getTransitKeysMap(ctx)
if err != nil {

Loading…
Cancel
Save