|
|
|
|
@ -2,10 +2,8 @@
|
|
|
|
|
package saml |
|
|
|
|
|
|
|
|
|
import ( |
|
|
|
|
"crypto/rand" |
|
|
|
|
"crypto/x509" |
|
|
|
|
"encoding/base64" |
|
|
|
|
"encoding/hex" |
|
|
|
|
"encoding/pem" |
|
|
|
|
"encoding/xml" |
|
|
|
|
"errors" |
|
|
|
|
@ -270,12 +268,22 @@ func (p *provider) POSTData(s connector.Scopes, id string) (action, value string
|
|
|
|
|
return p.ssoURL, base64.StdEncoding.EncodeToString(data), nil |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// HandlePOST interprets a request from a SAML provider attempting to verify a
|
|
|
|
|
// user's identity.
|
|
|
|
|
//
|
|
|
|
|
// The steps taken are:
|
|
|
|
|
//
|
|
|
|
|
// * Verify signature on XML document (or verify sig on assertion elements).
|
|
|
|
|
// * Verify various parts of the Assertion element. Conditions, audience, etc.
|
|
|
|
|
// * Map the Assertion's attribute elements to user info.
|
|
|
|
|
//
|
|
|
|
|
func (p *provider) HandlePOST(s connector.Scopes, samlResponse, inResponseTo string) (ident connector.Identity, err error) { |
|
|
|
|
rawResp, err := base64.StdEncoding.DecodeString(samlResponse) |
|
|
|
|
if err != nil { |
|
|
|
|
return ident, fmt.Errorf("decode response: %v", err) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Root element is allowed to not be signed if the Assertion element is.
|
|
|
|
|
rootElementSigned := true |
|
|
|
|
if p.validator != nil { |
|
|
|
|
rawResp, rootElementSigned, err = verifyResponseSig(p.validator, rawResp) |
|
|
|
|
@ -289,6 +297,8 @@ func (p *provider) HandlePOST(s connector.Scopes, samlResponse, inResponseTo str
|
|
|
|
|
return ident, fmt.Errorf("unmarshal response: %v", err) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// If the root element isn't signed, there's no reason to inspect these
|
|
|
|
|
// elements. They're not verified.
|
|
|
|
|
if rootElementSigned { |
|
|
|
|
if p.ssoIssuer != "" && resp.Issuer != nil && resp.Issuer.Issuer != p.ssoIssuer { |
|
|
|
|
return ident, fmt.Errorf("expected Issuer value %s, got %s", p.ssoIssuer, resp.Issuer.Issuer) |
|
|
|
|
@ -303,10 +313,14 @@ func (p *provider) HandlePOST(s connector.Scopes, samlResponse, inResponseTo str
|
|
|
|
|
// Destination is optional.
|
|
|
|
|
if resp.Destination != "" && resp.Destination != p.redirectURI { |
|
|
|
|
return ident, fmt.Errorf("expected destination %q got %q", p.redirectURI, resp.Destination) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Status is a required element.
|
|
|
|
|
if resp.Status == nil { |
|
|
|
|
return ident, fmt.Errorf("Response did not contain a Status element") |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if err = p.validateStatus(&resp); err != nil { |
|
|
|
|
if err = p.validateStatus(resp.Status); err != nil { |
|
|
|
|
return ident, err |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
@ -315,16 +329,25 @@ func (p *provider) HandlePOST(s connector.Scopes, samlResponse, inResponseTo str
|
|
|
|
|
if assertion == nil { |
|
|
|
|
return ident, fmt.Errorf("response did not contain an assertion") |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Subject is usually optional, but we need it for the user ID, so complain
|
|
|
|
|
// if it's not present.
|
|
|
|
|
subject := assertion.Subject |
|
|
|
|
if subject == nil { |
|
|
|
|
return ident, fmt.Errorf("response did not contain a subject") |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if err = p.validateConditions(assertion); err != nil { |
|
|
|
|
// Validate that the response is to the request we originally sent.
|
|
|
|
|
if err = p.validateSubject(subject, inResponseTo); err != nil { |
|
|
|
|
return ident, err |
|
|
|
|
} |
|
|
|
|
if err = p.validateSubjectConfirmation(subject); err != nil { |
|
|
|
|
return ident, err |
|
|
|
|
|
|
|
|
|
// Conditions element is optional, but must be validated if present.
|
|
|
|
|
if assertion.Conditions != nil { |
|
|
|
|
// Validate that dex is the intended audience of this response.
|
|
|
|
|
if err = p.validateConditions(assertion.Conditions); err != nil { |
|
|
|
|
return ident, err |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
switch { |
|
|
|
|
@ -336,53 +359,57 @@ func (p *provider) HandlePOST(s connector.Scopes, samlResponse, inResponseTo str
|
|
|
|
|
return ident, fmt.Errorf("subject does not contain an NameID element") |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// After verifying the assertion, map data in the attribute statements to
|
|
|
|
|
// various user info.
|
|
|
|
|
attributes := assertion.AttributeStatement |
|
|
|
|
if attributes == nil { |
|
|
|
|
return ident, fmt.Errorf("response did not contain a AttributeStatement") |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Grab the email.
|
|
|
|
|
if ident.Email, _ = attributes.get(p.emailAttr); ident.Email == "" { |
|
|
|
|
return ident, fmt.Errorf("no attribute with name %q: %s", p.emailAttr, attributes.names()) |
|
|
|
|
} |
|
|
|
|
// TODO(ericchiang): Does SAML have an email_verified equivalent?
|
|
|
|
|
ident.EmailVerified = true |
|
|
|
|
|
|
|
|
|
// Grab the username.
|
|
|
|
|
if ident.Username, _ = attributes.get(p.usernameAttr); ident.Username == "" { |
|
|
|
|
return ident, fmt.Errorf("no attribute with name %q: %s", p.usernameAttr, attributes.names()) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if s.Groups && p.groupsAttr != "" { |
|
|
|
|
if p.groupsDelim != "" { |
|
|
|
|
groupsStr, ok := attributes.get(p.groupsAttr) |
|
|
|
|
if !ok { |
|
|
|
|
return ident, fmt.Errorf("no attribute with name %q: %s", p.groupsAttr, attributes.names()) |
|
|
|
|
} |
|
|
|
|
// TODO(ericchiang): Do we need to further trim whitespace?
|
|
|
|
|
ident.Groups = strings.Split(groupsStr, p.groupsDelim) |
|
|
|
|
} else { |
|
|
|
|
groups, ok := attributes.all(p.groupsAttr) |
|
|
|
|
if !ok { |
|
|
|
|
return ident, fmt.Errorf("no attribute with name %q: %s", p.groupsAttr, attributes.names()) |
|
|
|
|
} |
|
|
|
|
ident.Groups = groups |
|
|
|
|
} |
|
|
|
|
if !s.Groups || p.groupsAttr == "" { |
|
|
|
|
// Groups not requested or not configured. We're done.
|
|
|
|
|
return ident, nil |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Grab the groups.
|
|
|
|
|
if p.groupsDelim != "" { |
|
|
|
|
groupsStr, ok := attributes.get(p.groupsAttr) |
|
|
|
|
if !ok { |
|
|
|
|
return ident, fmt.Errorf("no attribute with name %q: %s", p.groupsAttr, attributes.names()) |
|
|
|
|
} |
|
|
|
|
// TODO(ericchiang): Do we need to further trim whitespace?
|
|
|
|
|
ident.Groups = strings.Split(groupsStr, p.groupsDelim) |
|
|
|
|
} else { |
|
|
|
|
groups, ok := attributes.all(p.groupsAttr) |
|
|
|
|
if !ok { |
|
|
|
|
return ident, fmt.Errorf("no attribute with name %q: %s", p.groupsAttr, attributes.names()) |
|
|
|
|
} |
|
|
|
|
ident.Groups = groups |
|
|
|
|
} |
|
|
|
|
return ident, nil |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Validate that the StatusCode of the Response is success.
|
|
|
|
|
// Otherwise return a human readable message to the end user
|
|
|
|
|
func (p *provider) validateStatus(resp *response) error { |
|
|
|
|
// Status is mandatory in the Response type
|
|
|
|
|
status := resp.Status |
|
|
|
|
if status == nil { |
|
|
|
|
return fmt.Errorf("response did not contain a Status") |
|
|
|
|
} |
|
|
|
|
// validateStatus verifies that the response has a good status code or
|
|
|
|
|
// formats a human readble error based on the bad status.
|
|
|
|
|
func (p *provider) validateStatus(status *status) error { |
|
|
|
|
// StatusCode is mandatory in the Status type
|
|
|
|
|
statusCode := status.StatusCode |
|
|
|
|
if statusCode == nil { |
|
|
|
|
return fmt.Errorf("response did not contain a StatusCode") |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if statusCode.Value != statusCodeSuccess { |
|
|
|
|
parts := strings.Split(statusCode.Value, ":") |
|
|
|
|
lastPart := parts[len(parts)-1] |
|
|
|
|
@ -396,96 +423,107 @@ func (p *provider) validateStatus(resp *response) error {
|
|
|
|
|
return nil |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Multiple subject SubjectConfirmation can be in the assertion
|
|
|
|
|
// and at least one SubjectConfirmation must be valid.
|
|
|
|
|
// validateSubject ensures the response is to the request we expect.
|
|
|
|
|
//
|
|
|
|
|
// This is described in the spec "Profiles for the OASIS Security
|
|
|
|
|
// Assertion Markup Language" in section 3.3 Bearer.
|
|
|
|
|
// see https://www.oasis-open.org/committees/download.php/35389/sstc-saml-profiles-errata-2.0-wd-06-diff.pdf
|
|
|
|
|
func (p *provider) validateSubjectConfirmation(subject *subject) error { |
|
|
|
|
validSubjectConfirmation := false |
|
|
|
|
subjectConfirmations := subject.SubjectConfirmations |
|
|
|
|
if subjectConfirmations != nil && len(subjectConfirmations) > 0 { |
|
|
|
|
for _, subjectConfirmation := range subjectConfirmations { |
|
|
|
|
// skip if method is wrong
|
|
|
|
|
method := subjectConfirmation.Method |
|
|
|
|
if method != "" && method != subjectConfirmationMethodBearer { |
|
|
|
|
continue |
|
|
|
|
//
|
|
|
|
|
// Some of these fields are optional, but we're going to be strict here since
|
|
|
|
|
// we have no other way of guarenteeing that this is actually the response to
|
|
|
|
|
// the request we expect.
|
|
|
|
|
func (p *provider) validateSubject(subject *subject, inResponseTo string) error { |
|
|
|
|
// Optional according to the spec, but again, we're going to be strict here.
|
|
|
|
|
if len(subject.SubjectConfirmations) == 0 { |
|
|
|
|
return fmt.Errorf("Subject contained no SubjectConfrimations") |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
var errs []error |
|
|
|
|
// One of these must match our assumptions, not all.
|
|
|
|
|
for _, c := range subject.SubjectConfirmations { |
|
|
|
|
err := func() error { |
|
|
|
|
if c.Method != subjectConfirmationMethodBearer { |
|
|
|
|
return fmt.Errorf("unexpected subject confirmation method: %v", c.Method) |
|
|
|
|
} |
|
|
|
|
subjectConfirmationData := subjectConfirmation.SubjectConfirmationData |
|
|
|
|
if subjectConfirmationData == nil { |
|
|
|
|
continue |
|
|
|
|
|
|
|
|
|
data := c.SubjectConfirmationData |
|
|
|
|
if data == nil { |
|
|
|
|
return fmt.Errorf("SubjectConfirmation contained no SubjectConfirmationData") |
|
|
|
|
} |
|
|
|
|
inResponseTo := subjectConfirmationData.InResponseTo |
|
|
|
|
if inResponseTo != "" { |
|
|
|
|
// TODO also validate InResponseTo if present
|
|
|
|
|
if data.InResponseTo != inResponseTo { |
|
|
|
|
return fmt.Errorf("expected SubjectConfirmationData InResponseTo value %q, got %q", inResponseTo, data.InResponseTo) |
|
|
|
|
} |
|
|
|
|
// only validate that subjectConfirmationData is not expired
|
|
|
|
|
|
|
|
|
|
notBefore := time.Time(data.NotBefore) |
|
|
|
|
notOnOrAfter := time.Time(data.NotOnOrAfter) |
|
|
|
|
now := p.now() |
|
|
|
|
notOnOrAfter := time.Time(subjectConfirmationData.NotOnOrAfter) |
|
|
|
|
if !notOnOrAfter.IsZero() { |
|
|
|
|
if now.After(notOnOrAfter) { |
|
|
|
|
continue |
|
|
|
|
} |
|
|
|
|
if !notBefore.IsZero() && before(now, notBefore) { |
|
|
|
|
return fmt.Errorf("at %s got response that cannot be processed before %s", now, notBefore) |
|
|
|
|
} |
|
|
|
|
// validate recipient if present
|
|
|
|
|
recipient := subjectConfirmationData.Recipient |
|
|
|
|
if recipient != "" && recipient != p.redirectURI { |
|
|
|
|
continue |
|
|
|
|
if !notOnOrAfter.IsZero() && after(now, notOnOrAfter) { |
|
|
|
|
return fmt.Errorf("at %s got response that cannot be processed because it expired at %s", now, notOnOrAfter) |
|
|
|
|
} |
|
|
|
|
validSubjectConfirmation = true |
|
|
|
|
if r := data.Recipient; r != "" && r != p.redirectURI { |
|
|
|
|
return fmt.Errorf("expected Recipient %q got %q", p.redirectURI, r) |
|
|
|
|
} |
|
|
|
|
return nil |
|
|
|
|
}() |
|
|
|
|
if err == nil { |
|
|
|
|
// Subject is valid.
|
|
|
|
|
return nil |
|
|
|
|
} |
|
|
|
|
errs = append(errs, err) |
|
|
|
|
} |
|
|
|
|
if !validSubjectConfirmation { |
|
|
|
|
return fmt.Errorf("no valid SubjectConfirmation was found on this Response") |
|
|
|
|
|
|
|
|
|
if len(errs) == 1 { |
|
|
|
|
return fmt.Errorf("failed to validate subject confirmation: %v", errs[0]) |
|
|
|
|
} |
|
|
|
|
return nil |
|
|
|
|
return fmt.Errorf("failed to validate subject confirmation: %v", errs) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Validates the Conditions element and all of it's content
|
|
|
|
|
// validationConditions ensures that dex is the intended audience
|
|
|
|
|
// for the request, and not another service provider.
|
|
|
|
|
//
|
|
|
|
|
// See: https://docs.oasis-open.org/security/saml/v2.0/saml-core-2.0-os.pdf
|
|
|
|
|
// "2.3.3 Element <Assertion>"
|
|
|
|
|
func (p *provider) validateConditions(assertion *assertion) error { |
|
|
|
|
// Checks if a Conditions element exists
|
|
|
|
|
conditions := assertion.Conditions |
|
|
|
|
if conditions == nil { |
|
|
|
|
return nil |
|
|
|
|
} |
|
|
|
|
// Validates Assertion timestamps
|
|
|
|
|
func (p *provider) validateConditions(conditions *conditions) error { |
|
|
|
|
// Ensure the conditions haven't expired.
|
|
|
|
|
now := p.now() |
|
|
|
|
notBefore := time.Time(conditions.NotBefore) |
|
|
|
|
if !notBefore.IsZero() { |
|
|
|
|
if now.Add(allowedClockDrift).Before(notBefore) { |
|
|
|
|
return fmt.Errorf("at %s got response that cannot be processed before %s", now, notBefore) |
|
|
|
|
} |
|
|
|
|
if !notBefore.IsZero() && before(now, notBefore) { |
|
|
|
|
return fmt.Errorf("at %s got response that cannot be processed before %s", now, notBefore) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
notOnOrAfter := time.Time(conditions.NotOnOrAfter) |
|
|
|
|
if !notOnOrAfter.IsZero() { |
|
|
|
|
if now.After(notOnOrAfter.Add(allowedClockDrift)) { |
|
|
|
|
return fmt.Errorf("at %s got response that cannot be processed because it expired at %s", now, notOnOrAfter) |
|
|
|
|
} |
|
|
|
|
if !notOnOrAfter.IsZero() && after(now, notOnOrAfter) { |
|
|
|
|
return fmt.Errorf("at %s got response that cannot be processed because it expired at %s", now, notOnOrAfter) |
|
|
|
|
} |
|
|
|
|
// Validates audience
|
|
|
|
|
audienceValue := p.entityIssuer |
|
|
|
|
if audienceValue == "" { |
|
|
|
|
audienceValue = p.redirectURI |
|
|
|
|
} |
|
|
|
|
audienceRestriction := conditions.AudienceRestriction |
|
|
|
|
if audienceRestriction != nil { |
|
|
|
|
audiences := audienceRestriction.Audiences |
|
|
|
|
if audiences != nil && len(audiences) > 0 { |
|
|
|
|
values := make([]string, len(audiences)) |
|
|
|
|
issuerInAudiences := false |
|
|
|
|
for i, audience := range audiences { |
|
|
|
|
if audience.Value == audienceValue { |
|
|
|
|
issuerInAudiences = true |
|
|
|
|
break |
|
|
|
|
} |
|
|
|
|
values[i] = audience.Value |
|
|
|
|
} |
|
|
|
|
if !issuerInAudiences { |
|
|
|
|
return fmt.Errorf("required audience %s was not in Response audiences %s", audienceValue, values) |
|
|
|
|
|
|
|
|
|
// Sometimes, dex's issuer string can be different than the redirect URI,
|
|
|
|
|
// but if dex's issuer isn't explicitly provided assume the redirect URI.
|
|
|
|
|
expAud := p.entityIssuer |
|
|
|
|
if expAud == "" { |
|
|
|
|
expAud = p.redirectURI |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// AudienceRestriction elements indicate the intended audience(s) of an
|
|
|
|
|
// assertion. If dex isn't in these audiences, reject the assertion.
|
|
|
|
|
//
|
|
|
|
|
// Note that if there are multiple AudienceRestriction elements, each must
|
|
|
|
|
// individually contain dex in their audience list.
|
|
|
|
|
for _, r := range conditions.AudienceRestriction { |
|
|
|
|
values := make([]string, len(r.Audiences)) |
|
|
|
|
issuerInAudiences := false |
|
|
|
|
for i, aud := range r.Audiences { |
|
|
|
|
if aud.Value == expAud { |
|
|
|
|
issuerInAudiences = true |
|
|
|
|
break |
|
|
|
|
} |
|
|
|
|
values[i] = aud.Value |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if !issuerInAudiences { |
|
|
|
|
return fmt.Errorf("required audience %s was not in Response audiences %s", expAud, values) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
return nil |
|
|
|
|
@ -544,24 +582,14 @@ func verifyResponseSig(validator *dsig.ValidationContext, data []byte) (signed [
|
|
|
|
|
return signed, false, err |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func uuidv4() string { |
|
|
|
|
u := make([]byte, 16) |
|
|
|
|
if _, err := rand.Read(u); err != nil { |
|
|
|
|
panic(err) |
|
|
|
|
} |
|
|
|
|
u[6] = (u[6] | 0x40) & 0x4F |
|
|
|
|
u[8] = (u[8] | 0x80) & 0xBF |
|
|
|
|
|
|
|
|
|
r := make([]byte, 36) |
|
|
|
|
r[8] = '-' |
|
|
|
|
r[13] = '-' |
|
|
|
|
r[18] = '-' |
|
|
|
|
r[23] = '-' |
|
|
|
|
hex.Encode(r, u[0:4]) |
|
|
|
|
hex.Encode(r[9:], u[4:6]) |
|
|
|
|
hex.Encode(r[14:], u[6:8]) |
|
|
|
|
hex.Encode(r[19:], u[8:10]) |
|
|
|
|
hex.Encode(r[24:], u[10:]) |
|
|
|
|
|
|
|
|
|
return string(r) |
|
|
|
|
// before determines if a given time is before the current time, with an
|
|
|
|
|
// allowed clock drift.
|
|
|
|
|
func before(now, notBefore time.Time) bool { |
|
|
|
|
return now.Add(allowedClockDrift).Before(notBefore) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// after determines if a given time is after the current time, with an
|
|
|
|
|
// allowed clock drift.
|
|
|
|
|
func after(now, notOnOrAfter time.Time) bool { |
|
|
|
|
return now.After(notOnOrAfter.Add(allowedClockDrift)) |
|
|
|
|
} |
|
|
|
|
|