Browse Source

feat: saml support refresh tokens (#4565)

Signed-off-by: Ivan Zvyagintsev <ivan.zvyagintsev@flant.com>
pull/4583/head^2
Ivan Zviagintsev 3 weeks ago committed by GitHub
parent
commit
4311931881
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 68
      connector/saml/saml.go
  2. 327
      connector/saml/saml_test.go
  3. 10
      server/handlers.go
  4. 116
      server/handlers_test.go

68
connector/saml/saml.go

@ -3,8 +3,10 @@ package saml
import (
"bytes"
"context"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"encoding/xml"
"fmt"
@ -255,6 +257,39 @@ type provider struct {
logger *slog.Logger
}
// Compile-time check that provider implements RefreshConnector
var _ connector.RefreshConnector = (*provider)(nil)
// cachedIdentity stores the identity from SAML assertion for refresh token support.
// Since SAML has no native refresh mechanism, we cache the identity obtained during
// the initial authentication and return it on subsequent refresh requests.
type cachedIdentity struct {
UserID string `json:"userId"`
Username string `json:"username"`
PreferredUsername string `json:"preferredUsername"`
Email string `json:"email"`
EmailVerified bool `json:"emailVerified"`
Groups []string `json:"groups,omitempty"`
}
// marshalCachedIdentity serializes the identity into ConnectorData for refresh token support.
func marshalCachedIdentity(ident connector.Identity) (connector.Identity, error) {
ci := cachedIdentity{
UserID: ident.UserID,
Username: ident.Username,
PreferredUsername: ident.PreferredUsername,
Email: ident.Email,
EmailVerified: ident.EmailVerified,
Groups: ident.Groups,
}
connectorData, err := json.Marshal(ci)
if err != nil {
return ident, fmt.Errorf("saml: failed to marshal cached identity: %v", err)
}
ident.ConnectorData = connectorData
return ident, nil
}
func (p *provider) POSTData(s connector.Scopes, id string) (action, value string, err error) {
r := &authnRequest{
ProtocolBinding: bindingPOST,
@ -405,7 +440,7 @@ func (p *provider) HandlePOST(s connector.Scopes, samlResponse, inResponseTo str
if len(p.allowedGroups) == 0 && (!s.Groups || p.groupsAttr == "") {
// Groups not requested or not configured. We're done.
return ident, nil
return marshalCachedIdentity(ident)
}
if len(p.allowedGroups) > 0 && (!s.Groups || p.groupsAttr == "") {
@ -431,7 +466,7 @@ func (p *provider) HandlePOST(s connector.Scopes, samlResponse, inResponseTo str
if len(p.allowedGroups) == 0 {
// No allowed groups set, just return the ident
return ident, nil
return marshalCachedIdentity(ident)
}
// Look for membership in one of the allowed groups
@ -447,6 +482,35 @@ func (p *provider) HandlePOST(s connector.Scopes, samlResponse, inResponseTo str
}
// Otherwise, we're good
return marshalCachedIdentity(ident)
}
// Refresh implements connector.RefreshConnector.
// Since SAML has no native refresh mechanism, this method returns the cached
// identity from the initial SAML assertion stored in ConnectorData.
func (p *provider) Refresh(ctx context.Context, s connector.Scopes, ident connector.Identity) (connector.Identity, error) {
if len(ident.ConnectorData) == 0 {
return ident, fmt.Errorf("saml: no connector data available for refresh")
}
var ci cachedIdentity
if err := json.Unmarshal(ident.ConnectorData, &ci); err != nil {
return ident, fmt.Errorf("saml: failed to unmarshal cached identity: %v", err)
}
ident.UserID = ci.UserID
ident.Username = ci.Username
ident.PreferredUsername = ci.PreferredUsername
ident.Email = ci.Email
ident.EmailVerified = ci.EmailVerified
// Only populate groups if the client requested the groups scope.
if s.Groups {
ident.Groups = ci.Groups
} else {
ident.Groups = nil
}
return ident, nil
}

327
connector/saml/saml_test.go

@ -1,8 +1,10 @@
package saml
import (
"context"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"errors"
"log/slog"
@ -448,6 +450,24 @@ func (r responseTest) run(t *testing.T) {
}
sort.Strings(ident.Groups)
sort.Strings(r.wantIdent.Groups)
// Verify ConnectorData contains valid cached identity, then clear it
// for the main identity comparison (ConnectorData is an implementation
// detail of refresh token support).
if len(ident.ConnectorData) > 0 {
var ci cachedIdentity
if err := json.Unmarshal(ident.ConnectorData, &ci); err != nil {
t.Fatalf("failed to unmarshal ConnectorData: %v", err)
}
if ci.UserID != ident.UserID {
t.Errorf("cached identity UserID mismatch: got %q, want %q", ci.UserID, ident.UserID)
}
if ci.Email != ident.Email {
t.Errorf("cached identity Email mismatch: got %q, want %q", ci.Email, ident.Email)
}
}
ident.ConnectorData = nil
if diff := pretty.Compare(ident, r.wantIdent); diff != "" {
t.Error(diff)
}
@ -589,3 +609,310 @@ func TestVerifySignedMessageAndSignedAssertion(t *testing.T) {
func TestVerifyUnsignedMessageAndUnsignedAssertion(t *testing.T) {
runVerify(t, "testdata/idp-cert.pem", "testdata/idp-resp.xml", false)
}
func TestSAMLRefresh(t *testing.T) {
// Create a provider using the same pattern as existing tests.
c := Config{
CA: "testdata/ca.crt",
UsernameAttr: "Name",
EmailAttr: "email",
GroupsAttr: "groups",
RedirectURI: "http://127.0.0.1:5556/dex/callback",
SSOURL: "http://foo.bar/",
}
conn, err := c.openConnector(slog.New(slog.DiscardHandler))
if err != nil {
t.Fatal(err)
}
t.Run("SuccessfulRefresh", func(t *testing.T) {
ci := cachedIdentity{
UserID: "test-user-id",
Username: "testuser",
PreferredUsername: "testuser",
Email: "test@example.com",
EmailVerified: true,
Groups: []string{"group1", "group2"},
}
connectorData, err := json.Marshal(ci)
if err != nil {
t.Fatal(err)
}
ident := connector.Identity{
UserID: "old-id",
Username: "old-name",
ConnectorData: connectorData,
}
refreshed, err := conn.Refresh(context.Background(), connector.Scopes{Groups: true}, ident)
if err != nil {
t.Fatalf("Refresh failed: %v", err)
}
if refreshed.UserID != "test-user-id" {
t.Errorf("expected UserID %q, got %q", "test-user-id", refreshed.UserID)
}
if refreshed.Username != "testuser" {
t.Errorf("expected Username %q, got %q", "testuser", refreshed.Username)
}
if refreshed.PreferredUsername != "testuser" {
t.Errorf("expected PreferredUsername %q, got %q", "testuser", refreshed.PreferredUsername)
}
if refreshed.Email != "test@example.com" {
t.Errorf("expected Email %q, got %q", "test@example.com", refreshed.Email)
}
if !refreshed.EmailVerified {
t.Error("expected EmailVerified to be true")
}
if len(refreshed.Groups) != 2 || refreshed.Groups[0] != "group1" || refreshed.Groups[1] != "group2" {
t.Errorf("expected groups [group1, group2], got %v", refreshed.Groups)
}
// ConnectorData should be preserved through refresh
if len(refreshed.ConnectorData) == 0 {
t.Error("expected ConnectorData to be preserved")
}
})
t.Run("RefreshPreservesConnectorData", func(t *testing.T) {
ci := cachedIdentity{
UserID: "user-123",
Username: "alice",
Email: "alice@example.com",
EmailVerified: true,
}
connectorData, err := json.Marshal(ci)
if err != nil {
t.Fatal(err)
}
ident := connector.Identity{
UserID: "old-id",
ConnectorData: connectorData,
}
refreshed, err := conn.Refresh(context.Background(), connector.Scopes{}, ident)
if err != nil {
t.Fatalf("Refresh failed: %v", err)
}
// Verify the refreshed identity can be refreshed again (round-trip)
var roundTrip cachedIdentity
if err := json.Unmarshal(refreshed.ConnectorData, &roundTrip); err != nil {
t.Fatalf("failed to unmarshal ConnectorData after refresh: %v", err)
}
if roundTrip.UserID != "user-123" {
t.Errorf("round-trip UserID mismatch: got %q, want %q", roundTrip.UserID, "user-123")
}
})
t.Run("EmptyConnectorData", func(t *testing.T) {
ident := connector.Identity{
UserID: "test-id",
ConnectorData: nil,
}
_, err := conn.Refresh(context.Background(), connector.Scopes{}, ident)
if err == nil {
t.Error("expected error for empty ConnectorData")
}
})
t.Run("InvalidJSON", func(t *testing.T) {
ident := connector.Identity{
UserID: "test-id",
ConnectorData: []byte("not-json"),
}
_, err := conn.Refresh(context.Background(), connector.Scopes{}, ident)
if err == nil {
t.Error("expected error for invalid JSON")
}
})
t.Run("HandlePOSTThenRefresh", func(t *testing.T) {
// Full integration: HandlePOST → get ConnectorData → Refresh → verify identity
now, err := time.Parse(timeFormat, "2017-04-04T04:34:59.330Z")
if err != nil {
t.Fatal(err)
}
conn.now = func() time.Time { return now }
resp, err := os.ReadFile("testdata/good-resp.xml")
if err != nil {
t.Fatal(err)
}
samlResp := base64.StdEncoding.EncodeToString(resp)
scopes := connector.Scopes{
OfflineAccess: true,
Groups: true,
}
ident, err := conn.HandlePOST(scopes, samlResp, "6zmm5mguyebwvajyf2sdwwcw6m")
if err != nil {
t.Fatalf("HandlePOST failed: %v", err)
}
if len(ident.ConnectorData) == 0 {
t.Fatal("expected ConnectorData to be set after HandlePOST")
}
// Now refresh using the ConnectorData from HandlePOST
refreshed, err := conn.Refresh(context.Background(), scopes, ident)
if err != nil {
t.Fatalf("Refresh failed: %v", err)
}
if refreshed.UserID != ident.UserID {
t.Errorf("UserID mismatch: got %q, want %q", refreshed.UserID, ident.UserID)
}
if refreshed.Username != ident.Username {
t.Errorf("Username mismatch: got %q, want %q", refreshed.Username, ident.Username)
}
if refreshed.Email != ident.Email {
t.Errorf("Email mismatch: got %q, want %q", refreshed.Email, ident.Email)
}
if refreshed.EmailVerified != ident.EmailVerified {
t.Errorf("EmailVerified mismatch: got %v, want %v", refreshed.EmailVerified, ident.EmailVerified)
}
sort.Strings(refreshed.Groups)
sort.Strings(ident.Groups)
if len(refreshed.Groups) != len(ident.Groups) {
t.Errorf("Groups length mismatch: got %d, want %d", len(refreshed.Groups), len(ident.Groups))
}
for i := range ident.Groups {
if i < len(refreshed.Groups) && refreshed.Groups[i] != ident.Groups[i] {
t.Errorf("Groups[%d] mismatch: got %q, want %q", i, refreshed.Groups[i], ident.Groups[i])
}
}
})
t.Run("HandlePOSTThenDoubleRefresh", func(t *testing.T) {
// Verify that refresh tokens can be chained: HandlePOST → Refresh → Refresh
now, err := time.Parse(timeFormat, "2017-04-04T04:34:59.330Z")
if err != nil {
t.Fatal(err)
}
conn.now = func() time.Time { return now }
resp, err := os.ReadFile("testdata/good-resp.xml")
if err != nil {
t.Fatal(err)
}
samlResp := base64.StdEncoding.EncodeToString(resp)
scopes := connector.Scopes{OfflineAccess: true, Groups: true}
ident, err := conn.HandlePOST(scopes, samlResp, "6zmm5mguyebwvajyf2sdwwcw6m")
if err != nil {
t.Fatalf("HandlePOST failed: %v", err)
}
// First refresh
refreshed1, err := conn.Refresh(context.Background(), scopes, ident)
if err != nil {
t.Fatalf("first Refresh failed: %v", err)
}
if len(refreshed1.ConnectorData) == 0 {
t.Fatal("expected ConnectorData after first refresh")
}
// Second refresh using output of first refresh
refreshed2, err := conn.Refresh(context.Background(), scopes, refreshed1)
if err != nil {
t.Fatalf("second Refresh failed: %v", err)
}
// All fields should match original
if refreshed2.UserID != ident.UserID {
t.Errorf("UserID mismatch after double refresh: got %q, want %q", refreshed2.UserID, ident.UserID)
}
if refreshed2.Email != ident.Email {
t.Errorf("Email mismatch after double refresh: got %q, want %q", refreshed2.Email, ident.Email)
}
if refreshed2.Username != ident.Username {
t.Errorf("Username mismatch after double refresh: got %q, want %q", refreshed2.Username, ident.Username)
}
})
t.Run("HandlePOSTWithAssertionSignedThenRefresh", func(t *testing.T) {
// Test with assertion-signed.xml (signature on assertion, not response)
now, err := time.Parse(timeFormat, "2017-04-04T04:34:59.330Z")
if err != nil {
t.Fatal(err)
}
conn.now = func() time.Time { return now }
resp, err := os.ReadFile("testdata/assertion-signed.xml")
if err != nil {
t.Fatal(err)
}
samlResp := base64.StdEncoding.EncodeToString(resp)
scopes := connector.Scopes{OfflineAccess: true, Groups: true}
ident, err := conn.HandlePOST(scopes, samlResp, "6zmm5mguyebwvajyf2sdwwcw6m")
if err != nil {
t.Fatalf("HandlePOST with assertion-signed failed: %v", err)
}
if len(ident.ConnectorData) == 0 {
t.Fatal("expected ConnectorData after HandlePOST with assertion-signed")
}
refreshed, err := conn.Refresh(context.Background(), scopes, ident)
if err != nil {
t.Fatalf("Refresh after assertion-signed HandlePOST failed: %v", err)
}
if refreshed.Email != ident.Email {
t.Errorf("Email mismatch: got %q, want %q", refreshed.Email, ident.Email)
}
if refreshed.Username != ident.Username {
t.Errorf("Username mismatch: got %q, want %q", refreshed.Username, ident.Username)
}
})
t.Run("HandlePOSTRefreshWithoutGroupsScope", func(t *testing.T) {
// Verify that groups are NOT returned when groups scope is not requested during refresh
now, err := time.Parse(timeFormat, "2017-04-04T04:34:59.330Z")
if err != nil {
t.Fatal(err)
}
conn.now = func() time.Time { return now }
resp, err := os.ReadFile("testdata/good-resp.xml")
if err != nil {
t.Fatal(err)
}
samlResp := base64.StdEncoding.EncodeToString(resp)
// Initial auth WITH groups
scopesWithGroups := connector.Scopes{OfflineAccess: true, Groups: true}
ident, err := conn.HandlePOST(scopesWithGroups, samlResp, "6zmm5mguyebwvajyf2sdwwcw6m")
if err != nil {
t.Fatalf("HandlePOST failed: %v", err)
}
if len(ident.Groups) == 0 {
t.Fatal("expected groups in initial identity")
}
// Refresh WITHOUT groups scope
scopesNoGroups := connector.Scopes{OfflineAccess: true, Groups: false}
refreshed, err := conn.Refresh(context.Background(), scopesNoGroups, ident)
if err != nil {
t.Fatalf("Refresh failed: %v", err)
}
if len(refreshed.Groups) != 0 {
t.Errorf("expected no groups when groups scope not requested, got %v", refreshed.Groups)
}
// Refresh WITH groups scope — groups should be back
refreshedWithGroups, err := conn.Refresh(context.Background(), scopesWithGroups, ident)
if err != nil {
t.Fatalf("Refresh with groups failed: %v", err)
}
if len(refreshedWithGroups.Groups) == 0 {
t.Error("expected groups when groups scope is requested")
}
})
}

10
server/handlers.go

@ -1072,9 +1072,10 @@ func (s *Server) exchangeAuthCode(ctx context.Context, w http.ResponseWriter, au
return nil, err
}
offlineSessions := storage.OfflineSessions{
UserID: refresh.Claims.UserID,
ConnID: refresh.ConnectorID,
Refresh: make(map[string]*storage.RefreshTokenRef),
UserID: refresh.Claims.UserID,
ConnID: refresh.ConnectorID,
Refresh: make(map[string]*storage.RefreshTokenRef),
ConnectorData: refresh.ConnectorData,
}
offlineSessions.Refresh[tokenRef.ClientID] = &tokenRef
@ -1100,6 +1101,9 @@ func (s *Server) exchangeAuthCode(ctx context.Context, w http.ResponseWriter, au
// Update existing OfflineSession obj with new RefreshTokenRef.
if err := s.storage.UpdateOfflineSessions(ctx, session.UserID, session.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
old.Refresh[tokenRef.ClientID] = &tokenRef
if len(refresh.ConnectorData) > 0 {
old.ConnectorData = refresh.ConnectorData
}
return old, nil
}); err != nil {
s.logger.ErrorContext(ctx, "failed to update offline session", "err", err)

116
server/handlers_test.go

@ -21,6 +21,7 @@ import (
"golang.org/x/crypto/bcrypt"
"golang.org/x/oauth2"
"github.com/dexidp/dex/connector"
"github.com/dexidp/dex/storage"
)
@ -892,3 +893,118 @@ func setNonEmpty(vals url.Values, key, value string) {
vals.Set(key, value)
}
}
// registerTestConnector creates a connector in storage and registers it in the server's connectors map.
func registerTestConnector(t *testing.T, s *Server, connID string, c connector.Connector) {
t.Helper()
ctx := t.Context()
storageConn := storage.Connector{
ID: connID,
Type: "saml",
Name: "Test SAML",
ResourceVersion: "1",
}
if err := s.storage.CreateConnector(ctx, storageConn); err != nil {
t.Fatalf("failed to create connector in storage: %v", err)
}
s.mu.Lock()
s.connectors[connID] = Connector{
ResourceVersion: "1",
Connector: c,
}
s.mu.Unlock()
}
func TestConnectorDataPersistence(t *testing.T) {
// Test that ConnectorData is correctly stored in refresh token
// and can be used for subsequent refresh operations.
httpServer, server := newTestServer(t, func(c *Config) {
c.RefreshTokenPolicy = &RefreshTokenPolicy{rotateRefreshTokens: true}
})
defer httpServer.Close()
ctx := t.Context()
connID := "saml-conndata"
// Create a mock SAML connector that also implements RefreshConnector
mockConn := &mockSAMLRefreshConnector{
refreshIdentity: connector.Identity{
UserID: "refreshed-user",
Username: "refreshed-name",
Email: "refreshed@example.com",
EmailVerified: true,
Groups: []string{"refreshed-group"},
},
}
registerTestConnector(t, server, connID, mockConn)
// Create client
client := storage.Client{
ID: "conndata-client",
Secret: "conndata-secret",
RedirectURIs: []string{"https://example.com/callback"},
Name: "ConnData Test Client",
}
require.NoError(t, server.storage.CreateClient(ctx, client))
// Create refresh token with ConnectorData (simulating what HandlePOST would store)
connectorData := []byte(`{"userID":"user-123","username":"testuser","email":"test@example.com","emailVerified":true,"groups":["admin","dev"]}`)
refreshToken := storage.RefreshToken{
ID: "conndata-refresh",
Token: "conndata-token",
CreatedAt: time.Now(),
LastUsed: time.Now(),
ClientID: client.ID,
ConnectorID: connID,
Scopes: []string{"openid", "email", "offline_access"},
Claims: storage.Claims{
UserID: "user-123",
Username: "testuser",
Email: "test@example.com",
EmailVerified: true,
Groups: []string{"admin", "dev"},
},
ConnectorData: connectorData,
Nonce: "conndata-nonce",
}
require.NoError(t, server.storage.CreateRefresh(ctx, refreshToken))
offlineSession := storage.OfflineSessions{
UserID: "user-123",
ConnID: connID,
Refresh: map[string]*storage.RefreshTokenRef{client.ID: {ID: refreshToken.ID, ClientID: client.ID}},
ConnectorData: connectorData,
}
require.NoError(t, server.storage.CreateOfflineSessions(ctx, offlineSession))
// Verify ConnectorData is stored correctly
storedToken, err := server.storage.GetRefresh(ctx, refreshToken.ID)
require.NoError(t, err)
require.Equal(t, connectorData, storedToken.ConnectorData,
"ConnectorData should be persisted in refresh token storage")
// Verify ConnectorData is stored in offline session
storedSession, err := server.storage.GetOfflineSessions(ctx, "user-123", connID)
require.NoError(t, err)
require.Equal(t, connectorData, storedSession.ConnectorData,
"ConnectorData should be persisted in offline session storage")
}
// mockSAMLRefreshConnector implements SAMLConnector + RefreshConnector for testing.
type mockSAMLRefreshConnector struct {
refreshIdentity connector.Identity
}
func (m *mockSAMLRefreshConnector) POSTData(s connector.Scopes, requestID string) (ssoURL, samlRequest string, err error) {
return "", "", nil
}
func (m *mockSAMLRefreshConnector) HandlePOST(s connector.Scopes, samlResponse, inResponseTo string) (connector.Identity, error) {
return connector.Identity{}, nil
}
func (m *mockSAMLRefreshConnector) Refresh(ctx context.Context, s connector.Scopes, ident connector.Identity) (connector.Identity, error) {
return m.refreshIdentity, nil
}

Loading…
Cancel
Save