Browse Source

OAuth 2.0 Token Exchange (#2806)

Signed-off-by: Sean Liao <sean+git@liao.dev>
Co-authored-by: Maksim Nabokikh <max.nabokih@gmail.com>
pull/3027/head
Sean Liao 3 years ago committed by GitHub
parent
commit
dcf7b18510
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 4
      cmd/dex/config.go
  2. 7
      cmd/dex/config_test.go
  3. 12
      cmd/dex/serve.go
  4. 4
      connector/connector.go
  5. 4
      connector/mock/connectortest.go
  6. 21
      connector/oidc/oidc.go
  7. 76
      connector/oidc/oidc_test.go
  8. 123
      server/handlers.go
  9. 116
      server/handlers_test.go
  10. 17
      server/oauth2.go
  11. 4
      server/oauth2_test.go
  12. 2
      server/refreshhandlers.go
  13. 29
      server/server.go
  14. 23
      server/server_test.go

4
cmd/dex/config.go

@ -129,6 +129,10 @@ func (p *password) UnmarshalJSON(b []byte) error {
// OAuth2 describes enabled OAuth2 extensions.
type OAuth2 struct {
// list of allowed grant types,
// defaults to all supported types
GrantTypes []string `json:"grantTypes"`
ResponseTypes []string `json:"responseTypes"`
// If specified, do not prompt the user to approve client authorization. The
// act of logging in implies authorization.

7
cmd/dex/config_test.go

@ -87,6 +87,9 @@ staticClients:
oauth2:
alwaysShowLoginScreen: true
grantTypes:
- refresh_token
- "urn:ietf:params:oauth:grant-type:token-exchange"
connectors:
- type: mockCallback
@ -161,6 +164,10 @@ logger:
},
OAuth2: OAuth2{
AlwaysShowLoginScreen: true,
GrantTypes: []string{
"refresh_token",
"urn:ietf:params:oauth:grant-type:token-exchange",
},
},
StaticConnectors: []Connector{
{

12
cmd/dex/serve.go

@ -259,6 +259,7 @@ func runServe(options serveOptions) error {
healthChecker := gosundheit.New()
serverConfig := server.Config{
AllowedGrantTypes: c.OAuth2.GrantTypes,
SupportedResponseTypes: c.OAuth2.ResponseTypes,
SkipApprovalScreen: c.OAuth2.SkipApprovalScreen,
AlwaysShowLoginScreen: c.OAuth2.AlwaysShowLoginScreen,
@ -550,6 +551,17 @@ func applyConfigOverrides(options serveOptions, config *Config) {
if config.Frontend.Dir == "" {
config.Frontend.Dir = os.Getenv("DEX_FRONTEND_DIR")
}
if len(config.OAuth2.GrantTypes) == 0 {
config.OAuth2.GrantTypes = []string{
"authorization_code",
"implicit",
"password",
"refresh_token",
"urn:ietf:params:oauth:grant-type:device_code",
"urn:ietf:params:oauth:grant-type:token-exchange",
}
}
}
func pprofHandler(router *http.ServeMux) {

4
connector/connector.go

@ -99,3 +99,7 @@ type RefreshConnector interface {
// changes since the token was last refreshed.
Refresh(ctx context.Context, s Scopes, identity Identity) (Identity, error)
}
type TokenIdentityConnector interface {
TokenIdentity(ctx context.Context, subjectTokenType, subjectToken string) (Identity, error)
}

4
connector/mock/connectortest.go

@ -66,6 +66,10 @@ func (m *Callback) Refresh(ctx context.Context, s connector.Scopes, identity con
return m.Identity, nil
}
func (m *Callback) TokenIdentity(ctx context.Context, subjectTokenType, subjectToken string) (connector.Identity, error) {
return m.Identity, nil
}
// CallbackConfig holds the configuration parameters for a connector which requires no interaction.
type CallbackConfig struct{}

21
connector/oidc/oidc.go

@ -258,6 +258,7 @@ type caller uint
const (
createCaller caller = iota
refreshCaller
exchangeCaller
)
func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) {
@ -296,16 +297,32 @@ func (c *oidcConnector) Refresh(ctx context.Context, s connector.Scopes, identit
return c.createIdentity(ctx, identity, token, refreshCaller)
}
func (c *oidcConnector) TokenIdentity(ctx context.Context, subjectTokenType, subjectToken string) (connector.Identity, error) {
var identity connector.Identity
token := &oauth2.Token{
AccessToken: subjectToken,
}
return c.createIdentity(ctx, identity, token, exchangeCaller)
}
func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.Identity, token *oauth2.Token, caller caller) (connector.Identity, error) {
var claims map[string]interface{}
rawIDToken, ok := token.Extra("id_token").(string)
if ok {
if rawIDToken, ok := token.Extra("id_token").(string); ok {
idToken, err := c.verifier.Verify(ctx, rawIDToken)
if err != nil {
return identity, fmt.Errorf("oidc: failed to verify ID Token: %v", err)
}
if err := idToken.Claims(&claims); err != nil {
return identity, fmt.Errorf("oidc: failed to decode claims: %v", err)
}
} else if caller == exchangeCaller {
// AccessToken here could be either an id token or an access token
idToken, err := c.provider.Verifier(&oidc.Config{SkipClientIDCheck: true}).Verify(ctx, token.AccessToken)
if err != nil {
return identity, fmt.Errorf("oidc: failed to verify token: %v", err)
}
if err := idToken.Claims(&claims); err != nil {
return identity, fmt.Errorf("oidc: failed to decode claims: %v", err)
}

76
connector/oidc/oidc_test.go

@ -2,6 +2,7 @@ package oidc
import (
"bytes"
"context"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
@ -428,6 +429,81 @@ func TestRefresh(t *testing.T) {
}
}
func TestTokenIdentity(t *testing.T) {
tokenTypeAccess := "urn:ietf:params:oauth:token-type:access_token"
tokenTypeID := "urn:ietf:params:oauth:token-type:id_token"
long2short := map[string]string{
tokenTypeAccess: "access_token",
tokenTypeID: "id_token",
}
tests := []struct {
name string
subjectType string
userInfo bool
}{
{
name: "id_token",
subjectType: tokenTypeID,
}, {
name: "access_token",
subjectType: tokenTypeAccess,
}, {
name: "id_token with user info",
subjectType: tokenTypeID,
userInfo: true,
}, {
name: "access_token with user info",
subjectType: tokenTypeAccess,
userInfo: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
testServer, err := setupServer(map[string]any{
"sub": "subvalue",
"name": "namevalue",
}, true)
if err != nil {
t.Fatal("failed to setup test server", err)
}
conn, err := newConnector(Config{
Issuer: testServer.URL,
Scopes: []string{"openid", "groups"},
GetUserInfo: tc.userInfo,
})
if err != nil {
t.Fatal("failed to create new connector", err)
}
res, err := http.Get(testServer.URL + "/token")
if err != nil {
t.Fatal("failed to get initial token", err)
}
defer res.Body.Close()
var tokenResponse map[string]any
err = json.NewDecoder(res.Body).Decode(&tokenResponse)
if err != nil {
t.Fatal("failed to decode initial token", err)
}
origToken := tokenResponse[long2short[tc.subjectType]].(string)
identity, err := conn.TokenIdentity(ctx, tc.subjectType, origToken)
if err != nil {
t.Fatal("failed to get token identity", err)
}
// assert identity
expectEquals(t, identity.UserID, "subvalue")
expectEquals(t, identity.Username, "namevalue")
})
}
}
func setupServer(tok map[string]interface{}, idTokenDesired bool) (*httptest.Server, error) {
key, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {

123
server/handlers.go

@ -710,7 +710,7 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe
implicitOrHybrid = true
var err error
accessToken, err = s.newAccessToken(authReq.ClientID, authReq.Claims, authReq.Scopes, authReq.Nonce, authReq.ConnectorID)
accessToken, _, err = s.newAccessToken(authReq.ClientID, authReq.Claims, authReq.Scopes, authReq.Nonce, authReq.ConnectorID)
if err != nil {
s.logger.Errorf("failed to create new access token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
@ -830,6 +830,11 @@ func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) {
}
grantType := r.PostFormValue("grant_type")
if !contains(s.supportedGrantTypes, grantType) {
s.logger.Errorf("unsupported grant type: %v", grantType)
s.tokenErrHelper(w, errUnsupportedGrantType, "", http.StatusBadRequest)
return
}
switch grantType {
case grantTypeDeviceCode:
s.handleDeviceToken(w, r)
@ -839,6 +844,8 @@ func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) {
s.withClientFromStorage(w, r, s.handleRefreshToken)
case grantTypePassword:
s.withClientFromStorage(w, r, s.handlePasswordGrant)
case grantTypeTokenExchange:
s.withClientFromStorage(w, r, s.handleTokenExchange)
default:
s.tokenErrHelper(w, errUnsupportedGrantType, "", http.StatusBadRequest)
}
@ -917,7 +924,7 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
}
func (s *Server) exchangeAuthCode(w http.ResponseWriter, authCode storage.AuthCode, client storage.Client) (*accessTokenResponse, error) {
accessToken, err := s.newAccessToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, authCode.ConnectorID)
accessToken, _, err := s.newAccessToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, authCode.ConnectorID)
if err != nil {
s.logger.Errorf("failed to create new access token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
@ -1180,7 +1187,7 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli
Groups: identity.Groups,
}
accessToken, err := s.newAccessToken(client.ID, claims, scopes, nonce, connID)
accessToken, _, err := s.newAccessToken(client.ID, claims, scopes, nonce, connID)
if err != nil {
s.logger.Errorf("password grant failed to create new access token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
@ -1319,21 +1326,109 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli
s.writeAccessToken(w, resp)
}
func (s *Server) handleTokenExchange(w http.ResponseWriter, r *http.Request, client storage.Client) {
ctx := r.Context()
if err := r.ParseForm(); err != nil {
s.logger.Errorf("could not parse request body: %v", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest)
return
}
q := r.Form
scopes := strings.Fields(q.Get("scope")) // OPTIONAL, map to issued token scope
requestedTokenType := q.Get("requested_token_type") // OPTIONAL, default to access token
if requestedTokenType == "" {
requestedTokenType = tokenTypeAccess
}
subjectToken := q.Get("subject_token") // REQUIRED
subjectTokenType := q.Get("subject_token_type") // REQUIRED
connID := q.Get("connector_id") // REQUIRED, not in RFC
switch subjectTokenType {
case tokenTypeID, tokenTypeAccess: // ok, continue
default:
s.tokenErrHelper(w, errRequestNotSupported, "Invalid subject_token_type.", http.StatusBadRequest)
return
}
if subjectToken == "" {
s.tokenErrHelper(w, errInvalidRequest, "Missing subject_token", http.StatusBadRequest)
return
}
conn, err := s.getConnector(connID)
if err != nil {
s.logger.Errorf("failed to get connector: %v", err)
s.tokenErrHelper(w, errInvalidRequest, "Requested connector does not exist.", http.StatusBadRequest)
return
}
teConn, ok := conn.Connector.(connector.TokenIdentityConnector)
if !ok {
s.logger.Errorf("connector doesn't implement token exchange: %v", connID)
s.tokenErrHelper(w, errInvalidRequest, "Requested connector does not exist.", http.StatusBadRequest)
return
}
identity, err := teConn.TokenIdentity(ctx, subjectTokenType, subjectToken)
if err != nil {
s.logger.Errorf("failed to verify subject token: %v", err)
s.tokenErrHelper(w, errAccessDenied, "", http.StatusUnauthorized)
return
}
claims := storage.Claims{
UserID: identity.UserID,
Username: identity.Username,
PreferredUsername: identity.PreferredUsername,
Email: identity.Email,
EmailVerified: identity.EmailVerified,
Groups: identity.Groups,
}
resp := accessTokenResponse{
IssuedTokenType: requestedTokenType,
TokenType: "bearer",
}
var expiry time.Time
switch requestedTokenType {
case tokenTypeID:
resp.AccessToken, expiry, err = s.newIDToken(client.ID, claims, scopes, "", "", "", connID)
case tokenTypeAccess:
resp.AccessToken, expiry, err = s.newAccessToken(client.ID, claims, scopes, "", connID)
default:
s.tokenErrHelper(w, errRequestNotSupported, "Invalid requested_token_type.", http.StatusBadRequest)
return
}
if err != nil {
s.logger.Errorf("token exchange failed to create new %v token: %v", requestedTokenType, err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}
resp.ExpiresIn = int(time.Until(expiry).Seconds())
// Token response must include cache headers https://tools.ietf.org/html/rfc6749#section-5.1
w.Header().Set("Cache-Control", "no-store")
w.Header().Set("Pragma", "no-cache")
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
type accessTokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
IDToken string `json:"id_token"`
AccessToken string `json:"access_token"`
IssuedTokenType string `json:"issued_token_type,omitempty"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in,omitempty"`
RefreshToken string `json:"refresh_token,omitempty"`
IDToken string `json:"id_token,omitempty"`
Scope string `json:"scope,omitempty"`
}
func (s *Server) toAccessTokenResponse(idToken, accessToken, refreshToken string, expiry time.Time) *accessTokenResponse {
return &accessTokenResponse{
accessToken,
"bearer",
int(expiry.Sub(s.now()).Seconds()),
refreshToken,
idToken,
AccessToken: accessToken,
TokenType: "bearer",
ExpiresIn: int(expiry.Sub(s.now()).Seconds()),
RefreshToken: refreshToken,
IDToken: idToken,
}
}
@ -1355,7 +1450,7 @@ func (s *Server) writeAccessToken(w http.ResponseWriter, resp *accessTokenRespon
func (s *Server) renderError(r *http.Request, w http.ResponseWriter, status int, description string) {
if err := s.templates.err(r, w, status, description); err != nil {
s.logger.Errorf("Server template error: %v", err)
s.logger.Errorf("server template error: %v", err)
}
}

116
server/handlers_test.go

@ -10,6 +10,7 @@ import (
"net/http/httptest"
"net/url"
"path"
"strings"
"testing"
"time"
@ -332,7 +333,7 @@ func TestHandlePasswordLoginWithSkipApproval(t *testing.T) {
connID := "mockPw"
authReqID := "test"
expiry := time.Now().Add(100 * time.Second)
resTypes := []string{"code"}
resTypes := []string{responseTypeCode}
tests := []struct {
name string
@ -441,7 +442,7 @@ func TestHandleConnectorCallbackWithSkipApproval(t *testing.T) {
connID := "mock"
authReqID := "test"
expiry := time.Now().Add(100 * time.Second)
resTypes := []string{"code"}
resTypes := []string{responseTypeCode}
tests := []struct {
name string
@ -527,3 +528,114 @@ func TestHandleConnectorCallbackWithSkipApproval(t *testing.T) {
require.Equal(t, tc.expectedRes, cb.Path)
}
}
func TestHandleTokenExchange(t *testing.T) {
tests := []struct {
name string
scope string
requestedTokenType string
subjectTokenType string
subjectToken string
expectedCode int
expectedTokenType string
}{
{
"id-for-acccess",
"openid",
tokenTypeAccess,
tokenTypeID,
"foobar",
http.StatusOK,
tokenTypeAccess,
},
{
"id-for-id",
"openid",
tokenTypeID,
tokenTypeID,
"foobar",
http.StatusOK,
tokenTypeID,
},
{
"id-for-default",
"openid",
"",
tokenTypeID,
"foobar",
http.StatusOK,
tokenTypeAccess,
},
{
"access-for-access",
"openid",
tokenTypeAccess,
tokenTypeAccess,
"foobar",
http.StatusOK,
tokenTypeAccess,
},
{
"missing-subject_token_type",
"openid",
tokenTypeAccess,
"",
"foobar",
http.StatusBadRequest,
"",
},
{
"missing-subject_token",
"openid",
tokenTypeAccess,
tokenTypeAccess,
"",
http.StatusBadRequest,
"",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
httpServer, s := newTestServer(ctx, t, func(c *Config) {
c.Storage.CreateClient(storage.Client{
ID: "client_1",
Secret: "secret_1",
})
})
defer httpServer.Close()
vals := make(url.Values)
vals.Set("grant_type", grantTypeTokenExchange)
setNonEmpty(vals, "connector_id", "mock")
setNonEmpty(vals, "scope", tc.scope)
setNonEmpty(vals, "requested_token_type", tc.requestedTokenType)
setNonEmpty(vals, "subject_token_type", tc.subjectTokenType)
setNonEmpty(vals, "subject_token", tc.subjectToken)
setNonEmpty(vals, "client_id", "client_1")
setNonEmpty(vals, "client_secret", "secret_1")
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, httpServer.URL+"/token", strings.NewReader(vals.Encode()))
req.Header.Set("content-type", "application/x-www-form-urlencoded")
s.handleToken(rr, req)
require.Equal(t, tc.expectedCode, rr.Code, rr.Body.String())
require.Equal(t, "application/json", rr.Result().Header.Get("content-type"))
if tc.expectedCode == http.StatusOK {
var res accessTokenResponse
err := json.NewDecoder(rr.Result().Body).Decode(&res)
require.NoError(t, err)
require.Equal(t, tc.expectedTokenType, res.IssuedTokenType)
}
})
}
}
func setNonEmpty(vals url.Values, key, value string) {
if value != "" {
vals.Set(key, value)
}
}

17
server/oauth2.go

@ -93,7 +93,6 @@ func tokenErr(w http.ResponseWriter, typ, description string, statusCode int) er
return nil
}
//nolint
const (
errInvalidRequest = "invalid_request"
errUnauthorizedClient = "unauthorized_client"
@ -132,6 +131,17 @@ const (
grantTypeImplicit = "implicit"
grantTypePassword = "password"
grantTypeDeviceCode = "urn:ietf:params:oauth:grant-type:device_code"
grantTypeTokenExchange = "urn:ietf:params:oauth:grant-type:token-exchange"
)
const (
// https://www.rfc-editor.org/rfc/rfc8693.html#section-3
tokenTypeAccess = "urn:ietf:params:oauth:token-type:access_token"
tokenTypeRefresh = "urn:ietf:params:oauth:token-type:refresh_token"
tokenTypeID = "urn:ietf:params:oauth:token-type:id_token"
tokenTypeSAML1 = "urn:ietf:params:oauth:token-type:saml1"
tokenTypeSAML2 = "urn:ietf:params:oauth:token-type:saml2"
tokenTypeJWT = "urn:ietf:params:oauth:token-type:jwt"
)
const (
@ -288,9 +298,8 @@ type federatedIDClaims struct {
UserID string `json:"user_id,omitempty"`
}
func (s *Server) newAccessToken(clientID string, claims storage.Claims, scopes []string, nonce, connID string) (accessToken string, err error) {
idToken, _, err := s.newIDToken(clientID, claims, scopes, nonce, storage.NewID(), "", connID)
return idToken, err
func (s *Server) newAccessToken(clientID string, claims storage.Claims, scopes []string, nonce, connID string) (accessToken string, expiry time.Time, err error) {
return s.newIDToken(clientID, claims, scopes, nonce, storage.NewID(), "", connID)
}
func (s *Server) newIDToken(clientID string, claims storage.Claims, scopes []string, nonce, accessToken, code, connID string) (idToken string, expiry time.Time, err error) {

4
server/oauth2_test.go

@ -290,7 +290,7 @@ func TestParseAuthorizationRequest(t *testing.T) {
}
for _, tc := range tests {
func() {
t.Run(tc.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@ -343,7 +343,7 @@ func TestParseAuthorizationRequest(t *testing.T) {
t.Fatalf("%s: unsupported error type", tc.name)
}
}
}()
})
}
}

2
server/refreshhandlers.go

@ -361,7 +361,7 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
Groups: ident.Groups,
}
accessToken, err := s.newAccessToken(client.ID, claims, rCtx.scopes, rCtx.storageToken.Nonce, rCtx.storageToken.ConnectorID)
accessToken, _, err := s.newAccessToken(client.ID, claims, rCtx.scopes, rCtx.storageToken.Nonce, rCtx.storageToken.ConnectorID)
if err != nil {
s.logger.Errorf("failed to create new access token: %v", err)
s.refreshTokenErrHelper(w, newInternalServerError())

29
server/server.go

@ -66,6 +66,8 @@ type Config struct {
// The backing persistence layer.
Storage storage.Storage
AllowedGrantTypes []string
// Valid values are "code" to enable the code flow and "token" to enable the implicit
// flow. If no response types are supplied this value defaults to "code".
SupportedResponseTypes []string
@ -213,7 +215,12 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
c.SupportedResponseTypes = []string{responseTypeCode}
}
supportedGrant := []string{grantTypeAuthorizationCode, grantTypeRefreshToken, grantTypeDeviceCode} // default
allSupportedGrants := map[string]bool{
grantTypeAuthorizationCode: true,
grantTypeRefreshToken: true,
grantTypeDeviceCode: true,
grantTypeTokenExchange: true,
}
supportedRes := make(map[string]bool)
for _, respType := range c.SupportedResponseTypes {
@ -223,7 +230,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
case responseTypeToken:
// response_type=token is an implicit flow, let's add it to the discovery info
// https://datatracker.ietf.org/doc/html/rfc6749#section-4.2.1
supportedGrant = append(supportedGrant, grantTypeImplicit)
allSupportedGrants[grantTypeImplicit] = true
default:
return nil, fmt.Errorf("unsupported response_type %q", respType)
}
@ -231,10 +238,22 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
}
if c.PasswordConnector != "" {
supportedGrant = append(supportedGrant, grantTypePassword)
allSupportedGrants[grantTypePassword] = true
}
sort.Strings(supportedGrant)
var supportedGrants []string
if len(c.AllowedGrantTypes) > 0 {
for _, grant := range c.AllowedGrantTypes {
if allSupportedGrants[grant] {
supportedGrants = append(supportedGrants, grant)
}
}
} else {
for grant := range allSupportedGrants {
supportedGrants = append(supportedGrants, grant)
}
}
sort.Strings(supportedGrants)
webFS := web.FS()
if c.Web.Dir != "" {
@ -267,7 +286,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
connectors: make(map[string]Connector),
storage: newKeyCacher(c.Storage, now),
supportedResponseTypes: supportedRes,
supportedGrantTypes: supportedGrant,
supportedGrantTypes: supportedGrants,
idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour),
authRequestsValidFor: value(c.AuthRequestsValidFor, 24*time.Hour),
deviceRequestsValidFor: value(c.DeviceRequestsValidFor, 5*time.Minute),

23
server/server_test.go

@ -99,6 +99,14 @@ func newTestServer(ctx context.Context, t *testing.T, updateConfig func(c *Confi
PrometheusRegistry: prometheus.NewRegistry(),
HealthChecker: gosundheit.New(),
SkipApprovalScreen: true, // Don't prompt for approval, just immediately redirect with code.
AllowedGrantTypes: []string{ // all implemented types
grantTypeDeviceCode,
grantTypeAuthorizationCode,
grantTypeRefreshToken,
grantTypeTokenExchange,
grantTypeImplicit,
grantTypePassword,
},
}
if updateConfig != nil {
updateConfig(&config)
@ -1756,17 +1764,22 @@ func TestServerSupportedGrants(t *testing.T) {
{
name: "Simple",
config: func(c *Config) {},
resGrants: []string{grantTypeAuthorizationCode, grantTypeRefreshToken, grantTypeDeviceCode},
resGrants: []string{grantTypeAuthorizationCode, grantTypeRefreshToken, grantTypeDeviceCode, grantTypeTokenExchange},
},
{
name: "Minimal",
config: func(c *Config) { c.AllowedGrantTypes = []string{grantTypeTokenExchange} },
resGrants: []string{grantTypeTokenExchange},
},
{
name: "With password connector",
config: func(c *Config) { c.PasswordConnector = "local" },
resGrants: []string{grantTypeAuthorizationCode, grantTypePassword, grantTypeRefreshToken, grantTypeDeviceCode},
resGrants: []string{grantTypeAuthorizationCode, grantTypePassword, grantTypeRefreshToken, grantTypeDeviceCode, grantTypeTokenExchange},
},
{
name: "With token response",
config: func(c *Config) { c.SupportedResponseTypes = append(c.SupportedResponseTypes, responseTypeToken) },
resGrants: []string{grantTypeAuthorizationCode, grantTypeImplicit, grantTypeRefreshToken, grantTypeDeviceCode},
resGrants: []string{grantTypeAuthorizationCode, grantTypeImplicit, grantTypeRefreshToken, grantTypeDeviceCode, grantTypeTokenExchange},
},
{
name: "All",
@ -1774,14 +1787,14 @@ func TestServerSupportedGrants(t *testing.T) {
c.PasswordConnector = "local"
c.SupportedResponseTypes = append(c.SupportedResponseTypes, responseTypeToken)
},
resGrants: []string{grantTypeAuthorizationCode, grantTypeImplicit, grantTypePassword, grantTypeRefreshToken, grantTypeDeviceCode},
resGrants: []string{grantTypeAuthorizationCode, grantTypeImplicit, grantTypePassword, grantTypeRefreshToken, grantTypeDeviceCode, grantTypeTokenExchange},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
_, srv := newTestServer(context.TODO(), t, tc.config)
require.Equal(t, srv.supportedGrantTypes, tc.resGrants)
require.Equal(t, tc.resGrants, srv.supportedGrantTypes)
})
}
}

Loading…
Cancel
Save