diff --git a/connector/connector.go b/connector/connector.go index 4b9131d4..d812390f 100644 --- a/connector/connector.go +++ b/connector/connector.go @@ -92,16 +92,6 @@ type SAMLConnector interface { HandlePOST(s Scopes, samlResponse, inResponseTo string) (identity Identity, err error) } -// SAMLSLOConnector represents a SAML connector that supports Single Logout (SLO). -// When an IdP sends a LogoutRequest, the connector parses it and returns the -// NameID of the user being logged out. -type SAMLSLOConnector interface { - // HandleSLO processes a SAML LogoutRequest from the IdP. - // It validates the request signature and returns the NameID (user identifier) - // that should be used to invalidate the user's sessions. - HandleSLO(w http.ResponseWriter, r *http.Request) (nameID string, err error) -} - // RefreshConnector is a connector that can update the client claims. type RefreshConnector interface { // Refresh is called when a client attempts to claim a refresh token. The diff --git a/connector/saml/saml.go b/connector/saml/saml.go index 65c6fec4..b2e7d9b4 100644 --- a/connector/saml/saml.go +++ b/connector/saml/saml.go @@ -11,7 +11,6 @@ import ( "encoding/xml" "fmt" "log/slog" - "net/http" "os" "strings" "sync" @@ -85,11 +84,6 @@ type Config struct { InsecureSkipSignatureValidation bool `json:"insecureSkipSignatureValidation"` - // InsecureSkipSLOSignatureValidation skips signature validation on SLO requests. - // This is insecure and should only be used for testing or when the IdP - // does not sign LogoutRequests. - InsecureSkipSLOSignatureValidation bool `json:"insecureSkipSLOSignatureValidation"` - // Assertion attribute names to lookup various claims with. UsernameAttr string `json:"usernameAttr"` EmailAttr string `json:"emailAttr"` @@ -170,8 +164,6 @@ func (c *Config) openConnector(logger *slog.Logger) (*provider, error) { logger: logger, nameIDPolicyFormat: c.NameIDPolicyFormat, - - insecureSkipSLOSignatureValidation: c.InsecureSkipSLOSignatureValidation, } if p.nameIDPolicyFormat == "" { @@ -197,8 +189,7 @@ func (c *Config) openConnector(logger *slog.Logger) (*provider, error) { } } - needValidator := !c.InsecureSkipSignatureValidation || !c.InsecureSkipSLOSignatureValidation - if needValidator { + if !c.InsecureSkipSignatureValidation { if (c.CA == "") == (c.CAData == nil) { return nil, errors.New("must provide either 'ca' or 'caData'") } @@ -263,17 +254,12 @@ type provider struct { nameIDPolicyFormat string - insecureSkipSLOSignatureValidation bool - logger *slog.Logger } // Compile-time check that provider implements RefreshConnector var _ connector.RefreshConnector = (*provider)(nil) -// Compile-time check that provider implements SAMLSLOConnector -var _ connector.SAMLSLOConnector = (*provider)(nil) - // cachedIdentity stores the identity from SAML assertion for refresh token support. // Since SAML has no native refresh mechanism, we cache the identity obtained during // the initial authentication and return it on subsequent refresh requests. @@ -723,73 +709,3 @@ func before(now, notBefore time.Time) bool { func after(now, notOnOrAfter time.Time) bool { return now.After(notOnOrAfter.Add(allowedClockDrift)) } - -// validateSignature validates the XML digital signature of the given raw XML data. -func (p *provider) validateSignature(rawXML []byte) ([]byte, error) { - if p.validator == nil { - return nil, fmt.Errorf("signature validation unavailable (no validator configured)") - } - - doc := etree.NewDocument() - if err := doc.ReadFromBytes(rawXML); err != nil { - return nil, fmt.Errorf("failed to parse XML: %v", err) - } - - // Find the Signature element - root := doc.Root() - if root == nil { - return nil, fmt.Errorf("empty XML document") - } - - _, err := p.validator.Validate(root) - if err != nil { - return nil, fmt.Errorf("signature validation failed: %v", err) - } - - return rawXML, nil -} - -// HandleSLO processes a SAML LogoutRequest from the IdP. -// It validates the request, extracts the NameID, and returns it for session invalidation. -func (p *provider) HandleSLO(w http.ResponseWriter, r *http.Request) (string, error) { - if r.Method != http.MethodPost { - return "", fmt.Errorf("saml slo: expected POST method, got %s", r.Method) - } - - if err := r.ParseForm(); err != nil { - return "", fmt.Errorf("saml slo: failed to parse form: %v", err) - } - - samlRequest := r.FormValue("SAMLRequest") - if samlRequest == "" { - return "", fmt.Errorf("saml slo: missing SAMLRequest parameter") - } - - rawRequest, err := base64.StdEncoding.DecodeString(samlRequest) - if err != nil { - return "", fmt.Errorf("saml slo: failed to decode SAMLRequest: %v", err) - } - - byteReader := bytes.NewReader(rawRequest) - if xrvErr := xrv.Validate(byteReader); xrvErr != nil { - return "", errors.Wrap(xrvErr, "validating XML logout request") - } - - // Validate signature unless explicitly skipped - if !p.insecureSkipSLOSignatureValidation { - if _, err := p.validateSignature(rawRequest); err != nil { - return "", fmt.Errorf("saml slo: signature validation failed: %v", err) - } - } - - var req logoutRequest - if err := xml.Unmarshal(rawRequest, &req); err != nil { - return "", fmt.Errorf("saml slo: failed to unmarshal LogoutRequest: %v", err) - } - - if req.NameID.Value == "" { - return "", fmt.Errorf("saml slo: LogoutRequest missing NameID") - } - - return req.NameID.Value, nil -} diff --git a/connector/saml/saml_test.go b/connector/saml/saml_test.go index ecea35f5..3eba5cf8 100644 --- a/connector/saml/saml_test.go +++ b/connector/saml/saml_test.go @@ -7,14 +7,9 @@ import ( "encoding/json" "encoding/pem" "errors" - "fmt" "log/slog" - "net/http" - "net/http/httptest" - "net/url" "os" "sort" - "strings" "testing" "time" @@ -921,151 +916,3 @@ func TestSAMLRefresh(t *testing.T) { } }) } - -func TestSAMLHandleSLO(t *testing.T) { - c := Config{ - CA: "testdata/ca.crt", - UsernameAttr: "Name", - EmailAttr: "email", - RedirectURI: "http://127.0.0.1:5556/dex/callback", - SSOURL: "http://foo.bar/", - InsecureSkipSLOSignatureValidation: true, - } - - conn, err := c.openConnector(slog.New(slog.DiscardHandler)) - if err != nil { - t.Fatal(err) - } - - // Helper to create a LogoutRequest XML - makeLogoutRequest := func(nameID string) string { - return fmt.Sprintf(` - https://idp.example.com - %s -`, nameID) - } - - t.Run("ValidLogoutRequest", func(t *testing.T) { - logoutXML := makeLogoutRequest("user@example.com") - encoded := base64.StdEncoding.EncodeToString([]byte(logoutXML)) - - form := url.Values{} - form.Set("SAMLRequest", encoded) - - req := httptest.NewRequest(http.MethodPost, "/saml/slo/test", strings.NewReader(form.Encode())) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - w := httptest.NewRecorder() - - nameID, err := conn.HandleSLO(w, req) - if err != nil { - t.Fatalf("HandleSLO failed: %v", err) - } - if nameID != "user@example.com" { - t.Errorf("expected nameID %q, got %q", "user@example.com", nameID) - } - }) - - t.Run("MissingSAMLRequest", func(t *testing.T) { - req := httptest.NewRequest(http.MethodPost, "/saml/slo/test", strings.NewReader("")) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - w := httptest.NewRecorder() - - _, err := conn.HandleSLO(w, req) - if err == nil { - t.Error("expected error for missing SAMLRequest") - } - }) - - t.Run("InvalidBase64", func(t *testing.T) { - form := url.Values{} - form.Set("SAMLRequest", "not-valid-base64!!!") - - req := httptest.NewRequest(http.MethodPost, "/saml/slo/test", strings.NewReader(form.Encode())) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - w := httptest.NewRecorder() - - _, err := conn.HandleSLO(w, req) - if err == nil { - t.Error("expected error for invalid base64") - } - }) - - t.Run("InvalidXML", func(t *testing.T) { - encoded := base64.StdEncoding.EncodeToString([]byte("not xml at all")) - form := url.Values{} - form.Set("SAMLRequest", encoded) - - req := httptest.NewRequest(http.MethodPost, "/saml/slo/test", strings.NewReader(form.Encode())) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - w := httptest.NewRecorder() - - _, err := conn.HandleSLO(w, req) - if err == nil { - t.Error("expected error for invalid XML") - } - }) - - t.Run("MissingNameID", func(t *testing.T) { - logoutXML := ` - https://idp.example.com - -` - encoded := base64.StdEncoding.EncodeToString([]byte(logoutXML)) - - form := url.Values{} - form.Set("SAMLRequest", encoded) - - req := httptest.NewRequest(http.MethodPost, "/saml/slo/test", strings.NewReader(form.Encode())) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - w := httptest.NewRecorder() - - _, err := conn.HandleSLO(w, req) - if err == nil { - t.Error("expected error for missing NameID") - } - }) - - t.Run("WrongHTTPMethod", func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "/saml/slo/test", nil) - w := httptest.NewRecorder() - - _, err := conn.HandleSLO(w, req) - if err == nil { - t.Error("expected error for GET method") - } - }) - - t.Run("DifferentNameIDValues", func(t *testing.T) { - testCases := []struct { - name string - nameIDVal string - wantNameID string - }{ - {"email format", "admin@corp.example.com", "admin@corp.example.com"}, - {"persistent ID", "AQIC5w...", "AQIC5w..."}, - {"transient ID", "_ce3d2948b4cf20146dee0a0b3dd6f69b6cf86f62d7", "_ce3d2948b4cf20146dee0a0b3dd6f69b6cf86f62d7"}, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - logoutXML := makeLogoutRequest(tc.nameIDVal) - encoded := base64.StdEncoding.EncodeToString([]byte(logoutXML)) - - form := url.Values{} - form.Set("SAMLRequest", encoded) - - req := httptest.NewRequest(http.MethodPost, "/saml/slo/test", strings.NewReader(form.Encode())) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - w := httptest.NewRecorder() - - nameID, err := conn.HandleSLO(w, req) - if err != nil { - t.Fatalf("HandleSLO failed: %v", err) - } - if nameID != tc.wantNameID { - t.Errorf("expected nameID %q, got %q", tc.wantNameID, nameID) - } - }) - } - }) -} diff --git a/connector/saml/types.go b/connector/saml/types.go index d63bcbe0..c8d7e7f3 100644 --- a/connector/saml/types.go +++ b/connector/saml/types.go @@ -275,21 +275,3 @@ func (a attribute) String() string { // "groups" = ["engineering", "docs"] return fmt.Sprintf("%q = %q", a.Name, values) } - -// logoutRequest represents a SAML 2.0 LogoutRequest message sent by the IdP -// to initiate Single Logout. -type logoutRequest struct { - XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:protocol LogoutRequest"` - ID string `xml:"ID,attr"` - Version string `xml:"Version,attr"` - IssueInstant xmlTime `xml:"IssueInstant,attr"` - Destination string `xml:"Destination,attr,omitempty"` - Issuer string `xml:"urn:oasis:names:tc:SAML:2.0:assertion Issuer"` - NameID nameID `xml:"urn:oasis:names:tc:SAML:2.0:assertion NameID"` - SessionIndex []sessionIndex `xml:"SessionIndex,omitempty"` -} - -// sessionIndex represents a SAML SessionIndex element in a LogoutRequest. -type sessionIndex struct { - Value string `xml:",chardata"` -} diff --git a/server/handlers.go b/server/handlers.go index 0bd73de1..e0133966 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -1504,84 +1504,3 @@ func usernamePrompt(conn connector.PasswordConnector) string { } return "Username" } - -func (s *Server) handleSAMLSLO(w http.ResponseWriter, r *http.Request) { - connID, err := url.PathUnescape(mux.Vars(r)["connector"]) - if err != nil { - s.logger.ErrorContext(r.Context(), "SAML SLO: failed to parse connector ID", "err", err) - http.Error(w, "Missing connector ID", http.StatusBadRequest) - return - } - - ctx := r.Context() - - conn, err := s.getConnector(ctx, connID) - if err != nil { - s.logger.ErrorContext(ctx, "SAML SLO: failed to get connector", "connector_id", connID, "err", err) - http.Error(w, "Connector not found", http.StatusNotFound) - return - } - - sloConnector, ok := conn.Connector.(connector.SAMLSLOConnector) - if !ok { - s.logger.ErrorContext(ctx, "SAML SLO: connector does not support SLO", "connector_id", connID) - http.Error(w, "Connector does not support SLO", http.StatusBadRequest) - return - } - - nameID, err := sloConnector.HandleSLO(w, r) - if err != nil { - s.logger.ErrorContext(ctx, "SAML SLO: failed to process LogoutRequest", "connector_id", connID, "err", err) - http.Error(w, "Failed to process LogoutRequest", http.StatusBadRequest) - return - } - - s.logger.InfoContext(ctx, "SAML SLO: processing logout", "connector_id", connID, "name_id", nameID) - - if err := s.invalidateUserSessions(ctx, connID, nameID); err != nil { - s.logger.ErrorContext(ctx, "SAML SLO: failed to invalidate sessions", "connector_id", connID, "name_id", nameID, "err", err) - http.Error(w, "Failed to invalidate sessions", http.StatusInternalServerError) - return - } - - s.logger.InfoContext(ctx, "SAML SLO: successfully invalidated sessions", "connector_id", connID, "name_id", nameID) - w.WriteHeader(http.StatusOK) -} - -// invalidateUserSessions removes all refresh tokens for a user identified by -// their NameID from a specific connector. It uses OfflineSessions to find -// the user's refresh tokens and deletes them. -func (s *Server) invalidateUserSessions(ctx context.Context, connectorID, nameID string) error { - // NameID from SAML is used as the user ID in the connector. - // OfflineSessions are keyed by (userID, connectorID). - // The userID in OfflineSessions is the connector-specific user ID (NameID for SAML). - - offlineSessions, err := s.storage.GetOfflineSessions(ctx, nameID, connectorID) - if err != nil { - if err == storage.ErrNotFound { - s.logger.InfoContext(ctx, "SAML SLO: no offline sessions found", "name_id", nameID, "connector_id", connectorID) - return nil // No sessions to invalidate - } - return fmt.Errorf("failed to get offline sessions: %v", err) - } - - // Delete all refresh tokens associated with this offline session - for _, tokenRef := range offlineSessions.Refresh { - if err := s.storage.DeleteRefresh(ctx, tokenRef.ID); err != nil { - if err == storage.ErrNotFound { - continue // Token already deleted - } - s.logger.ErrorContext(ctx, "SAML SLO: failed to delete refresh token", "token_id", tokenRef.ID, "err", err) - // Continue deleting other tokens even if one fails - } - } - - // Delete the offline session itself - if err := s.storage.DeleteOfflineSessions(ctx, nameID, connectorID); err != nil { - if err != storage.ErrNotFound { - return fmt.Errorf("failed to delete offline sessions: %v", err) - } - } - - return nil -} diff --git a/server/handlers_test.go b/server/handlers_test.go index b50e2fe8..6ec9b962 100644 --- a/server/handlers_test.go +++ b/server/handlers_test.go @@ -22,7 +22,6 @@ import ( "golang.org/x/oauth2" "github.com/dexidp/dex/connector" - "github.com/dexidp/dex/server/internal" "github.com/dexidp/dex/storage" ) @@ -895,35 +894,6 @@ func setNonEmpty(vals url.Values, key, value string) { } } -// mockSAMLSLOConnector implements connector.SAMLConnector and connector.SAMLSLOConnector for testing. -type mockSAMLSLOConnector struct { - sloNameID string - sloErr error -} - -func (m *mockSAMLSLOConnector) POSTData(s connector.Scopes, requestID string) (ssoURL, samlRequest string, err error) { - return "", "", nil -} - -func (m *mockSAMLSLOConnector) HandlePOST(s connector.Scopes, samlResponse, inResponseTo string) (connector.Identity, error) { - return connector.Identity{}, nil -} - -func (m *mockSAMLSLOConnector) HandleSLO(w http.ResponseWriter, r *http.Request) (string, error) { - return m.sloNameID, m.sloErr -} - -// mockNonSLOConnector implements connector.CallbackConnector but NOT SAMLSLOConnector. -type mockNonSLOConnector struct{} - -func (m *mockNonSLOConnector) LoginURL(s connector.Scopes, callbackURL, state string) (string, error) { - return "", nil -} - -func (m *mockNonSLOConnector) HandleCallback(s connector.Scopes, r *http.Request) (connector.Identity, error) { - return connector.Identity{}, nil -} - // registerTestConnector creates a connector in storage and registers it in the server's connectors map. func registerTestConnector(t *testing.T, s *Server, connID string, c connector.Connector) { t.Helper() @@ -947,393 +917,79 @@ func registerTestConnector(t *testing.T, s *Server, connID string, c connector.C s.mu.Unlock() } -func TestHandleSAMLSLO(t *testing.T) { - t.Run("SuccessfulSLO", func(t *testing.T) { - httpServer, server := newTestServer(t, nil) - defer httpServer.Close() - - ctx := t.Context() - connID := "saml-slo-test" - nameID := "user@example.com" - - mockConn := &mockSAMLSLOConnector{ - sloNameID: nameID, - } - registerTestConnector(t, server, connID, mockConn) - - // Create refresh tokens and offline session - refreshToken1 := storage.RefreshToken{ - ID: "refresh-token-1", - Token: "token-1", - CreatedAt: time.Now(), - LastUsed: time.Now(), - ClientID: "test-client", - ConnectorID: connID, - Claims: storage.Claims{ - UserID: nameID, - Username: "testuser", - Email: nameID, - }, - Nonce: "nonce-1", - } - refreshToken2 := storage.RefreshToken{ - ID: "refresh-token-2", - Token: "token-2", - CreatedAt: time.Now(), - LastUsed: time.Now(), - ClientID: "test-client-2", - ConnectorID: connID, - Claims: storage.Claims{ - UserID: nameID, - Username: "testuser", - Email: nameID, - }, - Nonce: "nonce-2", - } - require.NoError(t, server.storage.CreateRefresh(ctx, refreshToken1)) - require.NoError(t, server.storage.CreateRefresh(ctx, refreshToken2)) - - offlineSession := storage.OfflineSessions{ - UserID: nameID, - ConnID: connID, - Refresh: map[string]*storage.RefreshTokenRef{ - refreshToken1.ClientID: { - ID: refreshToken1.ID, - ClientID: refreshToken1.ClientID, - CreatedAt: refreshToken1.CreatedAt, - LastUsed: refreshToken1.LastUsed, - }, - refreshToken2.ClientID: { - ID: refreshToken2.ID, - ClientID: refreshToken2.ClientID, - CreatedAt: refreshToken2.CreatedAt, - LastUsed: refreshToken2.LastUsed, - }, - }, - } - require.NoError(t, server.storage.CreateOfflineSessions(ctx, offlineSession)) - - // Send POST to /saml/slo/{connector} - rr := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodPost, "/saml/slo/"+connID, nil) - server.ServeHTTP(rr, req) - - require.Equal(t, http.StatusOK, rr.Code, "expected HTTP 200, got %d: %s", rr.Code, rr.Body.String()) - - // Verify refresh tokens are deleted - _, err := server.storage.GetRefresh(ctx, refreshToken1.ID) - require.ErrorIs(t, err, storage.ErrNotFound, "refresh token 1 should be deleted") - - _, err = server.storage.GetRefresh(ctx, refreshToken2.ID) - require.ErrorIs(t, err, storage.ErrNotFound, "refresh token 2 should be deleted") - - // Verify offline session is deleted - _, err = server.storage.GetOfflineSessions(ctx, nameID, connID) - require.ErrorIs(t, err, storage.ErrNotFound, "offline session should be deleted") - }) - - t.Run("NoExistingSessions", func(t *testing.T) { - httpServer, server := newTestServer(t, nil) - defer httpServer.Close() - - connID := "saml-slo-nosession" - nameID := "nouser@example.com" - - mockConn := &mockSAMLSLOConnector{ - sloNameID: nameID, - } - registerTestConnector(t, server, connID, mockConn) - - // No refresh tokens or offline sessions created — should handle gracefully - rr := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodPost, "/saml/slo/"+connID, nil) - server.ServeHTTP(rr, req) - - require.Equal(t, http.StatusOK, rr.Code, "expected HTTP 200 for no sessions, got %d: %s", rr.Code, rr.Body.String()) - }) - - t.Run("ConnectorNotFound", func(t *testing.T) { - httpServer, server := newTestServer(t, nil) - defer httpServer.Close() - - rr := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodPost, "/saml/slo/nonexistent-connector", nil) - server.ServeHTTP(rr, req) - - require.Equal(t, http.StatusNotFound, rr.Code, "expected HTTP 404 for missing connector") - }) - - t.Run("ConnectorDoesNotSupportSLO", func(t *testing.T) { - httpServer, server := newTestServer(t, nil) - defer httpServer.Close() - - connID := "saml-no-slo" - mockConn := &mockNonSLOConnector{} - registerTestConnector(t, server, connID, mockConn) - - rr := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodPost, "/saml/slo/"+connID, nil) - server.ServeHTTP(rr, req) - - require.Equal(t, http.StatusBadRequest, rr.Code, "expected HTTP 400 for connector without SLO support") - }) - - t.Run("HandleSLOError", func(t *testing.T) { - httpServer, server := newTestServer(t, nil) - defer httpServer.Close() - - connID := "saml-slo-error" - mockConn := &mockSAMLSLOConnector{ - sloNameID: "", - sloErr: errors.New("invalid SAMLRequest"), - } - registerTestConnector(t, server, connID, mockConn) - - rr := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodPost, "/saml/slo/"+connID, nil) - server.ServeHTTP(rr, req) - - require.Equal(t, http.StatusBadRequest, rr.Code, "expected HTTP 400 for invalid SAMLRequest") - }) - - t.Run("MultipleTokensPartialDeletion", func(t *testing.T) { - httpServer, server := newTestServer(t, nil) - defer httpServer.Close() - - ctx := t.Context() - connID := "saml-slo-partial" - nameID := "partial@example.com" - - mockConn := &mockSAMLSLOConnector{ - sloNameID: nameID, - } - registerTestConnector(t, server, connID, mockConn) - - // Create only one refresh token but reference two in offline session - // (simulating a token that was already deleted) - refreshToken := storage.RefreshToken{ - ID: "existing-token", - Token: "token-existing", - CreatedAt: time.Now(), - LastUsed: time.Now(), - ClientID: "client-existing", - ConnectorID: connID, - Claims: storage.Claims{ - UserID: nameID, - Username: "partialuser", - Email: nameID, - }, - Nonce: "nonce-existing", - } - require.NoError(t, server.storage.CreateRefresh(ctx, refreshToken)) - - offlineSession := storage.OfflineSessions{ - UserID: nameID, - ConnID: connID, - Refresh: map[string]*storage.RefreshTokenRef{ - "client-existing": { - ID: "existing-token", - ClientID: "client-existing", - CreatedAt: time.Now(), - LastUsed: time.Now(), - }, - "client-missing": { - ID: "already-deleted-token", - ClientID: "client-missing", - CreatedAt: time.Now(), - LastUsed: time.Now(), - }, - }, - } - require.NoError(t, server.storage.CreateOfflineSessions(ctx, offlineSession)) - - rr := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodPost, "/saml/slo/"+connID, nil) - server.ServeHTTP(rr, req) - - require.Equal(t, http.StatusOK, rr.Code, "expected HTTP 200 even with partial deletion") - - // Verify existing token is deleted - _, err := server.storage.GetRefresh(ctx, "existing-token") - require.ErrorIs(t, err, storage.ErrNotFound, "existing refresh token should be deleted") - - // Verify offline session is deleted - _, err = server.storage.GetOfflineSessions(ctx, nameID, connID) - require.ErrorIs(t, err, storage.ErrNotFound, "offline session should be deleted") +func TestConnectorDataPersistence(t *testing.T) { + // Test that ConnectorData is correctly stored in refresh token + // and can be used for subsequent refresh operations. + httpServer, server := newTestServer(t, func(c *Config) { + c.RefreshTokenPolicy = &RefreshTokenPolicy{rotateRefreshTokens: true} }) + defer httpServer.Close() - t.Run("RefreshAfterSLO", func(t *testing.T) { - // Test that refresh token is rejected after SLO invalidation. - // This is the key integration test: SLO → refresh → error. - httpServer, server := newTestServer(t, func(c *Config) { - c.RefreshTokenPolicy = &RefreshTokenPolicy{rotateRefreshTokens: true} - }) - defer httpServer.Close() - - ctx := t.Context() - connID := "saml-slo-refresh" - nameID := "slo-refresh@example.com" - - mockConn := &mockSAMLSLOConnector{ - sloNameID: nameID, - } - registerTestConnector(t, server, connID, mockConn) - - // Create client for refresh token request - client := storage.Client{ - ID: "slo-test-client", - Secret: "slo-test-secret", - RedirectURIs: []string{"https://example.com/callback"}, - Name: "SLO Test Client", - } - require.NoError(t, server.storage.CreateClient(ctx, client)) - - // Create refresh token - refreshToken := storage.RefreshToken{ - ID: "slo-refresh-token", - Token: "slo-token-value", - CreatedAt: time.Now(), - LastUsed: time.Now(), - ClientID: client.ID, - ConnectorID: connID, - Scopes: []string{"openid", "email", "offline_access"}, - Claims: storage.Claims{ - UserID: nameID, - Username: "slouser", - Email: nameID, - EmailVerified: true, - }, - ConnectorData: []byte(`{"userID":"` + nameID + `","username":"slouser","email":"` + nameID + `","emailVerified":true}`), - Nonce: "slo-nonce", - } - require.NoError(t, server.storage.CreateRefresh(ctx, refreshToken)) - - offlineSession := storage.OfflineSessions{ - UserID: nameID, - ConnID: connID, - Refresh: map[string]*storage.RefreshTokenRef{ - client.ID: { - ID: refreshToken.ID, - ClientID: client.ID, - CreatedAt: refreshToken.CreatedAt, - LastUsed: refreshToken.LastUsed, - }, - }, - } - require.NoError(t, server.storage.CreateOfflineSessions(ctx, offlineSession)) - - // Step 1: Verify refresh token exists before SLO - _, err := server.storage.GetRefresh(ctx, refreshToken.ID) - require.NoError(t, err, "refresh token should exist before SLO") - - // Step 2: Send SLO request - sloRR := httptest.NewRecorder() - sloReq := httptest.NewRequest(http.MethodPost, "/saml/slo/"+connID, nil) - server.ServeHTTP(sloRR, sloReq) - require.Equal(t, http.StatusOK, sloRR.Code, "SLO should succeed") - - // Step 3: Verify refresh token is deleted - _, err = server.storage.GetRefresh(ctx, refreshToken.ID) - require.ErrorIs(t, err, storage.ErrNotFound, "refresh token should be deleted after SLO") - - // Step 4: Attempt to use the refresh token — should fail - tokenData, err := internal.Marshal(&internal.RefreshToken{RefreshId: refreshToken.ID, Token: refreshToken.Token}) - require.NoError(t, err) - - u, err := url.Parse(server.issuerURL.String()) - require.NoError(t, err) - u.Path = path.Join(u.Path, "/token") - - v := url.Values{} - v.Add("grant_type", "refresh_token") - v.Add("refresh_token", tokenData) - - refreshReq, _ := http.NewRequest("POST", u.String(), bytes.NewBufferString(v.Encode())) - refreshReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") - refreshReq.SetBasicAuth(client.ID, client.Secret) + ctx := t.Context() + connID := "saml-conndata" + + // Create a mock SAML connector that also implements RefreshConnector + mockConn := &mockSAMLRefreshConnector{ + refreshIdentity: connector.Identity{ + UserID: "refreshed-user", + Username: "refreshed-name", + Email: "refreshed@example.com", + EmailVerified: true, + Groups: []string{"refreshed-group"}, + }, + } + registerTestConnector(t, server, connID, mockConn) - refreshRR := httptest.NewRecorder() - server.ServeHTTP(refreshRR, refreshReq) + // Create client + client := storage.Client{ + ID: "conndata-client", + Secret: "conndata-secret", + RedirectURIs: []string{"https://example.com/callback"}, + Name: "ConnData Test Client", + } + require.NoError(t, server.storage.CreateClient(ctx, client)) + + // Create refresh token with ConnectorData (simulating what HandlePOST would store) + connectorData := []byte(`{"userID":"user-123","username":"testuser","email":"test@example.com","emailVerified":true,"groups":["admin","dev"]}`) + refreshToken := storage.RefreshToken{ + ID: "conndata-refresh", + Token: "conndata-token", + CreatedAt: time.Now(), + LastUsed: time.Now(), + ClientID: client.ID, + ConnectorID: connID, + Scopes: []string{"openid", "email", "offline_access"}, + Claims: storage.Claims{ + UserID: "user-123", + Username: "testuser", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"admin", "dev"}, + }, + ConnectorData: connectorData, + Nonce: "conndata-nonce", + } + require.NoError(t, server.storage.CreateRefresh(ctx, refreshToken)) - // Refresh should fail because the token was deleted by SLO - require.NotEqual(t, http.StatusOK, refreshRR.Code, - "refresh should fail after SLO, got status %d: %s", refreshRR.Code, refreshRR.Body.String()) - }) + offlineSession := storage.OfflineSessions{ + UserID: "user-123", + ConnID: connID, + Refresh: map[string]*storage.RefreshTokenRef{client.ID: {ID: refreshToken.ID, ClientID: client.ID}}, + ConnectorData: connectorData, + } + require.NoError(t, server.storage.CreateOfflineSessions(ctx, offlineSession)) - t.Run("ConnectorDataPersistence", func(t *testing.T) { - // Test that ConnectorData is correctly stored in refresh token - // and can be used for subsequent refresh operations. - httpServer, server := newTestServer(t, func(c *Config) { - c.RefreshTokenPolicy = &RefreshTokenPolicy{rotateRefreshTokens: true} - }) - defer httpServer.Close() - - ctx := t.Context() - connID := "saml-conndata" - - // Create a mock SAML connector that also implements RefreshConnector - mockConn := &mockSAMLRefreshConnector{ - refreshIdentity: connector.Identity{ - UserID: "refreshed-user", - Username: "refreshed-name", - Email: "refreshed@example.com", - EmailVerified: true, - Groups: []string{"refreshed-group"}, - }, - } - registerTestConnector(t, server, connID, mockConn) - - // Create client - client := storage.Client{ - ID: "conndata-client", - Secret: "conndata-secret", - RedirectURIs: []string{"https://example.com/callback"}, - Name: "ConnData Test Client", - } - require.NoError(t, server.storage.CreateClient(ctx, client)) - - // Create refresh token with ConnectorData (simulating what HandlePOST would store) - connectorData := []byte(`{"userID":"user-123","username":"testuser","email":"test@example.com","emailVerified":true,"groups":["admin","dev"]}`) - refreshToken := storage.RefreshToken{ - ID: "conndata-refresh", - Token: "conndata-token", - CreatedAt: time.Now(), - LastUsed: time.Now(), - ClientID: client.ID, - ConnectorID: connID, - Scopes: []string{"openid", "email", "offline_access"}, - Claims: storage.Claims{ - UserID: "user-123", - Username: "testuser", - Email: "test@example.com", - EmailVerified: true, - Groups: []string{"admin", "dev"}, - }, - ConnectorData: connectorData, - Nonce: "conndata-nonce", - } - require.NoError(t, server.storage.CreateRefresh(ctx, refreshToken)) + // Verify ConnectorData is stored correctly + storedToken, err := server.storage.GetRefresh(ctx, refreshToken.ID) + require.NoError(t, err) + require.Equal(t, connectorData, storedToken.ConnectorData, + "ConnectorData should be persisted in refresh token storage") - offlineSession := storage.OfflineSessions{ - UserID: "user-123", - ConnID: connID, - Refresh: map[string]*storage.RefreshTokenRef{client.ID: {ID: refreshToken.ID, ClientID: client.ID}}, - ConnectorData: connectorData, - } - require.NoError(t, server.storage.CreateOfflineSessions(ctx, offlineSession)) - - // Verify ConnectorData is stored correctly - storedToken, err := server.storage.GetRefresh(ctx, refreshToken.ID) - require.NoError(t, err) - require.Equal(t, connectorData, storedToken.ConnectorData, - "ConnectorData should be persisted in refresh token storage") - - // Verify ConnectorData is stored in offline session - storedSession, err := server.storage.GetOfflineSessions(ctx, "user-123", connID) - require.NoError(t, err) - require.Equal(t, connectorData, storedSession.ConnectorData, - "ConnectorData should be persisted in offline session storage") - }) + // Verify ConnectorData is stored in offline session + storedSession, err := server.storage.GetOfflineSessions(ctx, "user-123", connID) + require.NoError(t, err) + require.Equal(t, connectorData, storedSession.ConnectorData, + "ConnectorData should be persisted in offline session storage") } // mockSAMLRefreshConnector implements SAMLConnector + RefreshConnector for testing. diff --git a/server/server.go b/server/server.go index e76611a5..e923e3e0 100644 --- a/server/server.go +++ b/server/server.go @@ -494,7 +494,6 @@ func newServer(ctx context.Context, c Config) (*Server, error) { // For easier connector-specific web server configuration, e.g. for the // "authproxy" connector. handleFunc("/callback/{connector}", s.handleConnectorCallback) - handleFunc("/saml/slo/{connector}", s.handleSAMLSLO) handleFunc("/approval", s.handleApproval) handle("/healthz", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if !c.HealthChecker.IsHealthy() {