Browse Source

feat(saml): support refresh tokens and SLO

Signed-off-by: Ivan Zvyagintsev <ivan.zvyagintsev@flant.com>
pull/4565/head
Ivan Zvyagintsev 4 weeks ago
parent
commit
c9b8a48c15
  1. 10
      connector/connector.go
  2. 142
      connector/saml/saml.go
  3. 480
      connector/saml/saml_test.go
  4. 18
      connector/saml/types.go
  5. 91
      server/handlers.go
  6. 460
      server/handlers_test.go
  7. 1
      server/server.go

10
connector/connector.go

@ -92,6 +92,16 @@ type SAMLConnector interface {
HandlePOST(s Scopes, samlResponse, inResponseTo string) (identity Identity, err error)
}
// SAMLSLOConnector represents a SAML connector that supports Single Logout (SLO).
// When an IdP sends a LogoutRequest, the connector parses it and returns the
// NameID of the user being logged out.
type SAMLSLOConnector interface {
// HandleSLO processes a SAML LogoutRequest from the IdP.
// It validates the request signature and returns the NameID (user identifier)
// that should be used to invalidate the user's sessions.
HandleSLO(w http.ResponseWriter, r *http.Request) (nameID string, err error)
}
// RefreshConnector is a connector that can update the client claims.
type RefreshConnector interface {
// Refresh is called when a client attempts to claim a refresh token. The

142
connector/saml/saml.go

@ -3,12 +3,15 @@ package saml
import (
"bytes"
"context"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"encoding/xml"
"fmt"
"log/slog"
"net/http"
"os"
"strings"
"sync"
@ -82,6 +85,11 @@ type Config struct {
InsecureSkipSignatureValidation bool `json:"insecureSkipSignatureValidation"`
// InsecureSkipSLOSignatureValidation skips signature validation on SLO requests.
// This is insecure and should only be used for testing or when the IdP
// does not sign LogoutRequests.
InsecureSkipSLOSignatureValidation bool `json:"insecureSkipSLOSignatureValidation"`
// Assertion attribute names to lookup various claims with.
UsernameAttr string `json:"usernameAttr"`
EmailAttr string `json:"emailAttr"`
@ -162,6 +170,8 @@ func (c *Config) openConnector(logger *slog.Logger) (*provider, error) {
logger: logger,
nameIDPolicyFormat: c.NameIDPolicyFormat,
insecureSkipSLOSignatureValidation: c.InsecureSkipSLOSignatureValidation,
}
if p.nameIDPolicyFormat == "" {
@ -252,9 +262,47 @@ type provider struct {
nameIDPolicyFormat string
insecureSkipSLOSignatureValidation bool
logger *slog.Logger
}
// Compile-time check that provider implements RefreshConnector
var _ connector.RefreshConnector = (*provider)(nil)
// Compile-time check that provider implements SAMLSLOConnector
var _ connector.SAMLSLOConnector = (*provider)(nil)
// cachedIdentity stores the identity from SAML assertion for refresh token support.
// Since SAML has no native refresh mechanism, we cache the identity obtained during
// the initial authentication and return it on subsequent refresh requests.
type cachedIdentity struct {
UserID string `json:"userId"`
Username string `json:"username"`
PreferredUsername string `json:"preferredUsername"`
Email string `json:"email"`
EmailVerified bool `json:"emailVerified"`
Groups []string `json:"groups,omitempty"`
}
// marshalCachedIdentity serializes the identity into ConnectorData for refresh token support.
func marshalCachedIdentity(ident connector.Identity) (connector.Identity, error) {
ci := cachedIdentity{
UserID: ident.UserID,
Username: ident.Username,
PreferredUsername: ident.PreferredUsername,
Email: ident.Email,
EmailVerified: ident.EmailVerified,
Groups: ident.Groups,
}
connectorData, err := json.Marshal(ci)
if err != nil {
return ident, fmt.Errorf("saml: failed to marshal cached identity: %v", err)
}
ident.ConnectorData = connectorData
return ident, nil
}
func (p *provider) POSTData(s connector.Scopes, id string) (action, value string, err error) {
r := &authnRequest{
ProtocolBinding: bindingPOST,
@ -405,7 +453,7 @@ func (p *provider) HandlePOST(s connector.Scopes, samlResponse, inResponseTo str
if len(p.allowedGroups) == 0 && (!s.Groups || p.groupsAttr == "") {
// Groups not requested or not configured. We're done.
return ident, nil
return marshalCachedIdentity(ident)
}
if len(p.allowedGroups) > 0 && (!s.Groups || p.groupsAttr == "") {
@ -431,7 +479,7 @@ func (p *provider) HandlePOST(s connector.Scopes, samlResponse, inResponseTo str
if len(p.allowedGroups) == 0 {
// No allowed groups set, just return the ident
return ident, nil
return marshalCachedIdentity(ident)
}
// Look for membership in one of the allowed groups
@ -447,6 +495,35 @@ func (p *provider) HandlePOST(s connector.Scopes, samlResponse, inResponseTo str
}
// Otherwise, we're good
return marshalCachedIdentity(ident)
}
// Refresh implements connector.RefreshConnector.
// Since SAML has no native refresh mechanism, this method returns the cached
// identity from the initial SAML assertion stored in ConnectorData.
func (p *provider) Refresh(ctx context.Context, s connector.Scopes, ident connector.Identity) (connector.Identity, error) {
if len(ident.ConnectorData) == 0 {
return ident, fmt.Errorf("saml: no connector data available for refresh")
}
var ci cachedIdentity
if err := json.Unmarshal(ident.ConnectorData, &ci); err != nil {
return ident, fmt.Errorf("saml: failed to unmarshal cached identity: %v", err)
}
ident.UserID = ci.UserID
ident.Username = ci.Username
ident.PreferredUsername = ci.PreferredUsername
ident.Email = ci.Email
ident.EmailVerified = ci.EmailVerified
// Only populate groups if the client requested the groups scope.
if s.Groups {
ident.Groups = ci.Groups
} else {
ident.Groups = nil
}
return ident, nil
}
@ -645,3 +722,64 @@ func before(now, notBefore time.Time) bool {
func after(now, notOnOrAfter time.Time) bool {
return now.After(notOnOrAfter.Add(allowedClockDrift))
}
// validateSignature validates the XML digital signature of the given raw XML data.
func (p *provider) validateSignature(rawXML []byte) ([]byte, error) {
doc := etree.NewDocument()
if err := doc.ReadFromBytes(rawXML); err != nil {
return nil, fmt.Errorf("failed to parse XML: %v", err)
}
// Find the Signature element
root := doc.Root()
if root == nil {
return nil, fmt.Errorf("empty XML document")
}
_, err := p.validator.Validate(root)
if err != nil {
return nil, fmt.Errorf("signature validation failed: %v", err)
}
return rawXML, nil
}
// HandleSLO processes a SAML LogoutRequest from the IdP.
// It validates the request, extracts the NameID, and returns it for session invalidation.
func (p *provider) HandleSLO(w http.ResponseWriter, r *http.Request) (string, error) {
if r.Method != http.MethodPost {
return "", fmt.Errorf("saml slo: expected POST method, got %s", r.Method)
}
if err := r.ParseForm(); err != nil {
return "", fmt.Errorf("saml slo: failed to parse form: %v", err)
}
samlRequest := r.FormValue("SAMLRequest")
if samlRequest == "" {
return "", fmt.Errorf("saml slo: missing SAMLRequest parameter")
}
rawRequest, err := base64.StdEncoding.DecodeString(samlRequest)
if err != nil {
return "", fmt.Errorf("saml slo: failed to decode SAMLRequest: %v", err)
}
// Validate signature unless explicitly skipped
if !p.insecureSkipSLOSignatureValidation {
if _, err := p.validateSignature(rawRequest); err != nil {
return "", fmt.Errorf("saml slo: signature validation failed: %v", err)
}
}
var req logoutRequest
if err := xml.Unmarshal(rawRequest, &req); err != nil {
return "", fmt.Errorf("saml slo: failed to unmarshal LogoutRequest: %v", err)
}
if req.NameID.Value == "" {
return "", fmt.Errorf("saml slo: LogoutRequest missing NameID")
}
return req.NameID.Value, nil
}

480
connector/saml/saml_test.go

@ -1,13 +1,20 @@
package saml
import (
"context"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"log/slog"
"net/http"
"net/http/httptest"
"net/url"
"os"
"sort"
"strings"
"testing"
"time"
@ -448,6 +455,24 @@ func (r responseTest) run(t *testing.T) {
}
sort.Strings(ident.Groups)
sort.Strings(r.wantIdent.Groups)
// Verify ConnectorData contains valid cached identity, then clear it
// for the main identity comparison (ConnectorData is an implementation
// detail of refresh token support).
if len(ident.ConnectorData) > 0 {
var ci cachedIdentity
if err := json.Unmarshal(ident.ConnectorData, &ci); err != nil {
t.Fatalf("failed to unmarshal ConnectorData: %v", err)
}
if ci.UserID != ident.UserID {
t.Errorf("cached identity UserID mismatch: got %q, want %q", ci.UserID, ident.UserID)
}
if ci.Email != ident.Email {
t.Errorf("cached identity Email mismatch: got %q, want %q", ci.Email, ident.Email)
}
}
ident.ConnectorData = nil
if diff := pretty.Compare(ident, r.wantIdent); diff != "" {
t.Error(diff)
}
@ -589,3 +614,458 @@ func TestVerifySignedMessageAndSignedAssertion(t *testing.T) {
func TestVerifyUnsignedMessageAndUnsignedAssertion(t *testing.T) {
runVerify(t, "testdata/idp-cert.pem", "testdata/idp-resp.xml", false)
}
func TestSAMLRefresh(t *testing.T) {
// Create a provider using the same pattern as existing tests.
c := Config{
CA: "testdata/ca.crt",
UsernameAttr: "Name",
EmailAttr: "email",
GroupsAttr: "groups",
RedirectURI: "http://127.0.0.1:5556/dex/callback",
SSOURL: "http://foo.bar/",
}
conn, err := c.openConnector(slog.New(slog.DiscardHandler))
if err != nil {
t.Fatal(err)
}
t.Run("SuccessfulRefresh", func(t *testing.T) {
ci := cachedIdentity{
UserID: "test-user-id",
Username: "testuser",
PreferredUsername: "testuser",
Email: "test@example.com",
EmailVerified: true,
Groups: []string{"group1", "group2"},
}
connectorData, err := json.Marshal(ci)
if err != nil {
t.Fatal(err)
}
ident := connector.Identity{
UserID: "old-id",
Username: "old-name",
ConnectorData: connectorData,
}
refreshed, err := conn.Refresh(context.Background(), connector.Scopes{Groups: true}, ident)
if err != nil {
t.Fatalf("Refresh failed: %v", err)
}
if refreshed.UserID != "test-user-id" {
t.Errorf("expected UserID %q, got %q", "test-user-id", refreshed.UserID)
}
if refreshed.Username != "testuser" {
t.Errorf("expected Username %q, got %q", "testuser", refreshed.Username)
}
if refreshed.PreferredUsername != "testuser" {
t.Errorf("expected PreferredUsername %q, got %q", "testuser", refreshed.PreferredUsername)
}
if refreshed.Email != "test@example.com" {
t.Errorf("expected Email %q, got %q", "test@example.com", refreshed.Email)
}
if !refreshed.EmailVerified {
t.Error("expected EmailVerified to be true")
}
if len(refreshed.Groups) != 2 || refreshed.Groups[0] != "group1" || refreshed.Groups[1] != "group2" {
t.Errorf("expected groups [group1, group2], got %v", refreshed.Groups)
}
// ConnectorData should be preserved through refresh
if len(refreshed.ConnectorData) == 0 {
t.Error("expected ConnectorData to be preserved")
}
})
t.Run("RefreshPreservesConnectorData", func(t *testing.T) {
ci := cachedIdentity{
UserID: "user-123",
Username: "alice",
Email: "alice@example.com",
EmailVerified: true,
}
connectorData, err := json.Marshal(ci)
if err != nil {
t.Fatal(err)
}
ident := connector.Identity{
UserID: "old-id",
ConnectorData: connectorData,
}
refreshed, err := conn.Refresh(context.Background(), connector.Scopes{}, ident)
if err != nil {
t.Fatalf("Refresh failed: %v", err)
}
// Verify the refreshed identity can be refreshed again (round-trip)
var roundTrip cachedIdentity
if err := json.Unmarshal(refreshed.ConnectorData, &roundTrip); err != nil {
t.Fatalf("failed to unmarshal ConnectorData after refresh: %v", err)
}
if roundTrip.UserID != "user-123" {
t.Errorf("round-trip UserID mismatch: got %q, want %q", roundTrip.UserID, "user-123")
}
})
t.Run("EmptyConnectorData", func(t *testing.T) {
ident := connector.Identity{
UserID: "test-id",
ConnectorData: nil,
}
_, err := conn.Refresh(context.Background(), connector.Scopes{}, ident)
if err == nil {
t.Error("expected error for empty ConnectorData")
}
})
t.Run("InvalidJSON", func(t *testing.T) {
ident := connector.Identity{
UserID: "test-id",
ConnectorData: []byte("not-json"),
}
_, err := conn.Refresh(context.Background(), connector.Scopes{}, ident)
if err == nil {
t.Error("expected error for invalid JSON")
}
})
t.Run("HandlePOSTThenRefresh", func(t *testing.T) {
// Full integration: HandlePOST → get ConnectorData → Refresh → verify identity
now, err := time.Parse(timeFormat, "2017-04-04T04:34:59.330Z")
if err != nil {
t.Fatal(err)
}
conn.now = func() time.Time { return now }
resp, err := os.ReadFile("testdata/good-resp.xml")
if err != nil {
t.Fatal(err)
}
samlResp := base64.StdEncoding.EncodeToString(resp)
scopes := connector.Scopes{
OfflineAccess: true,
Groups: true,
}
ident, err := conn.HandlePOST(scopes, samlResp, "6zmm5mguyebwvajyf2sdwwcw6m")
if err != nil {
t.Fatalf("HandlePOST failed: %v", err)
}
if len(ident.ConnectorData) == 0 {
t.Fatal("expected ConnectorData to be set after HandlePOST")
}
// Now refresh using the ConnectorData from HandlePOST
refreshed, err := conn.Refresh(context.Background(), scopes, ident)
if err != nil {
t.Fatalf("Refresh failed: %v", err)
}
if refreshed.UserID != ident.UserID {
t.Errorf("UserID mismatch: got %q, want %q", refreshed.UserID, ident.UserID)
}
if refreshed.Username != ident.Username {
t.Errorf("Username mismatch: got %q, want %q", refreshed.Username, ident.Username)
}
if refreshed.Email != ident.Email {
t.Errorf("Email mismatch: got %q, want %q", refreshed.Email, ident.Email)
}
if refreshed.EmailVerified != ident.EmailVerified {
t.Errorf("EmailVerified mismatch: got %v, want %v", refreshed.EmailVerified, ident.EmailVerified)
}
sort.Strings(refreshed.Groups)
sort.Strings(ident.Groups)
if len(refreshed.Groups) != len(ident.Groups) {
t.Errorf("Groups length mismatch: got %d, want %d", len(refreshed.Groups), len(ident.Groups))
}
for i := range ident.Groups {
if i < len(refreshed.Groups) && refreshed.Groups[i] != ident.Groups[i] {
t.Errorf("Groups[%d] mismatch: got %q, want %q", i, refreshed.Groups[i], ident.Groups[i])
}
}
})
t.Run("HandlePOSTThenDoubleRefresh", func(t *testing.T) {
// Verify that refresh tokens can be chained: HandlePOST → Refresh → Refresh
now, err := time.Parse(timeFormat, "2017-04-04T04:34:59.330Z")
if err != nil {
t.Fatal(err)
}
conn.now = func() time.Time { return now }
resp, err := os.ReadFile("testdata/good-resp.xml")
if err != nil {
t.Fatal(err)
}
samlResp := base64.StdEncoding.EncodeToString(resp)
scopes := connector.Scopes{OfflineAccess: true, Groups: true}
ident, err := conn.HandlePOST(scopes, samlResp, "6zmm5mguyebwvajyf2sdwwcw6m")
if err != nil {
t.Fatalf("HandlePOST failed: %v", err)
}
// First refresh
refreshed1, err := conn.Refresh(context.Background(), scopes, ident)
if err != nil {
t.Fatalf("first Refresh failed: %v", err)
}
if len(refreshed1.ConnectorData) == 0 {
t.Fatal("expected ConnectorData after first refresh")
}
// Second refresh using output of first refresh
refreshed2, err := conn.Refresh(context.Background(), scopes, refreshed1)
if err != nil {
t.Fatalf("second Refresh failed: %v", err)
}
// All fields should match original
if refreshed2.UserID != ident.UserID {
t.Errorf("UserID mismatch after double refresh: got %q, want %q", refreshed2.UserID, ident.UserID)
}
if refreshed2.Email != ident.Email {
t.Errorf("Email mismatch after double refresh: got %q, want %q", refreshed2.Email, ident.Email)
}
if refreshed2.Username != ident.Username {
t.Errorf("Username mismatch after double refresh: got %q, want %q", refreshed2.Username, ident.Username)
}
})
t.Run("HandlePOSTWithAssertionSignedThenRefresh", func(t *testing.T) {
// Test with assertion-signed.xml (signature on assertion, not response)
now, err := time.Parse(timeFormat, "2017-04-04T04:34:59.330Z")
if err != nil {
t.Fatal(err)
}
conn.now = func() time.Time { return now }
resp, err := os.ReadFile("testdata/assertion-signed.xml")
if err != nil {
t.Fatal(err)
}
samlResp := base64.StdEncoding.EncodeToString(resp)
scopes := connector.Scopes{OfflineAccess: true, Groups: true}
ident, err := conn.HandlePOST(scopes, samlResp, "6zmm5mguyebwvajyf2sdwwcw6m")
if err != nil {
t.Fatalf("HandlePOST with assertion-signed failed: %v", err)
}
if len(ident.ConnectorData) == 0 {
t.Fatal("expected ConnectorData after HandlePOST with assertion-signed")
}
refreshed, err := conn.Refresh(context.Background(), scopes, ident)
if err != nil {
t.Fatalf("Refresh after assertion-signed HandlePOST failed: %v", err)
}
if refreshed.Email != ident.Email {
t.Errorf("Email mismatch: got %q, want %q", refreshed.Email, ident.Email)
}
if refreshed.Username != ident.Username {
t.Errorf("Username mismatch: got %q, want %q", refreshed.Username, ident.Username)
}
})
t.Run("HandlePOSTRefreshWithoutGroupsScope", func(t *testing.T) {
// Verify that groups are NOT returned when groups scope is not requested during refresh
now, err := time.Parse(timeFormat, "2017-04-04T04:34:59.330Z")
if err != nil {
t.Fatal(err)
}
conn.now = func() time.Time { return now }
resp, err := os.ReadFile("testdata/good-resp.xml")
if err != nil {
t.Fatal(err)
}
samlResp := base64.StdEncoding.EncodeToString(resp)
// Initial auth WITH groups
scopesWithGroups := connector.Scopes{OfflineAccess: true, Groups: true}
ident, err := conn.HandlePOST(scopesWithGroups, samlResp, "6zmm5mguyebwvajyf2sdwwcw6m")
if err != nil {
t.Fatalf("HandlePOST failed: %v", err)
}
if len(ident.Groups) == 0 {
t.Fatal("expected groups in initial identity")
}
// Refresh WITHOUT groups scope
scopesNoGroups := connector.Scopes{OfflineAccess: true, Groups: false}
refreshed, err := conn.Refresh(context.Background(), scopesNoGroups, ident)
if err != nil {
t.Fatalf("Refresh failed: %v", err)
}
if len(refreshed.Groups) != 0 {
t.Errorf("expected no groups when groups scope not requested, got %v", refreshed.Groups)
}
// Refresh WITH groups scope — groups should be back
refreshedWithGroups, err := conn.Refresh(context.Background(), scopesWithGroups, ident)
if err != nil {
t.Fatalf("Refresh with groups failed: %v", err)
}
if len(refreshedWithGroups.Groups) == 0 {
t.Error("expected groups when groups scope is requested")
}
})
}
func TestSAMLHandleSLO(t *testing.T) {
c := Config{
CA: "testdata/ca.crt",
UsernameAttr: "Name",
EmailAttr: "email",
RedirectURI: "http://127.0.0.1:5556/dex/callback",
SSOURL: "http://foo.bar/",
InsecureSkipSLOSignatureValidation: true,
}
conn, err := c.openConnector(slog.New(slog.DiscardHandler))
if err != nil {
t.Fatal(err)
}
// Helper to create a LogoutRequest XML
makeLogoutRequest := func(nameID string) string {
return fmt.Sprintf(`<samlp:LogoutRequest xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol" xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion" ID="_test123" Version="2.0" IssueInstant="2024-01-01T00:00:00Z">
<saml:Issuer>https://idp.example.com</saml:Issuer>
<saml:NameID Format="urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress">%s</saml:NameID>
</samlp:LogoutRequest>`, nameID)
}
t.Run("ValidLogoutRequest", func(t *testing.T) {
logoutXML := makeLogoutRequest("user@example.com")
encoded := base64.StdEncoding.EncodeToString([]byte(logoutXML))
form := url.Values{}
form.Set("SAMLRequest", encoded)
req := httptest.NewRequest(http.MethodPost, "/saml/slo/test", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
w := httptest.NewRecorder()
nameID, err := conn.HandleSLO(w, req)
if err != nil {
t.Fatalf("HandleSLO failed: %v", err)
}
if nameID != "user@example.com" {
t.Errorf("expected nameID %q, got %q", "user@example.com", nameID)
}
})
t.Run("MissingSAMLRequest", func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/saml/slo/test", strings.NewReader(""))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
w := httptest.NewRecorder()
_, err := conn.HandleSLO(w, req)
if err == nil {
t.Error("expected error for missing SAMLRequest")
}
})
t.Run("InvalidBase64", func(t *testing.T) {
form := url.Values{}
form.Set("SAMLRequest", "not-valid-base64!!!")
req := httptest.NewRequest(http.MethodPost, "/saml/slo/test", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
w := httptest.NewRecorder()
_, err := conn.HandleSLO(w, req)
if err == nil {
t.Error("expected error for invalid base64")
}
})
t.Run("InvalidXML", func(t *testing.T) {
encoded := base64.StdEncoding.EncodeToString([]byte("not xml at all"))
form := url.Values{}
form.Set("SAMLRequest", encoded)
req := httptest.NewRequest(http.MethodPost, "/saml/slo/test", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
w := httptest.NewRecorder()
_, err := conn.HandleSLO(w, req)
if err == nil {
t.Error("expected error for invalid XML")
}
})
t.Run("MissingNameID", func(t *testing.T) {
logoutXML := `<samlp:LogoutRequest xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol" xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion" ID="_test123" Version="2.0" IssueInstant="2024-01-01T00:00:00Z">
<saml:Issuer>https://idp.example.com</saml:Issuer>
<saml:NameID></saml:NameID>
</samlp:LogoutRequest>`
encoded := base64.StdEncoding.EncodeToString([]byte(logoutXML))
form := url.Values{}
form.Set("SAMLRequest", encoded)
req := httptest.NewRequest(http.MethodPost, "/saml/slo/test", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
w := httptest.NewRecorder()
_, err := conn.HandleSLO(w, req)
if err == nil {
t.Error("expected error for missing NameID")
}
})
t.Run("WrongHTTPMethod", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/saml/slo/test", nil)
w := httptest.NewRecorder()
_, err := conn.HandleSLO(w, req)
if err == nil {
t.Error("expected error for GET method")
}
})
t.Run("DifferentNameIDValues", func(t *testing.T) {
testCases := []struct {
name string
nameIDVal string
wantNameID string
}{
{"email format", "admin@corp.example.com", "admin@corp.example.com"},
{"persistent ID", "AQIC5w...", "AQIC5w..."},
{"transient ID", "_ce3d2948b4cf20146dee0a0b3dd6f69b6cf86f62d7", "_ce3d2948b4cf20146dee0a0b3dd6f69b6cf86f62d7"},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
logoutXML := makeLogoutRequest(tc.nameIDVal)
encoded := base64.StdEncoding.EncodeToString([]byte(logoutXML))
form := url.Values{}
form.Set("SAMLRequest", encoded)
req := httptest.NewRequest(http.MethodPost, "/saml/slo/test", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
w := httptest.NewRecorder()
nameID, err := conn.HandleSLO(w, req)
if err != nil {
t.Fatalf("HandleSLO failed: %v", err)
}
if nameID != tc.wantNameID {
t.Errorf("expected nameID %q, got %q", tc.wantNameID, nameID)
}
})
}
})
}

18
connector/saml/types.go

@ -275,3 +275,21 @@ func (a attribute) String() string {
// "groups" = ["engineering", "docs"]
return fmt.Sprintf("%q = %q", a.Name, values)
}
// logoutRequest represents a SAML 2.0 LogoutRequest message sent by the IdP
// to initiate Single Logout.
type logoutRequest struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:protocol LogoutRequest"`
ID string `xml:"ID,attr"`
Version string `xml:"Version,attr"`
IssueInstant xmlTime `xml:"IssueInstant,attr"`
Destination string `xml:"Destination,attr,omitempty"`
Issuer string `xml:"urn:oasis:names:tc:SAML:2.0:assertion Issuer"`
NameID nameID `xml:"urn:oasis:names:tc:SAML:2.0:assertion NameID"`
SessionIndex []sessionIndex `xml:"SessionIndex,omitempty"`
}
// sessionIndex represents a SAML SessionIndex element in a LogoutRequest.
type sessionIndex struct {
Value string `xml:",chardata"`
}

91
server/handlers.go

@ -1054,9 +1054,10 @@ func (s *Server) exchangeAuthCode(ctx context.Context, w http.ResponseWriter, au
return nil, err
}
offlineSessions := storage.OfflineSessions{
UserID: refresh.Claims.UserID,
ConnID: refresh.ConnectorID,
Refresh: make(map[string]*storage.RefreshTokenRef),
UserID: refresh.Claims.UserID,
ConnID: refresh.ConnectorID,
Refresh: make(map[string]*storage.RefreshTokenRef),
ConnectorData: refresh.ConnectorData,
}
offlineSessions.Refresh[tokenRef.ClientID] = &tokenRef
@ -1082,6 +1083,9 @@ func (s *Server) exchangeAuthCode(ctx context.Context, w http.ResponseWriter, au
// Update existing OfflineSession obj with new RefreshTokenRef.
if err := s.storage.UpdateOfflineSessions(ctx, session.UserID, session.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
old.Refresh[tokenRef.ClientID] = &tokenRef
if len(refresh.ConnectorData) > 0 {
old.ConnectorData = refresh.ConnectorData
}
return old, nil
}); err != nil {
s.logger.ErrorContext(ctx, "failed to update offline session", "err", err)
@ -1500,3 +1504,84 @@ func usernamePrompt(conn connector.PasswordConnector) string {
}
return "Username"
}
func (s *Server) handleSAMLSLO(w http.ResponseWriter, r *http.Request) {
connID, err := url.PathUnescape(mux.Vars(r)["connector"])
if err != nil {
s.logger.ErrorContext(r.Context(), "SAML SLO: failed to parse connector ID", "err", err)
http.Error(w, "Missing connector ID", http.StatusBadRequest)
return
}
ctx := r.Context()
conn, err := s.getConnector(ctx, connID)
if err != nil {
s.logger.ErrorContext(ctx, "SAML SLO: failed to get connector", "connector_id", connID, "err", err)
http.Error(w, "Connector not found", http.StatusNotFound)
return
}
sloConnector, ok := conn.Connector.(connector.SAMLSLOConnector)
if !ok {
s.logger.ErrorContext(ctx, "SAML SLO: connector does not support SLO", "connector_id", connID)
http.Error(w, "Connector does not support SLO", http.StatusBadRequest)
return
}
nameID, err := sloConnector.HandleSLO(w, r)
if err != nil {
s.logger.ErrorContext(ctx, "SAML SLO: failed to process LogoutRequest", "connector_id", connID, "err", err)
http.Error(w, "Failed to process LogoutRequest", http.StatusBadRequest)
return
}
s.logger.InfoContext(ctx, "SAML SLO: processing logout", "connector_id", connID, "name_id", nameID)
if err := s.invalidateUserSessions(ctx, connID, nameID); err != nil {
s.logger.ErrorContext(ctx, "SAML SLO: failed to invalidate sessions", "connector_id", connID, "name_id", nameID, "err", err)
http.Error(w, "Failed to invalidate sessions", http.StatusInternalServerError)
return
}
s.logger.InfoContext(ctx, "SAML SLO: successfully invalidated sessions", "connector_id", connID, "name_id", nameID)
w.WriteHeader(http.StatusOK)
}
// invalidateUserSessions removes all refresh tokens for a user identified by
// their NameID from a specific connector. It uses OfflineSessions to find
// the user's refresh tokens and deletes them.
func (s *Server) invalidateUserSessions(ctx context.Context, connectorID, nameID string) error {
// NameID from SAML is used as the user ID in the connector.
// OfflineSessions are keyed by (userID, connectorID).
// The userID in OfflineSessions is the connector-specific user ID (NameID for SAML).
offlineSessions, err := s.storage.GetOfflineSessions(ctx, nameID, connectorID)
if err != nil {
if err == storage.ErrNotFound {
s.logger.InfoContext(ctx, "SAML SLO: no offline sessions found", "name_id", nameID, "connector_id", connectorID)
return nil // No sessions to invalidate
}
return fmt.Errorf("failed to get offline sessions: %v", err)
}
// Delete all refresh tokens associated with this offline session
for _, tokenRef := range offlineSessions.Refresh {
if err := s.storage.DeleteRefresh(ctx, tokenRef.ID); err != nil {
if err == storage.ErrNotFound {
continue // Token already deleted
}
s.logger.ErrorContext(ctx, "SAML SLO: failed to delete refresh token", "token_id", tokenRef.ID, "err", err)
// Continue deleting other tokens even if one fails
}
}
// Delete the offline session itself
if err := s.storage.DeleteOfflineSessions(ctx, nameID, connectorID); err != nil {
if err != storage.ErrNotFound {
return fmt.Errorf("failed to delete offline sessions: %v", err)
}
}
return nil
}

460
server/handlers_test.go

@ -21,6 +21,8 @@ import (
"golang.org/x/crypto/bcrypt"
"golang.org/x/oauth2"
"github.com/dexidp/dex/connector"
"github.com/dexidp/dex/server/internal"
"github.com/dexidp/dex/storage"
)
@ -892,3 +894,461 @@ func setNonEmpty(vals url.Values, key, value string) {
vals.Set(key, value)
}
}
// mockSAMLSLOConnector implements connector.SAMLConnector and connector.SAMLSLOConnector for testing.
type mockSAMLSLOConnector struct {
sloNameID string
sloErr error
}
func (m *mockSAMLSLOConnector) POSTData(s connector.Scopes, requestID string) (ssoURL, samlRequest string, err error) {
return "", "", nil
}
func (m *mockSAMLSLOConnector) HandlePOST(s connector.Scopes, samlResponse, inResponseTo string) (connector.Identity, error) {
return connector.Identity{}, nil
}
func (m *mockSAMLSLOConnector) HandleSLO(w http.ResponseWriter, r *http.Request) (string, error) {
return m.sloNameID, m.sloErr
}
// mockNonSLOConnector implements connector.CallbackConnector but NOT SAMLSLOConnector.
type mockNonSLOConnector struct{}
func (m *mockNonSLOConnector) LoginURL(s connector.Scopes, callbackURL, state string) (string, error) {
return "", nil
}
func (m *mockNonSLOConnector) HandleCallback(s connector.Scopes, r *http.Request) (connector.Identity, error) {
return connector.Identity{}, nil
}
// registerTestConnector creates a connector in storage and registers it in the server's connectors map.
func registerTestConnector(t *testing.T, s *Server, connID string, c connector.Connector) {
t.Helper()
ctx := t.Context()
storageConn := storage.Connector{
ID: connID,
Type: "saml",
Name: "Test SAML",
ResourceVersion: "1",
}
if err := s.storage.CreateConnector(ctx, storageConn); err != nil {
t.Fatalf("failed to create connector in storage: %v", err)
}
s.mu.Lock()
s.connectors[connID] = Connector{
ResourceVersion: "1",
Connector: c,
}
s.mu.Unlock()
}
func TestHandleSAMLSLO(t *testing.T) {
t.Run("SuccessfulSLO", func(t *testing.T) {
httpServer, server := newTestServer(t, nil)
defer httpServer.Close()
ctx := t.Context()
connID := "saml-slo-test"
nameID := "user@example.com"
mockConn := &mockSAMLSLOConnector{
sloNameID: nameID,
}
registerTestConnector(t, server, connID, mockConn)
// Create refresh tokens and offline session
refreshToken1 := storage.RefreshToken{
ID: "refresh-token-1",
Token: "token-1",
CreatedAt: time.Now(),
LastUsed: time.Now(),
ClientID: "test-client",
ConnectorID: connID,
Claims: storage.Claims{
UserID: nameID,
Username: "testuser",
Email: nameID,
},
Nonce: "nonce-1",
}
refreshToken2 := storage.RefreshToken{
ID: "refresh-token-2",
Token: "token-2",
CreatedAt: time.Now(),
LastUsed: time.Now(),
ClientID: "test-client-2",
ConnectorID: connID,
Claims: storage.Claims{
UserID: nameID,
Username: "testuser",
Email: nameID,
},
Nonce: "nonce-2",
}
require.NoError(t, server.storage.CreateRefresh(ctx, refreshToken1))
require.NoError(t, server.storage.CreateRefresh(ctx, refreshToken2))
offlineSession := storage.OfflineSessions{
UserID: nameID,
ConnID: connID,
Refresh: map[string]*storage.RefreshTokenRef{
refreshToken1.ClientID: {
ID: refreshToken1.ID,
ClientID: refreshToken1.ClientID,
CreatedAt: refreshToken1.CreatedAt,
LastUsed: refreshToken1.LastUsed,
},
refreshToken2.ClientID: {
ID: refreshToken2.ID,
ClientID: refreshToken2.ClientID,
CreatedAt: refreshToken2.CreatedAt,
LastUsed: refreshToken2.LastUsed,
},
},
}
require.NoError(t, server.storage.CreateOfflineSessions(ctx, offlineSession))
// Send POST to /saml/slo/{connector}
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/saml/slo/"+connID, nil)
server.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code, "expected HTTP 200, got %d: %s", rr.Code, rr.Body.String())
// Verify refresh tokens are deleted
_, err := server.storage.GetRefresh(ctx, refreshToken1.ID)
require.ErrorIs(t, err, storage.ErrNotFound, "refresh token 1 should be deleted")
_, err = server.storage.GetRefresh(ctx, refreshToken2.ID)
require.ErrorIs(t, err, storage.ErrNotFound, "refresh token 2 should be deleted")
// Verify offline session is deleted
_, err = server.storage.GetOfflineSessions(ctx, nameID, connID)
require.ErrorIs(t, err, storage.ErrNotFound, "offline session should be deleted")
})
t.Run("NoExistingSessions", func(t *testing.T) {
httpServer, server := newTestServer(t, nil)
defer httpServer.Close()
connID := "saml-slo-nosession"
nameID := "nouser@example.com"
mockConn := &mockSAMLSLOConnector{
sloNameID: nameID,
}
registerTestConnector(t, server, connID, mockConn)
// No refresh tokens or offline sessions created — should handle gracefully
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/saml/slo/"+connID, nil)
server.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code, "expected HTTP 200 for no sessions, got %d: %s", rr.Code, rr.Body.String())
})
t.Run("ConnectorNotFound", func(t *testing.T) {
httpServer, server := newTestServer(t, nil)
defer httpServer.Close()
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/saml/slo/nonexistent-connector", nil)
server.ServeHTTP(rr, req)
require.Equal(t, http.StatusNotFound, rr.Code, "expected HTTP 404 for missing connector")
})
t.Run("ConnectorDoesNotSupportSLO", func(t *testing.T) {
httpServer, server := newTestServer(t, nil)
defer httpServer.Close()
connID := "saml-no-slo"
mockConn := &mockNonSLOConnector{}
registerTestConnector(t, server, connID, mockConn)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/saml/slo/"+connID, nil)
server.ServeHTTP(rr, req)
require.Equal(t, http.StatusBadRequest, rr.Code, "expected HTTP 400 for connector without SLO support")
})
t.Run("HandleSLOError", func(t *testing.T) {
httpServer, server := newTestServer(t, nil)
defer httpServer.Close()
connID := "saml-slo-error"
mockConn := &mockSAMLSLOConnector{
sloNameID: "",
sloErr: errors.New("invalid SAMLRequest"),
}
registerTestConnector(t, server, connID, mockConn)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/saml/slo/"+connID, nil)
server.ServeHTTP(rr, req)
require.Equal(t, http.StatusBadRequest, rr.Code, "expected HTTP 400 for invalid SAMLRequest")
})
t.Run("MultipleTokensPartialDeletion", func(t *testing.T) {
httpServer, server := newTestServer(t, nil)
defer httpServer.Close()
ctx := t.Context()
connID := "saml-slo-partial"
nameID := "partial@example.com"
mockConn := &mockSAMLSLOConnector{
sloNameID: nameID,
}
registerTestConnector(t, server, connID, mockConn)
// Create only one refresh token but reference two in offline session
// (simulating a token that was already deleted)
refreshToken := storage.RefreshToken{
ID: "existing-token",
Token: "token-existing",
CreatedAt: time.Now(),
LastUsed: time.Now(),
ClientID: "client-existing",
ConnectorID: connID,
Claims: storage.Claims{
UserID: nameID,
Username: "partialuser",
Email: nameID,
},
Nonce: "nonce-existing",
}
require.NoError(t, server.storage.CreateRefresh(ctx, refreshToken))
offlineSession := storage.OfflineSessions{
UserID: nameID,
ConnID: connID,
Refresh: map[string]*storage.RefreshTokenRef{
"client-existing": {
ID: "existing-token",
ClientID: "client-existing",
CreatedAt: time.Now(),
LastUsed: time.Now(),
},
"client-missing": {
ID: "already-deleted-token",
ClientID: "client-missing",
CreatedAt: time.Now(),
LastUsed: time.Now(),
},
},
}
require.NoError(t, server.storage.CreateOfflineSessions(ctx, offlineSession))
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/saml/slo/"+connID, nil)
server.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Code, "expected HTTP 200 even with partial deletion")
// Verify existing token is deleted
_, err := server.storage.GetRefresh(ctx, "existing-token")
require.ErrorIs(t, err, storage.ErrNotFound, "existing refresh token should be deleted")
// Verify offline session is deleted
_, err = server.storage.GetOfflineSessions(ctx, nameID, connID)
require.ErrorIs(t, err, storage.ErrNotFound, "offline session should be deleted")
})
t.Run("RefreshAfterSLO", func(t *testing.T) {
// Test that refresh token is rejected after SLO invalidation.
// This is the key integration test: SLO → refresh → error.
httpServer, server := newTestServer(t, func(c *Config) {
c.RefreshTokenPolicy = &RefreshTokenPolicy{rotateRefreshTokens: true}
})
defer httpServer.Close()
ctx := t.Context()
connID := "saml-slo-refresh"
nameID := "slo-refresh@example.com"
mockConn := &mockSAMLSLOConnector{
sloNameID: nameID,
}
registerTestConnector(t, server, connID, mockConn)
// Create client for refresh token request
client := storage.Client{
ID: "slo-test-client",
Secret: "slo-test-secret",
RedirectURIs: []string{"https://example.com/callback"},
Name: "SLO Test Client",
}
require.NoError(t, server.storage.CreateClient(ctx, client))
// Create refresh token
refreshToken := storage.RefreshToken{
ID: "slo-refresh-token",
Token: "slo-token-value",
CreatedAt: time.Now(),
LastUsed: time.Now(),
ClientID: client.ID,
ConnectorID: connID,
Scopes: []string{"openid", "email", "offline_access"},
Claims: storage.Claims{
UserID: nameID,
Username: "slouser",
Email: nameID,
EmailVerified: true,
},
ConnectorData: []byte(`{"userID":"` + nameID + `","username":"slouser","email":"` + nameID + `","emailVerified":true}`),
Nonce: "slo-nonce",
}
require.NoError(t, server.storage.CreateRefresh(ctx, refreshToken))
offlineSession := storage.OfflineSessions{
UserID: nameID,
ConnID: connID,
Refresh: map[string]*storage.RefreshTokenRef{
client.ID: {
ID: refreshToken.ID,
ClientID: client.ID,
CreatedAt: refreshToken.CreatedAt,
LastUsed: refreshToken.LastUsed,
},
},
}
require.NoError(t, server.storage.CreateOfflineSessions(ctx, offlineSession))
// Step 1: Verify refresh token exists before SLO
_, err := server.storage.GetRefresh(ctx, refreshToken.ID)
require.NoError(t, err, "refresh token should exist before SLO")
// Step 2: Send SLO request
sloRR := httptest.NewRecorder()
sloReq := httptest.NewRequest(http.MethodPost, "/saml/slo/"+connID, nil)
server.ServeHTTP(sloRR, sloReq)
require.Equal(t, http.StatusOK, sloRR.Code, "SLO should succeed")
// Step 3: Verify refresh token is deleted
_, err = server.storage.GetRefresh(ctx, refreshToken.ID)
require.ErrorIs(t, err, storage.ErrNotFound, "refresh token should be deleted after SLO")
// Step 4: Attempt to use the refresh token — should fail
tokenData, err := internal.Marshal(&internal.RefreshToken{RefreshId: refreshToken.ID, Token: refreshToken.Token})
require.NoError(t, err)
u, err := url.Parse(server.issuerURL.String())
require.NoError(t, err)
u.Path = path.Join(u.Path, "/token")
v := url.Values{}
v.Add("grant_type", "refresh_token")
v.Add("refresh_token", tokenData)
refreshReq, _ := http.NewRequest("POST", u.String(), bytes.NewBufferString(v.Encode()))
refreshReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
refreshReq.SetBasicAuth(client.ID, client.Secret)
refreshRR := httptest.NewRecorder()
server.ServeHTTP(refreshRR, refreshReq)
// Refresh should fail because the token was deleted by SLO
require.NotEqual(t, http.StatusOK, refreshRR.Code,
"refresh should fail after SLO, got status %d: %s", refreshRR.Code, refreshRR.Body.String())
})
t.Run("ConnectorDataPersistence", func(t *testing.T) {
// Test that ConnectorData is correctly stored in refresh token
// and can be used for subsequent refresh operations.
httpServer, server := newTestServer(t, func(c *Config) {
c.RefreshTokenPolicy = &RefreshTokenPolicy{rotateRefreshTokens: true}
})
defer httpServer.Close()
ctx := t.Context()
connID := "saml-conndata"
// Create a mock SAML connector that also implements RefreshConnector
mockConn := &mockSAMLRefreshConnector{
refreshIdentity: connector.Identity{
UserID: "refreshed-user",
Username: "refreshed-name",
Email: "refreshed@example.com",
EmailVerified: true,
Groups: []string{"refreshed-group"},
},
}
registerTestConnector(t, server, connID, mockConn)
// Create client
client := storage.Client{
ID: "conndata-client",
Secret: "conndata-secret",
RedirectURIs: []string{"https://example.com/callback"},
Name: "ConnData Test Client",
}
require.NoError(t, server.storage.CreateClient(ctx, client))
// Create refresh token with ConnectorData (simulating what HandlePOST would store)
connectorData := []byte(`{"userID":"user-123","username":"testuser","email":"test@example.com","emailVerified":true,"groups":["admin","dev"]}`)
refreshToken := storage.RefreshToken{
ID: "conndata-refresh",
Token: "conndata-token",
CreatedAt: time.Now(),
LastUsed: time.Now(),
ClientID: client.ID,
ConnectorID: connID,
Scopes: []string{"openid", "email", "offline_access"},
Claims: storage.Claims{
UserID: "user-123",
Username: "testuser",
Email: "test@example.com",
EmailVerified: true,
Groups: []string{"admin", "dev"},
},
ConnectorData: connectorData,
Nonce: "conndata-nonce",
}
require.NoError(t, server.storage.CreateRefresh(ctx, refreshToken))
offlineSession := storage.OfflineSessions{
UserID: "user-123",
ConnID: connID,
Refresh: map[string]*storage.RefreshTokenRef{client.ID: {ID: refreshToken.ID, ClientID: client.ID}},
ConnectorData: connectorData,
}
require.NoError(t, server.storage.CreateOfflineSessions(ctx, offlineSession))
// Verify ConnectorData is stored correctly
storedToken, err := server.storage.GetRefresh(ctx, refreshToken.ID)
require.NoError(t, err)
require.Equal(t, connectorData, storedToken.ConnectorData,
"ConnectorData should be persisted in refresh token storage")
// Verify ConnectorData is stored in offline session
storedSession, err := server.storage.GetOfflineSessions(ctx, "user-123", connID)
require.NoError(t, err)
require.Equal(t, connectorData, storedSession.ConnectorData,
"ConnectorData should be persisted in offline session storage")
})
}
// mockSAMLRefreshConnector implements SAMLConnector + RefreshConnector for testing.
type mockSAMLRefreshConnector struct {
refreshIdentity connector.Identity
}
func (m *mockSAMLRefreshConnector) POSTData(s connector.Scopes, requestID string) (ssoURL, samlRequest string, err error) {
return "", "", nil
}
func (m *mockSAMLRefreshConnector) HandlePOST(s connector.Scopes, samlResponse, inResponseTo string) (connector.Identity, error) {
return connector.Identity{}, nil
}
func (m *mockSAMLRefreshConnector) Refresh(ctx context.Context, s connector.Scopes, ident connector.Identity) (connector.Identity, error) {
return m.refreshIdentity, nil
}

1
server/server.go

@ -494,6 +494,7 @@ func newServer(ctx context.Context, c Config) (*Server, error) {
// For easier connector-specific web server configuration, e.g. for the
// "authproxy" connector.
handleFunc("/callback/{connector}", s.handleConnectorCallback)
handleFunc("/saml/slo/{connector}", s.handleSAMLSLO)
handleFunc("/approval", s.handleApproval)
handle("/healthz", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !c.HealthChecker.IsHealthy() {

Loading…
Cancel
Save