diff --git a/connector/connector.go b/connector/connector.go index d812390f..4b9131d4 100644 --- a/connector/connector.go +++ b/connector/connector.go @@ -92,6 +92,16 @@ type SAMLConnector interface { HandlePOST(s Scopes, samlResponse, inResponseTo string) (identity Identity, err error) } +// SAMLSLOConnector represents a SAML connector that supports Single Logout (SLO). +// When an IdP sends a LogoutRequest, the connector parses it and returns the +// NameID of the user being logged out. +type SAMLSLOConnector interface { + // HandleSLO processes a SAML LogoutRequest from the IdP. + // It validates the request signature and returns the NameID (user identifier) + // that should be used to invalidate the user's sessions. + HandleSLO(w http.ResponseWriter, r *http.Request) (nameID string, err error) +} + // RefreshConnector is a connector that can update the client claims. type RefreshConnector interface { // Refresh is called when a client attempts to claim a refresh token. The diff --git a/connector/saml/saml.go b/connector/saml/saml.go index 3e44b477..bcfcccbe 100644 --- a/connector/saml/saml.go +++ b/connector/saml/saml.go @@ -3,12 +3,15 @@ package saml import ( "bytes" + "context" "crypto/x509" "encoding/base64" + "encoding/json" "encoding/pem" "encoding/xml" "fmt" "log/slog" + "net/http" "os" "strings" "sync" @@ -82,6 +85,11 @@ type Config struct { InsecureSkipSignatureValidation bool `json:"insecureSkipSignatureValidation"` + // InsecureSkipSLOSignatureValidation skips signature validation on SLO requests. + // This is insecure and should only be used for testing or when the IdP + // does not sign LogoutRequests. + InsecureSkipSLOSignatureValidation bool `json:"insecureSkipSLOSignatureValidation"` + // Assertion attribute names to lookup various claims with. UsernameAttr string `json:"usernameAttr"` EmailAttr string `json:"emailAttr"` @@ -162,6 +170,8 @@ func (c *Config) openConnector(logger *slog.Logger) (*provider, error) { logger: logger, nameIDPolicyFormat: c.NameIDPolicyFormat, + + insecureSkipSLOSignatureValidation: c.InsecureSkipSLOSignatureValidation, } if p.nameIDPolicyFormat == "" { @@ -252,9 +262,47 @@ type provider struct { nameIDPolicyFormat string + insecureSkipSLOSignatureValidation bool + logger *slog.Logger } +// Compile-time check that provider implements RefreshConnector +var _ connector.RefreshConnector = (*provider)(nil) + +// Compile-time check that provider implements SAMLSLOConnector +var _ connector.SAMLSLOConnector = (*provider)(nil) + +// cachedIdentity stores the identity from SAML assertion for refresh token support. +// Since SAML has no native refresh mechanism, we cache the identity obtained during +// the initial authentication and return it on subsequent refresh requests. +type cachedIdentity struct { + UserID string `json:"userId"` + Username string `json:"username"` + PreferredUsername string `json:"preferredUsername"` + Email string `json:"email"` + EmailVerified bool `json:"emailVerified"` + Groups []string `json:"groups,omitempty"` +} + +// marshalCachedIdentity serializes the identity into ConnectorData for refresh token support. +func marshalCachedIdentity(ident connector.Identity) (connector.Identity, error) { + ci := cachedIdentity{ + UserID: ident.UserID, + Username: ident.Username, + PreferredUsername: ident.PreferredUsername, + Email: ident.Email, + EmailVerified: ident.EmailVerified, + Groups: ident.Groups, + } + connectorData, err := json.Marshal(ci) + if err != nil { + return ident, fmt.Errorf("saml: failed to marshal cached identity: %v", err) + } + ident.ConnectorData = connectorData + return ident, nil +} + func (p *provider) POSTData(s connector.Scopes, id string) (action, value string, err error) { r := &authnRequest{ ProtocolBinding: bindingPOST, @@ -405,7 +453,7 @@ func (p *provider) HandlePOST(s connector.Scopes, samlResponse, inResponseTo str if len(p.allowedGroups) == 0 && (!s.Groups || p.groupsAttr == "") { // Groups not requested or not configured. We're done. - return ident, nil + return marshalCachedIdentity(ident) } if len(p.allowedGroups) > 0 && (!s.Groups || p.groupsAttr == "") { @@ -431,7 +479,7 @@ func (p *provider) HandlePOST(s connector.Scopes, samlResponse, inResponseTo str if len(p.allowedGroups) == 0 { // No allowed groups set, just return the ident - return ident, nil + return marshalCachedIdentity(ident) } // Look for membership in one of the allowed groups @@ -447,6 +495,35 @@ func (p *provider) HandlePOST(s connector.Scopes, samlResponse, inResponseTo str } // Otherwise, we're good + return marshalCachedIdentity(ident) +} + +// Refresh implements connector.RefreshConnector. +// Since SAML has no native refresh mechanism, this method returns the cached +// identity from the initial SAML assertion stored in ConnectorData. +func (p *provider) Refresh(ctx context.Context, s connector.Scopes, ident connector.Identity) (connector.Identity, error) { + if len(ident.ConnectorData) == 0 { + return ident, fmt.Errorf("saml: no connector data available for refresh") + } + + var ci cachedIdentity + if err := json.Unmarshal(ident.ConnectorData, &ci); err != nil { + return ident, fmt.Errorf("saml: failed to unmarshal cached identity: %v", err) + } + + ident.UserID = ci.UserID + ident.Username = ci.Username + ident.PreferredUsername = ci.PreferredUsername + ident.Email = ci.Email + ident.EmailVerified = ci.EmailVerified + + // Only populate groups if the client requested the groups scope. + if s.Groups { + ident.Groups = ci.Groups + } else { + ident.Groups = nil + } + return ident, nil } @@ -645,3 +722,64 @@ func before(now, notBefore time.Time) bool { func after(now, notOnOrAfter time.Time) bool { return now.After(notOnOrAfter.Add(allowedClockDrift)) } + +// validateSignature validates the XML digital signature of the given raw XML data. +func (p *provider) validateSignature(rawXML []byte) ([]byte, error) { + doc := etree.NewDocument() + if err := doc.ReadFromBytes(rawXML); err != nil { + return nil, fmt.Errorf("failed to parse XML: %v", err) + } + + // Find the Signature element + root := doc.Root() + if root == nil { + return nil, fmt.Errorf("empty XML document") + } + + _, err := p.validator.Validate(root) + if err != nil { + return nil, fmt.Errorf("signature validation failed: %v", err) + } + + return rawXML, nil +} + +// HandleSLO processes a SAML LogoutRequest from the IdP. +// It validates the request, extracts the NameID, and returns it for session invalidation. +func (p *provider) HandleSLO(w http.ResponseWriter, r *http.Request) (string, error) { + if r.Method != http.MethodPost { + return "", fmt.Errorf("saml slo: expected POST method, got %s", r.Method) + } + + if err := r.ParseForm(); err != nil { + return "", fmt.Errorf("saml slo: failed to parse form: %v", err) + } + + samlRequest := r.FormValue("SAMLRequest") + if samlRequest == "" { + return "", fmt.Errorf("saml slo: missing SAMLRequest parameter") + } + + rawRequest, err := base64.StdEncoding.DecodeString(samlRequest) + if err != nil { + return "", fmt.Errorf("saml slo: failed to decode SAMLRequest: %v", err) + } + + // Validate signature unless explicitly skipped + if !p.insecureSkipSLOSignatureValidation { + if _, err := p.validateSignature(rawRequest); err != nil { + return "", fmt.Errorf("saml slo: signature validation failed: %v", err) + } + } + + var req logoutRequest + if err := xml.Unmarshal(rawRequest, &req); err != nil { + return "", fmt.Errorf("saml slo: failed to unmarshal LogoutRequest: %v", err) + } + + if req.NameID.Value == "" { + return "", fmt.Errorf("saml slo: LogoutRequest missing NameID") + } + + return req.NameID.Value, nil +} diff --git a/connector/saml/saml_test.go b/connector/saml/saml_test.go index 03e891fe..ecea35f5 100644 --- a/connector/saml/saml_test.go +++ b/connector/saml/saml_test.go @@ -1,13 +1,20 @@ package saml import ( + "context" "crypto/x509" "encoding/base64" + "encoding/json" "encoding/pem" "errors" + "fmt" "log/slog" + "net/http" + "net/http/httptest" + "net/url" "os" "sort" + "strings" "testing" "time" @@ -448,6 +455,24 @@ func (r responseTest) run(t *testing.T) { } sort.Strings(ident.Groups) sort.Strings(r.wantIdent.Groups) + + // Verify ConnectorData contains valid cached identity, then clear it + // for the main identity comparison (ConnectorData is an implementation + // detail of refresh token support). + if len(ident.ConnectorData) > 0 { + var ci cachedIdentity + if err := json.Unmarshal(ident.ConnectorData, &ci); err != nil { + t.Fatalf("failed to unmarshal ConnectorData: %v", err) + } + if ci.UserID != ident.UserID { + t.Errorf("cached identity UserID mismatch: got %q, want %q", ci.UserID, ident.UserID) + } + if ci.Email != ident.Email { + t.Errorf("cached identity Email mismatch: got %q, want %q", ci.Email, ident.Email) + } + } + ident.ConnectorData = nil + if diff := pretty.Compare(ident, r.wantIdent); diff != "" { t.Error(diff) } @@ -589,3 +614,458 @@ func TestVerifySignedMessageAndSignedAssertion(t *testing.T) { func TestVerifyUnsignedMessageAndUnsignedAssertion(t *testing.T) { runVerify(t, "testdata/idp-cert.pem", "testdata/idp-resp.xml", false) } + +func TestSAMLRefresh(t *testing.T) { + // Create a provider using the same pattern as existing tests. + c := Config{ + CA: "testdata/ca.crt", + UsernameAttr: "Name", + EmailAttr: "email", + GroupsAttr: "groups", + RedirectURI: "http://127.0.0.1:5556/dex/callback", + SSOURL: "http://foo.bar/", + } + + conn, err := c.openConnector(slog.New(slog.DiscardHandler)) + if err != nil { + t.Fatal(err) + } + + t.Run("SuccessfulRefresh", func(t *testing.T) { + ci := cachedIdentity{ + UserID: "test-user-id", + Username: "testuser", + PreferredUsername: "testuser", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"group1", "group2"}, + } + connectorData, err := json.Marshal(ci) + if err != nil { + t.Fatal(err) + } + + ident := connector.Identity{ + UserID: "old-id", + Username: "old-name", + ConnectorData: connectorData, + } + + refreshed, err := conn.Refresh(context.Background(), connector.Scopes{Groups: true}, ident) + if err != nil { + t.Fatalf("Refresh failed: %v", err) + } + + if refreshed.UserID != "test-user-id" { + t.Errorf("expected UserID %q, got %q", "test-user-id", refreshed.UserID) + } + if refreshed.Username != "testuser" { + t.Errorf("expected Username %q, got %q", "testuser", refreshed.Username) + } + if refreshed.PreferredUsername != "testuser" { + t.Errorf("expected PreferredUsername %q, got %q", "testuser", refreshed.PreferredUsername) + } + if refreshed.Email != "test@example.com" { + t.Errorf("expected Email %q, got %q", "test@example.com", refreshed.Email) + } + if !refreshed.EmailVerified { + t.Error("expected EmailVerified to be true") + } + if len(refreshed.Groups) != 2 || refreshed.Groups[0] != "group1" || refreshed.Groups[1] != "group2" { + t.Errorf("expected groups [group1, group2], got %v", refreshed.Groups) + } + // ConnectorData should be preserved through refresh + if len(refreshed.ConnectorData) == 0 { + t.Error("expected ConnectorData to be preserved") + } + }) + + t.Run("RefreshPreservesConnectorData", func(t *testing.T) { + ci := cachedIdentity{ + UserID: "user-123", + Username: "alice", + Email: "alice@example.com", + EmailVerified: true, + } + connectorData, err := json.Marshal(ci) + if err != nil { + t.Fatal(err) + } + + ident := connector.Identity{ + UserID: "old-id", + ConnectorData: connectorData, + } + + refreshed, err := conn.Refresh(context.Background(), connector.Scopes{}, ident) + if err != nil { + t.Fatalf("Refresh failed: %v", err) + } + + // Verify the refreshed identity can be refreshed again (round-trip) + var roundTrip cachedIdentity + if err := json.Unmarshal(refreshed.ConnectorData, &roundTrip); err != nil { + t.Fatalf("failed to unmarshal ConnectorData after refresh: %v", err) + } + if roundTrip.UserID != "user-123" { + t.Errorf("round-trip UserID mismatch: got %q, want %q", roundTrip.UserID, "user-123") + } + }) + + t.Run("EmptyConnectorData", func(t *testing.T) { + ident := connector.Identity{ + UserID: "test-id", + ConnectorData: nil, + } + _, err := conn.Refresh(context.Background(), connector.Scopes{}, ident) + if err == nil { + t.Error("expected error for empty ConnectorData") + } + }) + + t.Run("InvalidJSON", func(t *testing.T) { + ident := connector.Identity{ + UserID: "test-id", + ConnectorData: []byte("not-json"), + } + _, err := conn.Refresh(context.Background(), connector.Scopes{}, ident) + if err == nil { + t.Error("expected error for invalid JSON") + } + }) + + t.Run("HandlePOSTThenRefresh", func(t *testing.T) { + // Full integration: HandlePOST → get ConnectorData → Refresh → verify identity + now, err := time.Parse(timeFormat, "2017-04-04T04:34:59.330Z") + if err != nil { + t.Fatal(err) + } + conn.now = func() time.Time { return now } + + resp, err := os.ReadFile("testdata/good-resp.xml") + if err != nil { + t.Fatal(err) + } + samlResp := base64.StdEncoding.EncodeToString(resp) + + scopes := connector.Scopes{ + OfflineAccess: true, + Groups: true, + } + ident, err := conn.HandlePOST(scopes, samlResp, "6zmm5mguyebwvajyf2sdwwcw6m") + if err != nil { + t.Fatalf("HandlePOST failed: %v", err) + } + + if len(ident.ConnectorData) == 0 { + t.Fatal("expected ConnectorData to be set after HandlePOST") + } + + // Now refresh using the ConnectorData from HandlePOST + refreshed, err := conn.Refresh(context.Background(), scopes, ident) + if err != nil { + t.Fatalf("Refresh failed: %v", err) + } + + if refreshed.UserID != ident.UserID { + t.Errorf("UserID mismatch: got %q, want %q", refreshed.UserID, ident.UserID) + } + if refreshed.Username != ident.Username { + t.Errorf("Username mismatch: got %q, want %q", refreshed.Username, ident.Username) + } + if refreshed.Email != ident.Email { + t.Errorf("Email mismatch: got %q, want %q", refreshed.Email, ident.Email) + } + if refreshed.EmailVerified != ident.EmailVerified { + t.Errorf("EmailVerified mismatch: got %v, want %v", refreshed.EmailVerified, ident.EmailVerified) + } + sort.Strings(refreshed.Groups) + sort.Strings(ident.Groups) + if len(refreshed.Groups) != len(ident.Groups) { + t.Errorf("Groups length mismatch: got %d, want %d", len(refreshed.Groups), len(ident.Groups)) + } + for i := range ident.Groups { + if i < len(refreshed.Groups) && refreshed.Groups[i] != ident.Groups[i] { + t.Errorf("Groups[%d] mismatch: got %q, want %q", i, refreshed.Groups[i], ident.Groups[i]) + } + } + }) + + t.Run("HandlePOSTThenDoubleRefresh", func(t *testing.T) { + // Verify that refresh tokens can be chained: HandlePOST → Refresh → Refresh + now, err := time.Parse(timeFormat, "2017-04-04T04:34:59.330Z") + if err != nil { + t.Fatal(err) + } + conn.now = func() time.Time { return now } + + resp, err := os.ReadFile("testdata/good-resp.xml") + if err != nil { + t.Fatal(err) + } + samlResp := base64.StdEncoding.EncodeToString(resp) + + scopes := connector.Scopes{OfflineAccess: true, Groups: true} + ident, err := conn.HandlePOST(scopes, samlResp, "6zmm5mguyebwvajyf2sdwwcw6m") + if err != nil { + t.Fatalf("HandlePOST failed: %v", err) + } + + // First refresh + refreshed1, err := conn.Refresh(context.Background(), scopes, ident) + if err != nil { + t.Fatalf("first Refresh failed: %v", err) + } + if len(refreshed1.ConnectorData) == 0 { + t.Fatal("expected ConnectorData after first refresh") + } + + // Second refresh using output of first refresh + refreshed2, err := conn.Refresh(context.Background(), scopes, refreshed1) + if err != nil { + t.Fatalf("second Refresh failed: %v", err) + } + + // All fields should match original + if refreshed2.UserID != ident.UserID { + t.Errorf("UserID mismatch after double refresh: got %q, want %q", refreshed2.UserID, ident.UserID) + } + if refreshed2.Email != ident.Email { + t.Errorf("Email mismatch after double refresh: got %q, want %q", refreshed2.Email, ident.Email) + } + if refreshed2.Username != ident.Username { + t.Errorf("Username mismatch after double refresh: got %q, want %q", refreshed2.Username, ident.Username) + } + }) + + t.Run("HandlePOSTWithAssertionSignedThenRefresh", func(t *testing.T) { + // Test with assertion-signed.xml (signature on assertion, not response) + now, err := time.Parse(timeFormat, "2017-04-04T04:34:59.330Z") + if err != nil { + t.Fatal(err) + } + conn.now = func() time.Time { return now } + + resp, err := os.ReadFile("testdata/assertion-signed.xml") + if err != nil { + t.Fatal(err) + } + samlResp := base64.StdEncoding.EncodeToString(resp) + + scopes := connector.Scopes{OfflineAccess: true, Groups: true} + ident, err := conn.HandlePOST(scopes, samlResp, "6zmm5mguyebwvajyf2sdwwcw6m") + if err != nil { + t.Fatalf("HandlePOST with assertion-signed failed: %v", err) + } + + if len(ident.ConnectorData) == 0 { + t.Fatal("expected ConnectorData after HandlePOST with assertion-signed") + } + + refreshed, err := conn.Refresh(context.Background(), scopes, ident) + if err != nil { + t.Fatalf("Refresh after assertion-signed HandlePOST failed: %v", err) + } + + if refreshed.Email != ident.Email { + t.Errorf("Email mismatch: got %q, want %q", refreshed.Email, ident.Email) + } + if refreshed.Username != ident.Username { + t.Errorf("Username mismatch: got %q, want %q", refreshed.Username, ident.Username) + } + }) + + t.Run("HandlePOSTRefreshWithoutGroupsScope", func(t *testing.T) { + // Verify that groups are NOT returned when groups scope is not requested during refresh + now, err := time.Parse(timeFormat, "2017-04-04T04:34:59.330Z") + if err != nil { + t.Fatal(err) + } + conn.now = func() time.Time { return now } + + resp, err := os.ReadFile("testdata/good-resp.xml") + if err != nil { + t.Fatal(err) + } + samlResp := base64.StdEncoding.EncodeToString(resp) + + // Initial auth WITH groups + scopesWithGroups := connector.Scopes{OfflineAccess: true, Groups: true} + ident, err := conn.HandlePOST(scopesWithGroups, samlResp, "6zmm5mguyebwvajyf2sdwwcw6m") + if err != nil { + t.Fatalf("HandlePOST failed: %v", err) + } + if len(ident.Groups) == 0 { + t.Fatal("expected groups in initial identity") + } + + // Refresh WITHOUT groups scope + scopesNoGroups := connector.Scopes{OfflineAccess: true, Groups: false} + refreshed, err := conn.Refresh(context.Background(), scopesNoGroups, ident) + if err != nil { + t.Fatalf("Refresh failed: %v", err) + } + + if len(refreshed.Groups) != 0 { + t.Errorf("expected no groups when groups scope not requested, got %v", refreshed.Groups) + } + + // Refresh WITH groups scope — groups should be back + refreshedWithGroups, err := conn.Refresh(context.Background(), scopesWithGroups, ident) + if err != nil { + t.Fatalf("Refresh with groups failed: %v", err) + } + + if len(refreshedWithGroups.Groups) == 0 { + t.Error("expected groups when groups scope is requested") + } + }) +} + +func TestSAMLHandleSLO(t *testing.T) { + c := Config{ + CA: "testdata/ca.crt", + UsernameAttr: "Name", + EmailAttr: "email", + RedirectURI: "http://127.0.0.1:5556/dex/callback", + SSOURL: "http://foo.bar/", + InsecureSkipSLOSignatureValidation: true, + } + + conn, err := c.openConnector(slog.New(slog.DiscardHandler)) + if err != nil { + t.Fatal(err) + } + + // Helper to create a LogoutRequest XML + makeLogoutRequest := func(nameID string) string { + return fmt.Sprintf(` + https://idp.example.com + %s +`, nameID) + } + + t.Run("ValidLogoutRequest", func(t *testing.T) { + logoutXML := makeLogoutRequest("user@example.com") + encoded := base64.StdEncoding.EncodeToString([]byte(logoutXML)) + + form := url.Values{} + form.Set("SAMLRequest", encoded) + + req := httptest.NewRequest(http.MethodPost, "/saml/slo/test", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + w := httptest.NewRecorder() + + nameID, err := conn.HandleSLO(w, req) + if err != nil { + t.Fatalf("HandleSLO failed: %v", err) + } + if nameID != "user@example.com" { + t.Errorf("expected nameID %q, got %q", "user@example.com", nameID) + } + }) + + t.Run("MissingSAMLRequest", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/saml/slo/test", strings.NewReader("")) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + w := httptest.NewRecorder() + + _, err := conn.HandleSLO(w, req) + if err == nil { + t.Error("expected error for missing SAMLRequest") + } + }) + + t.Run("InvalidBase64", func(t *testing.T) { + form := url.Values{} + form.Set("SAMLRequest", "not-valid-base64!!!") + + req := httptest.NewRequest(http.MethodPost, "/saml/slo/test", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + w := httptest.NewRecorder() + + _, err := conn.HandleSLO(w, req) + if err == nil { + t.Error("expected error for invalid base64") + } + }) + + t.Run("InvalidXML", func(t *testing.T) { + encoded := base64.StdEncoding.EncodeToString([]byte("not xml at all")) + form := url.Values{} + form.Set("SAMLRequest", encoded) + + req := httptest.NewRequest(http.MethodPost, "/saml/slo/test", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + w := httptest.NewRecorder() + + _, err := conn.HandleSLO(w, req) + if err == nil { + t.Error("expected error for invalid XML") + } + }) + + t.Run("MissingNameID", func(t *testing.T) { + logoutXML := ` + https://idp.example.com + +` + encoded := base64.StdEncoding.EncodeToString([]byte(logoutXML)) + + form := url.Values{} + form.Set("SAMLRequest", encoded) + + req := httptest.NewRequest(http.MethodPost, "/saml/slo/test", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + w := httptest.NewRecorder() + + _, err := conn.HandleSLO(w, req) + if err == nil { + t.Error("expected error for missing NameID") + } + }) + + t.Run("WrongHTTPMethod", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/saml/slo/test", nil) + w := httptest.NewRecorder() + + _, err := conn.HandleSLO(w, req) + if err == nil { + t.Error("expected error for GET method") + } + }) + + t.Run("DifferentNameIDValues", func(t *testing.T) { + testCases := []struct { + name string + nameIDVal string + wantNameID string + }{ + {"email format", "admin@corp.example.com", "admin@corp.example.com"}, + {"persistent ID", "AQIC5w...", "AQIC5w..."}, + {"transient ID", "_ce3d2948b4cf20146dee0a0b3dd6f69b6cf86f62d7", "_ce3d2948b4cf20146dee0a0b3dd6f69b6cf86f62d7"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + logoutXML := makeLogoutRequest(tc.nameIDVal) + encoded := base64.StdEncoding.EncodeToString([]byte(logoutXML)) + + form := url.Values{} + form.Set("SAMLRequest", encoded) + + req := httptest.NewRequest(http.MethodPost, "/saml/slo/test", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + w := httptest.NewRecorder() + + nameID, err := conn.HandleSLO(w, req) + if err != nil { + t.Fatalf("HandleSLO failed: %v", err) + } + if nameID != tc.wantNameID { + t.Errorf("expected nameID %q, got %q", tc.wantNameID, nameID) + } + }) + } + }) +} diff --git a/connector/saml/types.go b/connector/saml/types.go index c8d7e7f3..d63bcbe0 100644 --- a/connector/saml/types.go +++ b/connector/saml/types.go @@ -275,3 +275,21 @@ func (a attribute) String() string { // "groups" = ["engineering", "docs"] return fmt.Sprintf("%q = %q", a.Name, values) } + +// logoutRequest represents a SAML 2.0 LogoutRequest message sent by the IdP +// to initiate Single Logout. +type logoutRequest struct { + XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:protocol LogoutRequest"` + ID string `xml:"ID,attr"` + Version string `xml:"Version,attr"` + IssueInstant xmlTime `xml:"IssueInstant,attr"` + Destination string `xml:"Destination,attr,omitempty"` + Issuer string `xml:"urn:oasis:names:tc:SAML:2.0:assertion Issuer"` + NameID nameID `xml:"urn:oasis:names:tc:SAML:2.0:assertion NameID"` + SessionIndex []sessionIndex `xml:"SessionIndex,omitempty"` +} + +// sessionIndex represents a SAML SessionIndex element in a LogoutRequest. +type sessionIndex struct { + Value string `xml:",chardata"` +} diff --git a/server/handlers.go b/server/handlers.go index 1292b8ee..0bd73de1 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -1054,9 +1054,10 @@ func (s *Server) exchangeAuthCode(ctx context.Context, w http.ResponseWriter, au return nil, err } offlineSessions := storage.OfflineSessions{ - UserID: refresh.Claims.UserID, - ConnID: refresh.ConnectorID, - Refresh: make(map[string]*storage.RefreshTokenRef), + UserID: refresh.Claims.UserID, + ConnID: refresh.ConnectorID, + Refresh: make(map[string]*storage.RefreshTokenRef), + ConnectorData: refresh.ConnectorData, } offlineSessions.Refresh[tokenRef.ClientID] = &tokenRef @@ -1082,6 +1083,9 @@ func (s *Server) exchangeAuthCode(ctx context.Context, w http.ResponseWriter, au // Update existing OfflineSession obj with new RefreshTokenRef. if err := s.storage.UpdateOfflineSessions(ctx, session.UserID, session.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) { old.Refresh[tokenRef.ClientID] = &tokenRef + if len(refresh.ConnectorData) > 0 { + old.ConnectorData = refresh.ConnectorData + } return old, nil }); err != nil { s.logger.ErrorContext(ctx, "failed to update offline session", "err", err) @@ -1500,3 +1504,84 @@ func usernamePrompt(conn connector.PasswordConnector) string { } return "Username" } + +func (s *Server) handleSAMLSLO(w http.ResponseWriter, r *http.Request) { + connID, err := url.PathUnescape(mux.Vars(r)["connector"]) + if err != nil { + s.logger.ErrorContext(r.Context(), "SAML SLO: failed to parse connector ID", "err", err) + http.Error(w, "Missing connector ID", http.StatusBadRequest) + return + } + + ctx := r.Context() + + conn, err := s.getConnector(ctx, connID) + if err != nil { + s.logger.ErrorContext(ctx, "SAML SLO: failed to get connector", "connector_id", connID, "err", err) + http.Error(w, "Connector not found", http.StatusNotFound) + return + } + + sloConnector, ok := conn.Connector.(connector.SAMLSLOConnector) + if !ok { + s.logger.ErrorContext(ctx, "SAML SLO: connector does not support SLO", "connector_id", connID) + http.Error(w, "Connector does not support SLO", http.StatusBadRequest) + return + } + + nameID, err := sloConnector.HandleSLO(w, r) + if err != nil { + s.logger.ErrorContext(ctx, "SAML SLO: failed to process LogoutRequest", "connector_id", connID, "err", err) + http.Error(w, "Failed to process LogoutRequest", http.StatusBadRequest) + return + } + + s.logger.InfoContext(ctx, "SAML SLO: processing logout", "connector_id", connID, "name_id", nameID) + + if err := s.invalidateUserSessions(ctx, connID, nameID); err != nil { + s.logger.ErrorContext(ctx, "SAML SLO: failed to invalidate sessions", "connector_id", connID, "name_id", nameID, "err", err) + http.Error(w, "Failed to invalidate sessions", http.StatusInternalServerError) + return + } + + s.logger.InfoContext(ctx, "SAML SLO: successfully invalidated sessions", "connector_id", connID, "name_id", nameID) + w.WriteHeader(http.StatusOK) +} + +// invalidateUserSessions removes all refresh tokens for a user identified by +// their NameID from a specific connector. It uses OfflineSessions to find +// the user's refresh tokens and deletes them. +func (s *Server) invalidateUserSessions(ctx context.Context, connectorID, nameID string) error { + // NameID from SAML is used as the user ID in the connector. + // OfflineSessions are keyed by (userID, connectorID). + // The userID in OfflineSessions is the connector-specific user ID (NameID for SAML). + + offlineSessions, err := s.storage.GetOfflineSessions(ctx, nameID, connectorID) + if err != nil { + if err == storage.ErrNotFound { + s.logger.InfoContext(ctx, "SAML SLO: no offline sessions found", "name_id", nameID, "connector_id", connectorID) + return nil // No sessions to invalidate + } + return fmt.Errorf("failed to get offline sessions: %v", err) + } + + // Delete all refresh tokens associated with this offline session + for _, tokenRef := range offlineSessions.Refresh { + if err := s.storage.DeleteRefresh(ctx, tokenRef.ID); err != nil { + if err == storage.ErrNotFound { + continue // Token already deleted + } + s.logger.ErrorContext(ctx, "SAML SLO: failed to delete refresh token", "token_id", tokenRef.ID, "err", err) + // Continue deleting other tokens even if one fails + } + } + + // Delete the offline session itself + if err := s.storage.DeleteOfflineSessions(ctx, nameID, connectorID); err != nil { + if err != storage.ErrNotFound { + return fmt.Errorf("failed to delete offline sessions: %v", err) + } + } + + return nil +} diff --git a/server/handlers_test.go b/server/handlers_test.go index 0514d85c..b50e2fe8 100644 --- a/server/handlers_test.go +++ b/server/handlers_test.go @@ -21,6 +21,8 @@ import ( "golang.org/x/crypto/bcrypt" "golang.org/x/oauth2" + "github.com/dexidp/dex/connector" + "github.com/dexidp/dex/server/internal" "github.com/dexidp/dex/storage" ) @@ -892,3 +894,461 @@ func setNonEmpty(vals url.Values, key, value string) { vals.Set(key, value) } } + +// mockSAMLSLOConnector implements connector.SAMLConnector and connector.SAMLSLOConnector for testing. +type mockSAMLSLOConnector struct { + sloNameID string + sloErr error +} + +func (m *mockSAMLSLOConnector) POSTData(s connector.Scopes, requestID string) (ssoURL, samlRequest string, err error) { + return "", "", nil +} + +func (m *mockSAMLSLOConnector) HandlePOST(s connector.Scopes, samlResponse, inResponseTo string) (connector.Identity, error) { + return connector.Identity{}, nil +} + +func (m *mockSAMLSLOConnector) HandleSLO(w http.ResponseWriter, r *http.Request) (string, error) { + return m.sloNameID, m.sloErr +} + +// mockNonSLOConnector implements connector.CallbackConnector but NOT SAMLSLOConnector. +type mockNonSLOConnector struct{} + +func (m *mockNonSLOConnector) LoginURL(s connector.Scopes, callbackURL, state string) (string, error) { + return "", nil +} + +func (m *mockNonSLOConnector) HandleCallback(s connector.Scopes, r *http.Request) (connector.Identity, error) { + return connector.Identity{}, nil +} + +// registerTestConnector creates a connector in storage and registers it in the server's connectors map. +func registerTestConnector(t *testing.T, s *Server, connID string, c connector.Connector) { + t.Helper() + ctx := t.Context() + + storageConn := storage.Connector{ + ID: connID, + Type: "saml", + Name: "Test SAML", + ResourceVersion: "1", + } + if err := s.storage.CreateConnector(ctx, storageConn); err != nil { + t.Fatalf("failed to create connector in storage: %v", err) + } + + s.mu.Lock() + s.connectors[connID] = Connector{ + ResourceVersion: "1", + Connector: c, + } + s.mu.Unlock() +} + +func TestHandleSAMLSLO(t *testing.T) { + t.Run("SuccessfulSLO", func(t *testing.T) { + httpServer, server := newTestServer(t, nil) + defer httpServer.Close() + + ctx := t.Context() + connID := "saml-slo-test" + nameID := "user@example.com" + + mockConn := &mockSAMLSLOConnector{ + sloNameID: nameID, + } + registerTestConnector(t, server, connID, mockConn) + + // Create refresh tokens and offline session + refreshToken1 := storage.RefreshToken{ + ID: "refresh-token-1", + Token: "token-1", + CreatedAt: time.Now(), + LastUsed: time.Now(), + ClientID: "test-client", + ConnectorID: connID, + Claims: storage.Claims{ + UserID: nameID, + Username: "testuser", + Email: nameID, + }, + Nonce: "nonce-1", + } + refreshToken2 := storage.RefreshToken{ + ID: "refresh-token-2", + Token: "token-2", + CreatedAt: time.Now(), + LastUsed: time.Now(), + ClientID: "test-client-2", + ConnectorID: connID, + Claims: storage.Claims{ + UserID: nameID, + Username: "testuser", + Email: nameID, + }, + Nonce: "nonce-2", + } + require.NoError(t, server.storage.CreateRefresh(ctx, refreshToken1)) + require.NoError(t, server.storage.CreateRefresh(ctx, refreshToken2)) + + offlineSession := storage.OfflineSessions{ + UserID: nameID, + ConnID: connID, + Refresh: map[string]*storage.RefreshTokenRef{ + refreshToken1.ClientID: { + ID: refreshToken1.ID, + ClientID: refreshToken1.ClientID, + CreatedAt: refreshToken1.CreatedAt, + LastUsed: refreshToken1.LastUsed, + }, + refreshToken2.ClientID: { + ID: refreshToken2.ID, + ClientID: refreshToken2.ClientID, + CreatedAt: refreshToken2.CreatedAt, + LastUsed: refreshToken2.LastUsed, + }, + }, + } + require.NoError(t, server.storage.CreateOfflineSessions(ctx, offlineSession)) + + // Send POST to /saml/slo/{connector} + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/saml/slo/"+connID, nil) + server.ServeHTTP(rr, req) + + require.Equal(t, http.StatusOK, rr.Code, "expected HTTP 200, got %d: %s", rr.Code, rr.Body.String()) + + // Verify refresh tokens are deleted + _, err := server.storage.GetRefresh(ctx, refreshToken1.ID) + require.ErrorIs(t, err, storage.ErrNotFound, "refresh token 1 should be deleted") + + _, err = server.storage.GetRefresh(ctx, refreshToken2.ID) + require.ErrorIs(t, err, storage.ErrNotFound, "refresh token 2 should be deleted") + + // Verify offline session is deleted + _, err = server.storage.GetOfflineSessions(ctx, nameID, connID) + require.ErrorIs(t, err, storage.ErrNotFound, "offline session should be deleted") + }) + + t.Run("NoExistingSessions", func(t *testing.T) { + httpServer, server := newTestServer(t, nil) + defer httpServer.Close() + + connID := "saml-slo-nosession" + nameID := "nouser@example.com" + + mockConn := &mockSAMLSLOConnector{ + sloNameID: nameID, + } + registerTestConnector(t, server, connID, mockConn) + + // No refresh tokens or offline sessions created — should handle gracefully + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/saml/slo/"+connID, nil) + server.ServeHTTP(rr, req) + + require.Equal(t, http.StatusOK, rr.Code, "expected HTTP 200 for no sessions, got %d: %s", rr.Code, rr.Body.String()) + }) + + t.Run("ConnectorNotFound", func(t *testing.T) { + httpServer, server := newTestServer(t, nil) + defer httpServer.Close() + + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/saml/slo/nonexistent-connector", nil) + server.ServeHTTP(rr, req) + + require.Equal(t, http.StatusNotFound, rr.Code, "expected HTTP 404 for missing connector") + }) + + t.Run("ConnectorDoesNotSupportSLO", func(t *testing.T) { + httpServer, server := newTestServer(t, nil) + defer httpServer.Close() + + connID := "saml-no-slo" + mockConn := &mockNonSLOConnector{} + registerTestConnector(t, server, connID, mockConn) + + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/saml/slo/"+connID, nil) + server.ServeHTTP(rr, req) + + require.Equal(t, http.StatusBadRequest, rr.Code, "expected HTTP 400 for connector without SLO support") + }) + + t.Run("HandleSLOError", func(t *testing.T) { + httpServer, server := newTestServer(t, nil) + defer httpServer.Close() + + connID := "saml-slo-error" + mockConn := &mockSAMLSLOConnector{ + sloNameID: "", + sloErr: errors.New("invalid SAMLRequest"), + } + registerTestConnector(t, server, connID, mockConn) + + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/saml/slo/"+connID, nil) + server.ServeHTTP(rr, req) + + require.Equal(t, http.StatusBadRequest, rr.Code, "expected HTTP 400 for invalid SAMLRequest") + }) + + t.Run("MultipleTokensPartialDeletion", func(t *testing.T) { + httpServer, server := newTestServer(t, nil) + defer httpServer.Close() + + ctx := t.Context() + connID := "saml-slo-partial" + nameID := "partial@example.com" + + mockConn := &mockSAMLSLOConnector{ + sloNameID: nameID, + } + registerTestConnector(t, server, connID, mockConn) + + // Create only one refresh token but reference two in offline session + // (simulating a token that was already deleted) + refreshToken := storage.RefreshToken{ + ID: "existing-token", + Token: "token-existing", + CreatedAt: time.Now(), + LastUsed: time.Now(), + ClientID: "client-existing", + ConnectorID: connID, + Claims: storage.Claims{ + UserID: nameID, + Username: "partialuser", + Email: nameID, + }, + Nonce: "nonce-existing", + } + require.NoError(t, server.storage.CreateRefresh(ctx, refreshToken)) + + offlineSession := storage.OfflineSessions{ + UserID: nameID, + ConnID: connID, + Refresh: map[string]*storage.RefreshTokenRef{ + "client-existing": { + ID: "existing-token", + ClientID: "client-existing", + CreatedAt: time.Now(), + LastUsed: time.Now(), + }, + "client-missing": { + ID: "already-deleted-token", + ClientID: "client-missing", + CreatedAt: time.Now(), + LastUsed: time.Now(), + }, + }, + } + require.NoError(t, server.storage.CreateOfflineSessions(ctx, offlineSession)) + + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/saml/slo/"+connID, nil) + server.ServeHTTP(rr, req) + + require.Equal(t, http.StatusOK, rr.Code, "expected HTTP 200 even with partial deletion") + + // Verify existing token is deleted + _, err := server.storage.GetRefresh(ctx, "existing-token") + require.ErrorIs(t, err, storage.ErrNotFound, "existing refresh token should be deleted") + + // Verify offline session is deleted + _, err = server.storage.GetOfflineSessions(ctx, nameID, connID) + require.ErrorIs(t, err, storage.ErrNotFound, "offline session should be deleted") + }) + + t.Run("RefreshAfterSLO", func(t *testing.T) { + // Test that refresh token is rejected after SLO invalidation. + // This is the key integration test: SLO → refresh → error. + httpServer, server := newTestServer(t, func(c *Config) { + c.RefreshTokenPolicy = &RefreshTokenPolicy{rotateRefreshTokens: true} + }) + defer httpServer.Close() + + ctx := t.Context() + connID := "saml-slo-refresh" + nameID := "slo-refresh@example.com" + + mockConn := &mockSAMLSLOConnector{ + sloNameID: nameID, + } + registerTestConnector(t, server, connID, mockConn) + + // Create client for refresh token request + client := storage.Client{ + ID: "slo-test-client", + Secret: "slo-test-secret", + RedirectURIs: []string{"https://example.com/callback"}, + Name: "SLO Test Client", + } + require.NoError(t, server.storage.CreateClient(ctx, client)) + + // Create refresh token + refreshToken := storage.RefreshToken{ + ID: "slo-refresh-token", + Token: "slo-token-value", + CreatedAt: time.Now(), + LastUsed: time.Now(), + ClientID: client.ID, + ConnectorID: connID, + Scopes: []string{"openid", "email", "offline_access"}, + Claims: storage.Claims{ + UserID: nameID, + Username: "slouser", + Email: nameID, + EmailVerified: true, + }, + ConnectorData: []byte(`{"userID":"` + nameID + `","username":"slouser","email":"` + nameID + `","emailVerified":true}`), + Nonce: "slo-nonce", + } + require.NoError(t, server.storage.CreateRefresh(ctx, refreshToken)) + + offlineSession := storage.OfflineSessions{ + UserID: nameID, + ConnID: connID, + Refresh: map[string]*storage.RefreshTokenRef{ + client.ID: { + ID: refreshToken.ID, + ClientID: client.ID, + CreatedAt: refreshToken.CreatedAt, + LastUsed: refreshToken.LastUsed, + }, + }, + } + require.NoError(t, server.storage.CreateOfflineSessions(ctx, offlineSession)) + + // Step 1: Verify refresh token exists before SLO + _, err := server.storage.GetRefresh(ctx, refreshToken.ID) + require.NoError(t, err, "refresh token should exist before SLO") + + // Step 2: Send SLO request + sloRR := httptest.NewRecorder() + sloReq := httptest.NewRequest(http.MethodPost, "/saml/slo/"+connID, nil) + server.ServeHTTP(sloRR, sloReq) + require.Equal(t, http.StatusOK, sloRR.Code, "SLO should succeed") + + // Step 3: Verify refresh token is deleted + _, err = server.storage.GetRefresh(ctx, refreshToken.ID) + require.ErrorIs(t, err, storage.ErrNotFound, "refresh token should be deleted after SLO") + + // Step 4: Attempt to use the refresh token — should fail + tokenData, err := internal.Marshal(&internal.RefreshToken{RefreshId: refreshToken.ID, Token: refreshToken.Token}) + require.NoError(t, err) + + u, err := url.Parse(server.issuerURL.String()) + require.NoError(t, err) + u.Path = path.Join(u.Path, "/token") + + v := url.Values{} + v.Add("grant_type", "refresh_token") + v.Add("refresh_token", tokenData) + + refreshReq, _ := http.NewRequest("POST", u.String(), bytes.NewBufferString(v.Encode())) + refreshReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") + refreshReq.SetBasicAuth(client.ID, client.Secret) + + refreshRR := httptest.NewRecorder() + server.ServeHTTP(refreshRR, refreshReq) + + // Refresh should fail because the token was deleted by SLO + require.NotEqual(t, http.StatusOK, refreshRR.Code, + "refresh should fail after SLO, got status %d: %s", refreshRR.Code, refreshRR.Body.String()) + }) + + t.Run("ConnectorDataPersistence", func(t *testing.T) { + // Test that ConnectorData is correctly stored in refresh token + // and can be used for subsequent refresh operations. + httpServer, server := newTestServer(t, func(c *Config) { + c.RefreshTokenPolicy = &RefreshTokenPolicy{rotateRefreshTokens: true} + }) + defer httpServer.Close() + + ctx := t.Context() + connID := "saml-conndata" + + // Create a mock SAML connector that also implements RefreshConnector + mockConn := &mockSAMLRefreshConnector{ + refreshIdentity: connector.Identity{ + UserID: "refreshed-user", + Username: "refreshed-name", + Email: "refreshed@example.com", + EmailVerified: true, + Groups: []string{"refreshed-group"}, + }, + } + registerTestConnector(t, server, connID, mockConn) + + // Create client + client := storage.Client{ + ID: "conndata-client", + Secret: "conndata-secret", + RedirectURIs: []string{"https://example.com/callback"}, + Name: "ConnData Test Client", + } + require.NoError(t, server.storage.CreateClient(ctx, client)) + + // Create refresh token with ConnectorData (simulating what HandlePOST would store) + connectorData := []byte(`{"userID":"user-123","username":"testuser","email":"test@example.com","emailVerified":true,"groups":["admin","dev"]}`) + refreshToken := storage.RefreshToken{ + ID: "conndata-refresh", + Token: "conndata-token", + CreatedAt: time.Now(), + LastUsed: time.Now(), + ClientID: client.ID, + ConnectorID: connID, + Scopes: []string{"openid", "email", "offline_access"}, + Claims: storage.Claims{ + UserID: "user-123", + Username: "testuser", + Email: "test@example.com", + EmailVerified: true, + Groups: []string{"admin", "dev"}, + }, + ConnectorData: connectorData, + Nonce: "conndata-nonce", + } + require.NoError(t, server.storage.CreateRefresh(ctx, refreshToken)) + + offlineSession := storage.OfflineSessions{ + UserID: "user-123", + ConnID: connID, + Refresh: map[string]*storage.RefreshTokenRef{client.ID: {ID: refreshToken.ID, ClientID: client.ID}}, + ConnectorData: connectorData, + } + require.NoError(t, server.storage.CreateOfflineSessions(ctx, offlineSession)) + + // Verify ConnectorData is stored correctly + storedToken, err := server.storage.GetRefresh(ctx, refreshToken.ID) + require.NoError(t, err) + require.Equal(t, connectorData, storedToken.ConnectorData, + "ConnectorData should be persisted in refresh token storage") + + // Verify ConnectorData is stored in offline session + storedSession, err := server.storage.GetOfflineSessions(ctx, "user-123", connID) + require.NoError(t, err) + require.Equal(t, connectorData, storedSession.ConnectorData, + "ConnectorData should be persisted in offline session storage") + }) +} + +// mockSAMLRefreshConnector implements SAMLConnector + RefreshConnector for testing. +type mockSAMLRefreshConnector struct { + refreshIdentity connector.Identity +} + +func (m *mockSAMLRefreshConnector) POSTData(s connector.Scopes, requestID string) (ssoURL, samlRequest string, err error) { + return "", "", nil +} + +func (m *mockSAMLRefreshConnector) HandlePOST(s connector.Scopes, samlResponse, inResponseTo string) (connector.Identity, error) { + return connector.Identity{}, nil +} + +func (m *mockSAMLRefreshConnector) Refresh(ctx context.Context, s connector.Scopes, ident connector.Identity) (connector.Identity, error) { + return m.refreshIdentity, nil +} diff --git a/server/server.go b/server/server.go index e923e3e0..e76611a5 100644 --- a/server/server.go +++ b/server/server.go @@ -494,6 +494,7 @@ func newServer(ctx context.Context, c Config) (*Server, error) { // For easier connector-specific web server configuration, e.g. for the // "authproxy" connector. handleFunc("/callback/{connector}", s.handleConnectorCallback) + handleFunc("/saml/slo/{connector}", s.handleSAMLSLO) handleFunc("/approval", s.handleApproval) handle("/healthz", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if !c.HealthChecker.IsHealthy() {