Browse Source

feat(connector): connectors for grants (#4619)

Signed-off-by: maksim.nabokikh <max.nabokih@gmail.com>
pull/4634/head
Maksim Nabokikh 6 days ago committed by GitHub
parent
commit
7777773067
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 701
      api/v2/api.pb.go
  2. 11
      api/v2/api.proto
  3. 24
      cmd/dex/config.go
  4. 7
      cmd/dex/config_test.go
  5. 5
      cmd/dex/serve.go
  6. 14
      examples/config-dev.yaml
  7. 36
      server/api.go
  8. 99
      server/api_test.go
  9. 87
      server/handlers.go
  10. 157
      server/handlers_test.go
  11. 10
      server/oauth2.go
  12. 4
      server/refreshhandlers.go
  13. 9
      server/server.go
  14. 11
      storage/conformance/conformance.go
  15. 2
      storage/ent/client/connector.go
  16. 9
      storage/ent/client/types.go
  17. 18
      storage/ent/db/connector.go
  18. 3
      storage/ent/db/connector/connector.go
  19. 10
      storage/ent/db/connector/where.go
  20. 10
      storage/ent/db/connector_create.go
  21. 59
      storage/ent/db/connector_update.go
  22. 1
      storage/ent/db/migrate/schema.go
  23. 119
      storage/ent/db/mutation.go
  24. 2
      storage/ent/schema/connector.go
  25. 12
      storage/kubernetes/types.go
  26. 35
      storage/sql/crud.go
  27. 7
      storage/sql/migrate.go
  28. 4
      storage/storage.go

701
api/v2/api.pb.go

File diff suppressed because it is too large Load Diff

11
api/v2/api.proto

@ -140,6 +140,7 @@ message Connector {
string type = 2;
string name = 3;
bytes config = 4;
repeated string grant_types = 5;
}
// CreateConnectorReq is a request to make a connector.
@ -152,6 +153,12 @@ message CreateConnectorResp {
bool already_exists = 1;
}
// GrantTypes wraps a list of grant types to distinguish between
// "not specified" (no update) and "empty list" (unrestricted).
message GrantTypes {
repeated string grant_types = 1;
}
// UpdateConnectorReq is a request to modify an existing connector.
message UpdateConnectorReq {
// The id used to lookup the connector. This field cannot be modified
@ -159,6 +166,10 @@ message UpdateConnectorReq {
string new_type = 2;
string new_name = 3;
bytes new_config = 4;
// If set, updates the connector's allowed grant types.
// An empty grant_types list means unrestricted (all grant types allowed).
// If not set (null), grant types are not modified.
GrantTypes new_grant_types = 5;
}
// UpdateConnectorResp returns the response from modifying an existing connector.

24
cmd/dex/config.go

@ -459,7 +459,8 @@ type Connector struct {
Name string `json:"name"`
ID string `json:"id"`
Config server.ConnectorConfig `json:"config"`
Config server.ConnectorConfig `json:"config"`
GrantTypes []string `json:"grantTypes"`
}
// UnmarshalJSON allows Connector to implement the unmarshaler interface to
@ -470,7 +471,8 @@ func (c *Connector) UnmarshalJSON(b []byte) error {
Name string `json:"name"`
ID string `json:"id"`
Config json.RawMessage `json:"config"`
Config json.RawMessage `json:"config"`
GrantTypes []string `json:"grantTypes"`
}
if err := configUnmarshaller(b, &conn); err != nil {
return fmt.Errorf("parse connector: %v", err)
@ -508,10 +510,11 @@ func (c *Connector) UnmarshalJSON(b []byte) error {
}
*c = Connector{
Type: conn.Type,
Name: conn.Name,
ID: conn.ID,
Config: connConfig,
Type: conn.Type,
Name: conn.Name,
ID: conn.ID,
Config: connConfig,
GrantTypes: conn.GrantTypes,
}
return nil
}
@ -524,10 +527,11 @@ func ToStorageConnector(c Connector) (storage.Connector, error) {
}
return storage.Connector{
ID: c.ID,
Type: c.Type,
Name: c.Name,
Config: data,
ID: c.ID,
Type: c.Type,
Name: c.Name,
Config: data,
GrantTypes: c.GrantTypes,
}, nil
}

7
cmd/dex/config_test.go

@ -107,6 +107,9 @@ connectors:
- type: mockCallback
id: mock
name: Example
grantTypes:
- authorization_code
- "urn:ietf:params:oauth:grant-type:token-exchange"
- type: oidc
id: google
name: Google
@ -202,6 +205,10 @@ additionalFeatures: [
ID: "mock",
Name: "Example",
Config: &mock.CallbackConfig{},
GrantTypes: []string{
"authorization_code",
"urn:ietf:params:oauth:grant-type:token-exchange",
},
},
{
Type: "oidc",

5
cmd/dex/serve.go

@ -255,6 +255,11 @@ func runServe(options serveOptions) error {
if c.Config == nil {
return fmt.Errorf("invalid config: no config field for connector %q", c.ID)
}
for _, gt := range c.GrantTypes {
if !server.ConnectorGrantTypes[gt] {
return fmt.Errorf("invalid config: unknown grant type %q for connector %q", gt, c.ID)
}
}
logger.Info("config connector", "connector_id", c.ID)
// convert to a storage connector object

14
examples/config-dev.yaml

@ -100,7 +100,7 @@ telemetry:
# format: "text" # can also be "json"
# Default values shown below
#oauth2:
# oauth2:
# grantTypes determines the allowed set of authorization flows.
# grantTypes:
# - "authorization_code"
@ -151,6 +151,18 @@ connectors:
- type: mockCallback
id: mock
name: Example
# grantTypes restricts which grant types can use this connector.
# If not specified, all grant types are allowed.
# Supported values:
# - "authorization_code"
# - "implicit"
# - "refresh_token"
# - "password"
# - "urn:ietf:params:oauth:grant-type:device_code"
# - "urn:ietf:params:oauth:grant-type:token-exchange"
# grantTypes:
# - "authorization_code"
# - "refresh_token"
# - type: google
# id: google
# name: Google

36
server/api.go

@ -455,12 +455,19 @@ func (d dexAPI) CreateConnector(ctx context.Context, req *api.CreateConnectorReq
return nil, errors.New("invalid config supplied")
}
for _, gt := range req.Connector.GrantTypes {
if !ConnectorGrantTypes[gt] {
return nil, fmt.Errorf("unknown grant type %q", gt)
}
}
c := storage.Connector{
ID: req.Connector.Id,
Name: req.Connector.Name,
Type: req.Connector.Type,
ResourceVersion: "1",
Config: req.Connector.Config,
GrantTypes: req.Connector.GrantTypes,
}
if err := d.s.CreateConnector(ctx, c); err != nil {
if err == storage.ErrAlreadyExists {
@ -487,14 +494,26 @@ func (d dexAPI) UpdateConnector(ctx context.Context, req *api.UpdateConnectorReq
return nil, errors.New("no email supplied")
}
if len(req.NewConfig) == 0 && req.NewName == "" && req.NewType == "" {
hasUpdate := len(req.NewConfig) != 0 ||
req.NewName != "" ||
req.NewType != "" ||
req.NewGrantTypes != nil
if !hasUpdate {
return nil, errors.New("nothing to update")
}
if !json.Valid(req.NewConfig) {
if len(req.NewConfig) != 0 && !json.Valid(req.NewConfig) {
return nil, errors.New("invalid config supplied")
}
if req.NewGrantTypes != nil {
for _, gt := range req.NewGrantTypes.GrantTypes {
if !ConnectorGrantTypes[gt] {
return nil, fmt.Errorf("unknown grant type %q", gt)
}
}
}
updater := func(old storage.Connector) (storage.Connector, error) {
if req.NewType != "" {
old.Type = req.NewType
@ -508,6 +527,10 @@ func (d dexAPI) UpdateConnector(ctx context.Context, req *api.UpdateConnectorReq
old.Config = req.NewConfig
}
if req.NewGrantTypes != nil {
old.GrantTypes = req.NewGrantTypes.GrantTypes
}
if rev, err := strconv.Atoi(defaultTo(old.ResourceVersion, "0")); err == nil {
old.ResourceVersion = strconv.Itoa(rev + 1)
}
@ -561,10 +584,11 @@ func (d dexAPI) ListConnectors(ctx context.Context, req *api.ListConnectorReq) (
connectors := make([]*api.Connector, 0, len(connectorList))
for _, connector := range connectorList {
c := api.Connector{
Id: connector.ID,
Name: connector.Name,
Type: connector.Type,
Config: connector.Config,
Id: connector.ID,
Name: connector.Name,
Type: connector.Type,
Config: connector.Config,
GrantTypes: connector.GrantTypes,
}
connectors = append(connectors, &c)
}

99
server/api_test.go

@ -606,6 +606,105 @@ func TestUpdateConnector(t *testing.T) {
}
}
func TestUpdateConnectorGrantTypes(t *testing.T) {
t.Setenv("DEX_API_CONNECTORS_CRUD", "true")
logger := newLogger(t)
s := memory.New(logger)
client := newAPI(t, s, logger)
defer client.Close()
ctx := t.Context()
connectorID := "connector-gt"
// Create a connector without grant types
createReq := api.CreateConnectorReq{
Connector: &api.Connector{
Id: connectorID,
Name: "TestConnector",
Type: "TestType",
Config: []byte(`{"key": "value"}`),
},
}
_, err := client.CreateConnector(ctx, &createReq)
if err != nil {
t.Fatalf("failed to create connector: %v", err)
}
// Set grant types
_, err = client.UpdateConnector(ctx, &api.UpdateConnectorReq{
Id: connectorID,
NewGrantTypes: &api.GrantTypes{GrantTypes: []string{"authorization_code", "refresh_token"}},
})
if err != nil {
t.Fatalf("failed to update connector grant types: %v", err)
}
resp, err := client.ListConnectors(ctx, &api.ListConnectorReq{})
if err != nil {
t.Fatalf("failed to list connectors: %v", err)
}
for _, c := range resp.Connectors {
if c.Id == connectorID {
if !slices.Equal(c.GrantTypes, []string{"authorization_code", "refresh_token"}) {
t.Fatalf("expected grant types [authorization_code refresh_token], got %v", c.GrantTypes)
}
}
}
// Clear grant types by passing empty GrantTypes message
_, err = client.UpdateConnector(ctx, &api.UpdateConnectorReq{
Id: connectorID,
NewGrantTypes: &api.GrantTypes{},
})
if err != nil {
t.Fatalf("failed to clear connector grant types: %v", err)
}
resp, err = client.ListConnectors(ctx, &api.ListConnectorReq{})
if err != nil {
t.Fatalf("failed to list connectors: %v", err)
}
for _, c := range resp.Connectors {
if c.Id == connectorID {
if len(c.GrantTypes) != 0 {
t.Fatalf("expected empty grant types after clear, got %v", c.GrantTypes)
}
}
}
// Reject invalid grant type on update
_, err = client.UpdateConnector(ctx, &api.UpdateConnectorReq{
Id: connectorID,
NewGrantTypes: &api.GrantTypes{GrantTypes: []string{"bogus"}},
})
if err == nil {
t.Fatal("expected error for invalid grant type, got nil")
}
if !strings.Contains(err.Error(), `unknown grant type "bogus"`) {
t.Fatalf("unexpected error: %v", err)
}
// Reject invalid grant type on create
_, err = client.CreateConnector(ctx, &api.CreateConnectorReq{
Connector: &api.Connector{
Id: "bad-gt",
Name: "Bad",
Type: "TestType",
Config: []byte(`{}`),
GrantTypes: []string{"invalid_type"},
},
})
if err == nil {
t.Fatal("expected error for invalid grant type on create, got nil")
}
if !strings.Contains(err.Error(), `unknown grant type "invalid_type"`) {
t.Fatalf("unexpected error: %v", err)
}
}
func TestDeleteConnector(t *testing.T) {
t.Setenv("DEX_API_CONNECTORS_CRUD", "true")

87
server/handlers.go

@ -142,6 +142,21 @@ func (s *Server) constructDiscovery(ctx context.Context) discovery {
return d
}
// grantTypeFromAuthRequest determines the grant type from the authorization request parameters.
func (s *Server) grantTypeFromAuthRequest(r *http.Request) string {
redirectURI := r.Form.Get("redirect_uri")
if redirectURI == deviceCallbackURI || strings.HasSuffix(redirectURI, deviceCallbackURI) {
return grantTypeDeviceCode
}
responseType := r.Form.Get("response_type")
for _, rt := range strings.Fields(responseType) {
if rt == "token" || rt == "id_token" {
return grantTypeImplicit
}
}
return grantTypeAuthorizationCode
}
// handleAuthorization handles the OAuth2 auth endpoint.
func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
@ -154,13 +169,27 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
}
connectorID := r.Form.Get("connector_id")
connectors, err := s.storage.ListConnectors(ctx)
allConnectors, err := s.storage.ListConnectors(ctx)
if err != nil {
s.logger.ErrorContext(r.Context(), "failed to get list of connectors", "err", err)
s.renderError(r, w, http.StatusInternalServerError, "Failed to retrieve connector list.")
return
}
// Determine the grant type from the authorization request to filter connectors.
grantType := s.grantTypeFromAuthRequest(r)
connectors := make([]storage.Connector, 0, len(allConnectors))
for _, c := range allConnectors {
if GrantTypeAllowed(c.GrantTypes, grantType) {
connectors = append(connectors, c)
}
}
if len(connectors) == 0 {
s.renderError(r, w, http.StatusBadRequest, "No connectors available for the requested grant type.")
return
}
// We don't need connector_id any more
r.Form.Del("connector_id")
@ -187,15 +216,15 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, connURL.String(), http.StatusFound)
}
connectorInfos := make([]connectorInfo, len(connectors))
for index, conn := range connectors {
connectorInfos := make([]connectorInfo, 0, len(connectors))
for _, conn := range connectors {
connURL.Path = s.absPath("/auth", url.PathEscape(conn.ID))
connectorInfos[index] = connectorInfo{
connectorInfos = append(connectorInfos, connectorInfo{
ID: conn.ID,
Name: conn.Name,
Type: conn.Type,
URL: template.URL(connURL.String()),
}
})
}
if err := s.templates.login(r, w, connectorInfos); err != nil {
@ -235,6 +264,15 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
return
}
// Check if the connector allows the requested grant type.
grantType := s.grantTypeFromAuthRequest(r)
if !GrantTypeAllowed(conn.GrantTypes, grantType) {
s.logger.ErrorContext(r.Context(), "connector does not allow requested grant type",
"connector_id", connID, "grant_type", grantType)
s.renderError(r, w, http.StatusBadRequest, "Requested connector does not support this grant type.")
return
}
// Set the connector being used for the login.
if authReq.ConnectorID != "" && authReq.ConnectorID != connID {
s.logger.ErrorContext(r.Context(), "mismatched connector ID in auth request",
@ -995,9 +1033,16 @@ func (s *Server) exchangeAuthCode(ctx context.Context, w http.ResponseWriter, au
}
reqRefresh := func() bool {
// Ensure the connector supports refresh tokens.
// Determine whether to issue a refresh token. A refresh token is only
// issued when all of the following are true:
// 1. The connector implements RefreshConnector.
// 2. The connector's grantTypes config allows refresh_token.
// 3. The client requested the offline_access scope.
//
// Connectors like `saml` do not implement RefreshConnector.
// When any condition is not met, the refresh token is silently omitted
// rather than returning an error. This matches the OAuth2 spec: the
// server is never required to issue a refresh token (RFC 6749 §1.5).
// https://datatracker.ietf.org/doc/html/rfc6749#section-1.5
conn, err := s.getConnector(ctx, authCode.ConnectorID)
if err != nil {
s.logger.ErrorContext(ctx, "connector not found", "connector_id", authCode.ConnectorID, "err", err)
@ -1010,6 +1055,10 @@ func (s *Server) exchangeAuthCode(ctx context.Context, w http.ResponseWriter, au
return false
}
if !GrantTypeAllowed(conn.GrantTypes, grantTypeRefreshToken) {
return false
}
for _, scope := range authCode.Scopes {
if scope == scopeOfflineAccess {
return true
@ -1215,6 +1264,11 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli
s.tokenErrHelper(w, errInvalidRequest, "Requested connector does not exist.", http.StatusBadRequest)
return
}
if !GrantTypeAllowed(conn.GrantTypes, grantTypePassword) {
s.logger.ErrorContext(r.Context(), "connector does not allow password grant", "connector_id", connID)
s.tokenErrHelper(w, errInvalidRequest, "Requested connector does not support password grant.", http.StatusBadRequest)
return
}
passwordConnector, ok := conn.Connector.(connector.PasswordConnector)
if !ok {
@ -1261,11 +1315,15 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli
}
reqRefresh := func() bool {
// Ensure the connector supports refresh tokens.
//
// Connectors like `saml` do not implement RefreshConnector.
_, ok := conn.Connector.(connector.RefreshConnector)
if !ok {
// Same logic as in exchangeAuthCode: silently omit refresh token
// when the connector doesn't support it or grantTypes forbids it.
// See RFC 6749 §1.5 — refresh tokens are never mandatory.
// https://datatracker.ietf.org/doc/html/rfc6749#section-1.5
if _, ok := conn.Connector.(connector.RefreshConnector); !ok {
return false
}
if !GrantTypeAllowed(conn.GrantTypes, grantTypeRefreshToken) {
return false
}
@ -1422,6 +1480,11 @@ func (s *Server) handleTokenExchange(w http.ResponseWriter, r *http.Request, cli
s.tokenErrHelper(w, errInvalidRequest, "Requested connector does not exist.", http.StatusBadRequest)
return
}
if !GrantTypeAllowed(conn.GrantTypes, grantTypeTokenExchange) {
s.logger.ErrorContext(r.Context(), "connector does not allow token exchange", "connector_id", connID)
s.tokenErrHelper(w, errInvalidRequest, "Requested connector does not support token exchange.", http.StatusBadRequest)
return
}
teConn, ok := conn.Connector.(connector.TokenIdentityConnector)
if !ok {
s.logger.ErrorContext(r.Context(), "connector doesn't implement token exchange", "connector_id", connID)

157
server/handlers_test.go

@ -1060,6 +1060,163 @@ func TestHandleTokenExchange(t *testing.T) {
}
}
func TestHandleTokenExchangeConnectorGrantTypeRestriction(t *testing.T) {
ctx := t.Context()
httpServer, s := newTestServer(t, func(c *Config) {
c.Storage.CreateClient(ctx, storage.Client{
ID: "client_1",
Secret: "secret_1",
})
})
defer httpServer.Close()
// Restrict mock connector to authorization_code only
err := s.storage.UpdateConnector(ctx, "mock", func(c storage.Connector) (storage.Connector, error) {
c.GrantTypes = []string{grantTypeAuthorizationCode}
return c, nil
})
require.NoError(t, err)
// Clear cached connector to pick up new grant types
s.mu.Lock()
delete(s.connectors, "mock")
s.mu.Unlock()
vals := make(url.Values)
vals.Set("grant_type", grantTypeTokenExchange)
vals.Set("connector_id", "mock")
vals.Set("scope", "openid")
vals.Set("requested_token_type", tokenTypeAccess)
vals.Set("subject_token_type", tokenTypeID)
vals.Set("subject_token", "foobar")
vals.Set("client_id", "client_1")
vals.Set("client_secret", "secret_1")
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, httpServer.URL+"/token", strings.NewReader(vals.Encode()))
req.Header.Set("content-type", "application/x-www-form-urlencoded")
s.handleToken(rr, req)
require.Equal(t, http.StatusBadRequest, rr.Code, rr.Body.String())
}
func TestHandleAuthorizationConnectorGrantTypeFiltering(t *testing.T) {
tests := []struct {
name string
// grantTypes per connector ID; nil means unrestricted
connectorGrantTypes map[string][]string
responseType string
wantCode int
// wantRedirectContains is checked when wantCode == 302
wantRedirectContains string
// wantBodyContains is checked when wantCode != 302
wantBodyContains string
}{
{
name: "one connector filtered, redirect to remaining",
connectorGrantTypes: map[string][]string{
"mock": {grantTypeDeviceCode},
"mock2": nil,
},
responseType: "code",
wantCode: http.StatusFound,
wantRedirectContains: "/auth/mock2",
},
{
name: "all connectors filtered",
connectorGrantTypes: map[string][]string{
"mock": {grantTypeDeviceCode},
"mock2": {grantTypeDeviceCode},
},
responseType: "code",
wantCode: http.StatusBadRequest,
wantBodyContains: "No connectors available",
},
{
name: "no restrictions, both available",
connectorGrantTypes: map[string][]string{
"mock": nil,
"mock2": nil,
},
responseType: "code",
wantCode: http.StatusOK,
},
{
name: "implicit flow filters auth_code-only connector",
connectorGrantTypes: map[string][]string{
"mock": {grantTypeAuthorizationCode},
"mock2": nil,
},
responseType: "token",
wantCode: http.StatusFound,
wantRedirectContains: "/auth/mock2",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ctx := t.Context()
httpServer, s := newTestServerMultipleConnectors(t, nil)
defer httpServer.Close()
for id, gts := range tc.connectorGrantTypes {
err := s.storage.UpdateConnector(ctx, id, func(c storage.Connector) (storage.Connector, error) {
c.GrantTypes = gts
return c, nil
})
require.NoError(t, err)
s.mu.Lock()
delete(s.connectors, id)
s.mu.Unlock()
}
rr := httptest.NewRecorder()
reqURL := fmt.Sprintf("%s/auth?response_type=%s&client_id=test&redirect_uri=http://example.com/callback&scope=openid", httpServer.URL, tc.responseType)
req := httptest.NewRequest(http.MethodGet, reqURL, nil)
s.handleAuthorization(rr, req)
require.Equal(t, tc.wantCode, rr.Code)
if tc.wantRedirectContains != "" {
require.Contains(t, rr.Header().Get("Location"), tc.wantRedirectContains)
}
if tc.wantBodyContains != "" {
require.Contains(t, rr.Body.String(), tc.wantBodyContains)
}
})
}
}
func TestHandleConnectorLoginGrantTypeRejection(t *testing.T) {
ctx := t.Context()
httpServer, s := newTestServer(t, func(c *Config) {
c.Storage.CreateClient(ctx, storage.Client{
ID: "test-client",
Secret: "secret",
RedirectURIs: []string{"http://example.com/callback"},
})
})
defer httpServer.Close()
// Restrict mock connector to device_code only
err := s.storage.UpdateConnector(ctx, "mock", func(c storage.Connector) (storage.Connector, error) {
c.GrantTypes = []string{grantTypeDeviceCode}
return c, nil
})
require.NoError(t, err)
s.mu.Lock()
delete(s.connectors, "mock")
s.mu.Unlock()
// Try to use mock connector for auth code flow via the full server router
rr := httptest.NewRecorder()
reqURL := httpServer.URL + "/auth/mock?response_type=code&client_id=test-client&redirect_uri=http://example.com/callback&scope=openid"
req := httptest.NewRequest(http.MethodGet, reqURL, nil)
s.ServeHTTP(rr, req)
require.Equal(t, http.StatusBadRequest, rr.Code)
require.Contains(t, rr.Body.String(), "does not support this grant type")
}
func setNonEmpty(vals url.Values, key, value string) {
if value != "" {
vals.Set(key, value)

10
server/oauth2.go

@ -146,6 +146,16 @@ const (
grantTypeClientCredentials = "client_credentials"
)
// ConnectorGrantTypes is the set of grant types that can be restricted per connector.
var ConnectorGrantTypes = map[string]bool{
grantTypeAuthorizationCode: true,
grantTypeRefreshToken: true,
grantTypeImplicit: true,
grantTypePassword: true,
grantTypeDeviceCode: true,
grantTypeTokenExchange: true,
}
const (
// https://www.rfc-editor.org/rfc/rfc8693.html#section-3
tokenTypeAccess = "urn:ietf:params:oauth:token-type:access_token"

4
server/refreshhandlers.go

@ -202,6 +202,10 @@ func (s *Server) getRefreshTokenFromStorage(ctx context.Context, clientID *strin
s.logger.ErrorContext(ctx, "connector not found", "connector_id", refresh.ConnectorID, "err", err)
return nil, newInternalServerError()
}
if !GrantTypeAllowed(refreshCtx.connector.GrantTypes, grantTypeRefreshToken) {
s.logger.ErrorContext(ctx, "connector does not allow refresh token grant", "connector_id", refresh.ConnectorID)
return nil, &refreshError{msg: errInvalidRequest, desc: "Connector does not support refresh tokens.", code: http.StatusBadRequest}
}
// Get Connector Data
session, err := s.storage.GetOfflineSessions(ctx, refresh.Claims.UserID, refresh.ConnectorID)

9
server/server.go

@ -13,6 +13,7 @@ import (
"net/url"
"os"
"path"
"slices"
"sort"
"strings"
"sync"
@ -57,6 +58,13 @@ const LocalConnector = "local"
type Connector struct {
ResourceVersion string
Connector connector.Connector
GrantTypes []string
}
// GrantTypeAllowed checks if the given grant type is allowed for this connector.
// If no grant types are configured, all are allowed.
func GrantTypeAllowed(configuredTypes []string, grantType string) bool {
return len(configuredTypes) == 0 || slices.Contains(configuredTypes, grantType)
}
// Config holds the server's configuration options.
@ -739,6 +747,7 @@ func (s *Server) OpenConnector(conn storage.Connector) (Connector, error) {
connector := Connector{
ResourceVersion: conn.ResourceVersion,
Connector: c,
GrantTypes: conn.GrantTypes,
}
s.mu.Lock()
s.connectors[conn.ID] = connector

11
storage/conformance/conformance.go

@ -630,10 +630,11 @@ func testConnectorCRUD(t *testing.T, s storage.Storage) {
id1 := storage.NewID()
config1 := []byte(`{"issuer": "https://accounts.google.com"}`)
c1 := storage.Connector{
ID: id1,
Type: "Default",
Name: "Default",
Config: config1,
ID: id1,
Type: "Default",
Name: "Default",
Config: config1,
GrantTypes: []string{"authorization_code", "refresh_token"},
}
if err := s.CreateConnector(ctx, c1); err != nil {
@ -674,12 +675,14 @@ func testConnectorCRUD(t *testing.T, s storage.Storage) {
if err := s.UpdateConnector(ctx, c1.ID, func(old storage.Connector) (storage.Connector, error) {
old.Type = "oidc"
old.GrantTypes = []string{"urn:ietf:params:oauth:grant-type:token-exchange"}
return old, nil
}); err != nil {
t.Fatalf("failed to update Connector: %v", err)
}
c1.Type = "oidc"
c1.GrantTypes = []string{"urn:ietf:params:oauth:grant-type:token-exchange"}
getAndCompare(id1, c1)
connectorList := []storage.Connector{c1, c2}

2
storage/ent/client/connector.go

@ -14,6 +14,7 @@ func (d *Database) CreateConnector(ctx context.Context, connector storage.Connec
SetType(connector.Type).
SetResourceVersion(connector.ResourceVersion).
SetConfig(connector.Config).
SetGrantTypes(connector.GrantTypes).
Save(ctx)
if err != nil {
return convertDBError("create connector: %w", err)
@ -75,6 +76,7 @@ func (d *Database) UpdateConnector(ctx context.Context, id string, updater func(
SetType(newConnector.Type).
SetResourceVersion(newConnector.ResourceVersion).
SetConfig(newConnector.Config).
SetGrantTypes(newConnector.GrantTypes).
Save(ctx)
if err != nil {
return rollback(tx, "update connector uploading: %w", err)

9
storage/ent/client/types.go

@ -88,10 +88,11 @@ func toStorageClient(c *db.OAuth2Client) storage.Client {
func toStorageConnector(c *db.Connector) storage.Connector {
return storage.Connector{
ID: c.ID,
Type: c.Type,
Name: c.Name,
Config: c.Config,
ID: c.ID,
Type: c.Type,
Name: c.Name,
Config: c.Config,
GrantTypes: c.GrantTypes,
}
}

18
storage/ent/db/connector.go

@ -3,6 +3,7 @@
package db
import (
"encoding/json"
"fmt"
"strings"
@ -23,7 +24,9 @@ type Connector struct {
// ResourceVersion holds the value of the "resource_version" field.
ResourceVersion string `json:"resource_version,omitempty"`
// Config holds the value of the "config" field.
Config []byte `json:"config,omitempty"`
Config []byte `json:"config,omitempty"`
// GrantTypes holds the value of the "grant_types" field.
GrantTypes []string `json:"grant_types,omitempty"`
selectValues sql.SelectValues
}
@ -32,7 +35,7 @@ func (*Connector) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
case connector.FieldConfig:
case connector.FieldConfig, connector.FieldGrantTypes:
values[i] = new([]byte)
case connector.FieldID, connector.FieldType, connector.FieldName, connector.FieldResourceVersion:
values[i] = new(sql.NullString)
@ -81,6 +84,14 @@ func (_m *Connector) assignValues(columns []string, values []any) error {
} else if value != nil {
_m.Config = *value
}
case connector.FieldGrantTypes:
if value, ok := values[i].(*[]byte); !ok {
return fmt.Errorf("unexpected type %T for field grant_types", values[i])
} else if value != nil && len(*value) > 0 {
if err := json.Unmarshal(*value, &_m.GrantTypes); err != nil {
return fmt.Errorf("unmarshal field grant_types: %w", err)
}
}
default:
_m.selectValues.Set(columns[i], values[i])
}
@ -128,6 +139,9 @@ func (_m *Connector) String() string {
builder.WriteString(", ")
builder.WriteString("config=")
builder.WriteString(fmt.Sprintf("%v", _m.Config))
builder.WriteString(", ")
builder.WriteString("grant_types=")
builder.WriteString(fmt.Sprintf("%v", _m.GrantTypes))
builder.WriteByte(')')
return builder.String()
}

3
storage/ent/db/connector/connector.go

@ -19,6 +19,8 @@ const (
FieldResourceVersion = "resource_version"
// FieldConfig holds the string denoting the config field in the database.
FieldConfig = "config"
// FieldGrantTypes holds the string denoting the grant_types field in the database.
FieldGrantTypes = "grant_types"
// Table holds the table name of the connector in the database.
Table = "connectors"
)
@ -30,6 +32,7 @@ var Columns = []string{
FieldName,
FieldResourceVersion,
FieldConfig,
FieldGrantTypes,
}
// ValidColumn reports if the column name is valid (part of the table columns).

10
storage/ent/db/connector/where.go

@ -317,6 +317,16 @@ func ConfigLTE(v []byte) predicate.Connector {
return predicate.Connector(sql.FieldLTE(FieldConfig, v))
}
// GrantTypesIsNil applies the IsNil predicate on the "grant_types" field.
func GrantTypesIsNil() predicate.Connector {
return predicate.Connector(sql.FieldIsNull(FieldGrantTypes))
}
// GrantTypesNotNil applies the NotNil predicate on the "grant_types" field.
func GrantTypesNotNil() predicate.Connector {
return predicate.Connector(sql.FieldNotNull(FieldGrantTypes))
}
// And groups predicates with the AND operator between them.
func And(predicates ...predicate.Connector) predicate.Connector {
return predicate.Connector(sql.AndPredicates(predicates...))

10
storage/ent/db/connector_create.go

@ -43,6 +43,12 @@ func (_c *ConnectorCreate) SetConfig(v []byte) *ConnectorCreate {
return _c
}
// SetGrantTypes sets the "grant_types" field.
func (_c *ConnectorCreate) SetGrantTypes(v []string) *ConnectorCreate {
_c.mutation.SetGrantTypes(v)
return _c
}
// SetID sets the "id" field.
func (_c *ConnectorCreate) SetID(v string) *ConnectorCreate {
_c.mutation.SetID(v)
@ -161,6 +167,10 @@ func (_c *ConnectorCreate) createSpec() (*Connector, *sqlgraph.CreateSpec) {
_spec.SetField(connector.FieldConfig, field.TypeBytes, value)
_node.Config = value
}
if value, ok := _c.mutation.GrantTypes(); ok {
_spec.SetField(connector.FieldGrantTypes, field.TypeJSON, value)
_node.GrantTypes = value
}
return _node, _spec
}

59
storage/ent/db/connector_update.go

@ -9,6 +9,7 @@ import (
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/dialect/sql/sqljson"
"entgo.io/ent/schema/field"
"github.com/dexidp/dex/storage/ent/db/connector"
"github.com/dexidp/dex/storage/ent/db/predicate"
@ -75,6 +76,24 @@ func (_u *ConnectorUpdate) SetConfig(v []byte) *ConnectorUpdate {
return _u
}
// SetGrantTypes sets the "grant_types" field.
func (_u *ConnectorUpdate) SetGrantTypes(v []string) *ConnectorUpdate {
_u.mutation.SetGrantTypes(v)
return _u
}
// AppendGrantTypes appends value to the "grant_types" field.
func (_u *ConnectorUpdate) AppendGrantTypes(v []string) *ConnectorUpdate {
_u.mutation.AppendGrantTypes(v)
return _u
}
// ClearGrantTypes clears the value of the "grant_types" field.
func (_u *ConnectorUpdate) ClearGrantTypes() *ConnectorUpdate {
_u.mutation.ClearGrantTypes()
return _u
}
// Mutation returns the ConnectorMutation object of the builder.
func (_u *ConnectorUpdate) Mutation() *ConnectorMutation {
return _u.mutation
@ -146,6 +165,17 @@ func (_u *ConnectorUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.Config(); ok {
_spec.SetField(connector.FieldConfig, field.TypeBytes, value)
}
if value, ok := _u.mutation.GrantTypes(); ok {
_spec.SetField(connector.FieldGrantTypes, field.TypeJSON, value)
}
if value, ok := _u.mutation.AppendedGrantTypes(); ok {
_spec.AddModifier(func(u *sql.UpdateBuilder) {
sqljson.Append(u, connector.FieldGrantTypes, value)
})
}
if _u.mutation.GrantTypesCleared() {
_spec.ClearField(connector.FieldGrantTypes, field.TypeJSON)
}
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{connector.Label}
@ -214,6 +244,24 @@ func (_u *ConnectorUpdateOne) SetConfig(v []byte) *ConnectorUpdateOne {
return _u
}
// SetGrantTypes sets the "grant_types" field.
func (_u *ConnectorUpdateOne) SetGrantTypes(v []string) *ConnectorUpdateOne {
_u.mutation.SetGrantTypes(v)
return _u
}
// AppendGrantTypes appends value to the "grant_types" field.
func (_u *ConnectorUpdateOne) AppendGrantTypes(v []string) *ConnectorUpdateOne {
_u.mutation.AppendGrantTypes(v)
return _u
}
// ClearGrantTypes clears the value of the "grant_types" field.
func (_u *ConnectorUpdateOne) ClearGrantTypes() *ConnectorUpdateOne {
_u.mutation.ClearGrantTypes()
return _u
}
// Mutation returns the ConnectorMutation object of the builder.
func (_u *ConnectorUpdateOne) Mutation() *ConnectorMutation {
return _u.mutation
@ -315,6 +363,17 @@ func (_u *ConnectorUpdateOne) sqlSave(ctx context.Context) (_node *Connector, er
if value, ok := _u.mutation.Config(); ok {
_spec.SetField(connector.FieldConfig, field.TypeBytes, value)
}
if value, ok := _u.mutation.GrantTypes(); ok {
_spec.SetField(connector.FieldGrantTypes, field.TypeJSON, value)
}
if value, ok := _u.mutation.AppendedGrantTypes(); ok {
_spec.AddModifier(func(u *sql.UpdateBuilder) {
sqljson.Append(u, connector.FieldGrantTypes, value)
})
}
if _u.mutation.GrantTypesCleared() {
_spec.ClearField(connector.FieldGrantTypes, field.TypeJSON)
}
_node = &Connector{config: _u.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues

1
storage/ent/db/migrate/schema.go

@ -70,6 +70,7 @@ var (
{Name: "name", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}},
{Name: "resource_version", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}},
{Name: "config", Type: field.TypeBytes},
{Name: "grant_types", Type: field.TypeJSON, Nullable: true},
}
// ConnectorsTable holds the schema information for the "connectors" table.
ConnectorsTable = &schema.Table{

119
storage/ent/db/mutation.go

@ -2720,17 +2720,19 @@ func (m *AuthRequestMutation) ResetEdge(name string) error {
// ConnectorMutation represents an operation that mutates the Connector nodes in the graph.
type ConnectorMutation struct {
config
op Op
typ string
id *string
_type *string
name *string
resource_version *string
_config *[]byte
clearedFields map[string]struct{}
done bool
oldValue func(context.Context) (*Connector, error)
predicates []predicate.Connector
op Op
typ string
id *string
_type *string
name *string
resource_version *string
_config *[]byte
grant_types *[]string
appendgrant_types []string
clearedFields map[string]struct{}
done bool
oldValue func(context.Context) (*Connector, error)
predicates []predicate.Connector
}
var _ ent.Mutation = (*ConnectorMutation)(nil)
@ -2981,6 +2983,71 @@ func (m *ConnectorMutation) ResetConfig() {
m._config = nil
}
// SetGrantTypes sets the "grant_types" field.
func (m *ConnectorMutation) SetGrantTypes(s []string) {
m.grant_types = &s
m.appendgrant_types = nil
}
// GrantTypes returns the value of the "grant_types" field in the mutation.
func (m *ConnectorMutation) GrantTypes() (r []string, exists bool) {
v := m.grant_types
if v == nil {
return
}
return *v, true
}
// OldGrantTypes returns the old "grant_types" field's value of the Connector entity.
// If the Connector object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *ConnectorMutation) OldGrantTypes(ctx context.Context) (v []string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldGrantTypes is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldGrantTypes requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldGrantTypes: %w", err)
}
return oldValue.GrantTypes, nil
}
// AppendGrantTypes adds s to the "grant_types" field.
func (m *ConnectorMutation) AppendGrantTypes(s []string) {
m.appendgrant_types = append(m.appendgrant_types, s...)
}
// AppendedGrantTypes returns the list of values that were appended to the "grant_types" field in this mutation.
func (m *ConnectorMutation) AppendedGrantTypes() ([]string, bool) {
if len(m.appendgrant_types) == 0 {
return nil, false
}
return m.appendgrant_types, true
}
// ClearGrantTypes clears the value of the "grant_types" field.
func (m *ConnectorMutation) ClearGrantTypes() {
m.grant_types = nil
m.appendgrant_types = nil
m.clearedFields[connector.FieldGrantTypes] = struct{}{}
}
// GrantTypesCleared returns if the "grant_types" field was cleared in this mutation.
func (m *ConnectorMutation) GrantTypesCleared() bool {
_, ok := m.clearedFields[connector.FieldGrantTypes]
return ok
}
// ResetGrantTypes resets all changes to the "grant_types" field.
func (m *ConnectorMutation) ResetGrantTypes() {
m.grant_types = nil
m.appendgrant_types = nil
delete(m.clearedFields, connector.FieldGrantTypes)
}
// Where appends a list predicates to the ConnectorMutation builder.
func (m *ConnectorMutation) Where(ps ...predicate.Connector) {
m.predicates = append(m.predicates, ps...)
@ -3015,7 +3082,7 @@ func (m *ConnectorMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *ConnectorMutation) Fields() []string {
fields := make([]string, 0, 4)
fields := make([]string, 0, 5)
if m._type != nil {
fields = append(fields, connector.FieldType)
}
@ -3028,6 +3095,9 @@ func (m *ConnectorMutation) Fields() []string {
if m._config != nil {
fields = append(fields, connector.FieldConfig)
}
if m.grant_types != nil {
fields = append(fields, connector.FieldGrantTypes)
}
return fields
}
@ -3044,6 +3114,8 @@ func (m *ConnectorMutation) Field(name string) (ent.Value, bool) {
return m.ResourceVersion()
case connector.FieldConfig:
return m.Config()
case connector.FieldGrantTypes:
return m.GrantTypes()
}
return nil, false
}
@ -3061,6 +3133,8 @@ func (m *ConnectorMutation) OldField(ctx context.Context, name string) (ent.Valu
return m.OldResourceVersion(ctx)
case connector.FieldConfig:
return m.OldConfig(ctx)
case connector.FieldGrantTypes:
return m.OldGrantTypes(ctx)
}
return nil, fmt.Errorf("unknown Connector field %s", name)
}
@ -3098,6 +3172,13 @@ func (m *ConnectorMutation) SetField(name string, value ent.Value) error {
}
m.SetConfig(v)
return nil
case connector.FieldGrantTypes:
v, ok := value.([]string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetGrantTypes(v)
return nil
}
return fmt.Errorf("unknown Connector field %s", name)
}
@ -3127,7 +3208,11 @@ func (m *ConnectorMutation) AddField(name string, value ent.Value) error {
// ClearedFields returns all nullable fields that were cleared during this
// mutation.
func (m *ConnectorMutation) ClearedFields() []string {
return nil
var fields []string
if m.FieldCleared(connector.FieldGrantTypes) {
fields = append(fields, connector.FieldGrantTypes)
}
return fields
}
// FieldCleared returns a boolean indicating if a field with the given name was
@ -3140,6 +3225,11 @@ func (m *ConnectorMutation) FieldCleared(name string) bool {
// ClearField clears the value of the field with the given name. It returns an
// error if the field is not defined in the schema.
func (m *ConnectorMutation) ClearField(name string) error {
switch name {
case connector.FieldGrantTypes:
m.ClearGrantTypes()
return nil
}
return fmt.Errorf("unknown Connector nullable field %s", name)
}
@ -3159,6 +3249,9 @@ func (m *ConnectorMutation) ResetField(name string) error {
case connector.FieldConfig:
m.ResetConfig()
return nil
case connector.FieldGrantTypes:
m.ResetGrantTypes()
return nil
}
return fmt.Errorf("unknown Connector field %s", name)
}

2
storage/ent/schema/connector.go

@ -38,6 +38,8 @@ func (Connector) Fields() []ent.Field {
field.Text("resource_version").
SchemaType(textSchema),
field.Bytes("config"),
field.JSON("grant_types", []string{}).
Optional(),
}
}

12
storage/kubernetes/types.go

@ -721,6 +721,8 @@ type Connector struct {
Name string `json:"name,omitempty"`
// Config holds connector specific configuration information
Config []byte `json:"config,omitempty"`
// GrantTypes is a list of grant types that this connector is allowed to be used with.
GrantTypes []string `json:"grantTypes,omitempty"`
}
func (cli *client) fromStorageConnector(c storage.Connector) Connector {
@ -733,10 +735,11 @@ func (cli *client) fromStorageConnector(c storage.Connector) Connector {
Name: c.ID,
Namespace: cli.namespace,
},
ID: c.ID,
Type: c.Type,
Name: c.Name,
Config: c.Config,
ID: c.ID,
Type: c.Type,
Name: c.Name,
Config: c.Config,
GrantTypes: c.GrantTypes,
}
}
@ -747,6 +750,7 @@ func toStorageConnector(c Connector) storage.Connector {
Name: c.Name,
ResourceVersion: c.ObjectMeta.ResourceVersion,
Config: c.Config,
GrantTypes: c.GrantTypes,
}
}

35
storage/sql/crud.go

@ -769,15 +769,19 @@ func scanOfflineSessions(s scanner) (o storage.OfflineSessions, err error) {
}
func (c *conn) CreateConnector(ctx context.Context, connector storage.Connector) error {
_, err := c.Exec(`
grantTypes, err := json.Marshal(connector.GrantTypes)
if err != nil {
return fmt.Errorf("marshal connector grant types: %v", err)
}
_, err = c.Exec(`
insert into connector (
id, type, name, resource_version, config
id, type, name, resource_version, config, grant_types
)
values (
$1, $2, $3, $4, $5
$1, $2, $3, $4, $5, $6
);
`,
connector.ID, connector.Type, connector.Name, connector.ResourceVersion, connector.Config,
connector.ID, connector.Type, connector.Name, connector.ResourceVersion, connector.Config, grantTypes,
)
if err != nil {
if c.alreadyExistsCheck(err) {
@ -799,16 +803,21 @@ func (c *conn) UpdateConnector(ctx context.Context, id string, updater func(s st
if err != nil {
return err
}
grantTypes, err := json.Marshal(newConn.GrantTypes)
if err != nil {
return fmt.Errorf("marshal connector grant types: %v", err)
}
_, err = tx.Exec(`
update connector
set
type = $1,
name = $2,
resource_version = $3,
config = $4
where id = $5;
config = $4,
grant_types = $5
where id = $6;
`,
newConn.Type, newConn.Name, newConn.ResourceVersion, newConn.Config, connector.ID,
newConn.Type, newConn.Name, newConn.ResourceVersion, newConn.Config, grantTypes, connector.ID,
)
if err != nil {
return fmt.Errorf("update connector: %v", err)
@ -824,15 +833,16 @@ func (c *conn) GetConnector(ctx context.Context, id string) (storage.Connector,
func getConnector(ctx context.Context, q querier, id string) (storage.Connector, error) {
return scanConnector(q.QueryRow(`
select
id, type, name, resource_version, config
id, type, name, resource_version, config, grant_types
from connector
where id = $1;
`, id))
}
func scanConnector(s scanner) (c storage.Connector, err error) {
var grantTypes []byte
err = s.Scan(
&c.ID, &c.Type, &c.Name, &c.ResourceVersion, &c.Config,
&c.ID, &c.Type, &c.Name, &c.ResourceVersion, &c.Config, &grantTypes,
)
if err != nil {
if err == sql.ErrNoRows {
@ -840,13 +850,18 @@ func scanConnector(s scanner) (c storage.Connector, err error) {
}
return c, fmt.Errorf("select connector: %v", err)
}
if len(grantTypes) > 0 {
if err := json.Unmarshal(grantTypes, &c.GrantTypes); err != nil {
return c, fmt.Errorf("unmarshal connector grant types: %v", err)
}
}
return c, nil
}
func (c *conn) ListConnectors(ctx context.Context) ([]storage.Connector, error) {
rows, err := c.Query(`
select
id, type, name, resource_version, config
id, type, name, resource_version, config, grant_types
from connector;
`)
if err != nil {

7
storage/sql/migrate.go

@ -374,4 +374,11 @@ var migrations = []migration{
},
flavor: &flavorMySQL,
},
{
stmts: []string{
`
alter table connector
add column grant_types bytea;`,
},
},
}

4
storage/storage.go

@ -388,6 +388,10 @@ type Connector struct {
// However, fixing this requires migrating Kubernetes objects for all previously created connectors,
// or making Dex reading both tags and act accordingly.
Config []byte `json:"email"`
// GrantTypes is a list of grant types that this connector is allowed to be used with.
// If empty, all grant types are allowed.
GrantTypes []string `json:"grantTypes,omitempty"`
}
// VerificationKey is a rotated signing key which can still be used to verify

Loading…
Cancel
Save