diff --git a/connector/ldap/ldap.go b/connector/ldap/ldap.go index 0e3d26f0..aa5d785c 100644 --- a/connector/ldap/ldap.go +++ b/connector/ldap/ldap.go @@ -37,7 +37,9 @@ import ( // # Would translate to the query "(&(objectClass=person)(|(uid=)(mail=)))" // baseDN: cn=users,dc=example,dc=com // filter: "(objectClass=person)" -// username: uid,mail +// username: +// - uid +// - mail // idAttr: uid // emailAttr: mail // nameAttr: name @@ -58,6 +60,27 @@ import ( // nameAttr: name // +// UsernameAttributes represents one or more LDAP attributes to match against +// the username input. It supports unmarshaling from both a single string +// (e.g. "uid") and a list of strings (e.g. ["uid", "mail"]). +type UsernameAttributes []string + +func (u *UsernameAttributes) UnmarshalJSON(data []byte) error { + var arr []string + if err := json.Unmarshal(data, &arr); err == nil { + *u = arr + return nil + } + var s string + if err := json.Unmarshal(data, &s); err != nil { + return fmt.Errorf("username must be a string or list of strings") + } + if s != "" { + *u = UsernameAttributes{s} + } + return nil +} + // UserMatcher holds information about user and group matching. type UserMatcher struct { UserAttr string `json:"userAttr"` @@ -110,9 +133,10 @@ type Config struct { // Optional filter to apply when searching the directory. For example "(objectClass=person)" Filter string `json:"filter"` - // Attributes (comma-separated) to match (OR)against the inputted username. This will be translated and combined - // with the other filter as "(|(=)(=))". - Username string `json:"username"` + // Attribute(s) to match against the inputted username. Accepts a single string + // or a list of strings. When multiple attributes are specified, an OR filter is + // constructed: "(|(=)(=))". + Username UsernameAttributes `json:"username"` // Can either be: // * "sub" - search the whole sub tree @@ -240,7 +264,6 @@ func (c *Config) openConnector(logger *slog.Logger) (*ldapConnector, error) { }{ {"host", c.Host}, {"userSearch.baseDN", c.UserSearch.BaseDN}, - {"userSearch.username", c.UserSearch.Username}, } for _, field := range requiredFields { @@ -249,6 +272,10 @@ func (c *Config) openConnector(logger *slog.Logger) (*ldapConnector, error) { } } + if len(c.UserSearch.Username) == 0 { + return nil, fmt.Errorf("ldap: missing required field %q", "userSearch.username") + } + var ( host string err error @@ -296,7 +323,7 @@ func (c *Config) openConnector(logger *slog.Logger) (*ldapConnector, error) { // TODO(nabokihms): remove it after deleting deprecated groupSearch options c.GroupSearch.UserMatchers = userMatchers(c, logger) - return &ldapConnector{*c, userSearchScope, groupSearchScope, tlsConfig, logger}, nil + return &ldapConnector{*c, userSearchScope, groupSearchScope, tlsConfig, c.UserSearch.Username, logger}, nil } type ldapConnector struct { @@ -307,6 +334,8 @@ type ldapConnector struct { tlsConfig *tls.Config + usernameAttrs []string + logger *slog.Logger } @@ -422,15 +451,9 @@ func (c *ldapConnector) userEntry(conn *ldap.Conn, username string) (user ldap.E var filter string escapedUsername := ldap.EscapeFilter(username) - // Split username attribute by comma to support multiple search attributes - usernameAttrs := strings.Split(c.UserSearch.Username, ",") - - attrFilters := make([]string, 0, len(usernameAttrs)) - for _, attr := range usernameAttrs { - attr = strings.TrimSpace(attr) - if attr != "" { - attrFilters = append(attrFilters, fmt.Sprintf("(%s=%s)", attr, escapedUsername)) - } + attrFilters := make([]string, 0, len(c.usernameAttrs)) + for _, attr := range c.usernameAttrs { + attrFilters = append(attrFilters, fmt.Sprintf("(%s=%s)", attr, escapedUsername)) } if len(attrFilters) == 1 { filter = attrFilters[0] // Skip OR wrapper for single attribute @@ -455,10 +478,7 @@ func (c *ldapConnector) userEntry(conn *ldap.Conn, username string) (user ldap.E }, } - for _, attr := range usernameAttrs { - attr = strings.TrimSpace(attr) - req.Attributes = append(req.Attributes, attr) - } + req.Attributes = append(req.Attributes, c.usernameAttrs...) for _, matcher := range c.GroupSearch.UserMatchers { req.Attributes = append(req.Attributes, matcher.UserAttr) diff --git a/connector/ldap/ldap_test.go b/connector/ldap/ldap_test.go index 240911ae..3335d56b 100644 --- a/connector/ldap/ldap_test.go +++ b/connector/ldap/ldap_test.go @@ -45,7 +45,7 @@ func TestQuery(t *testing.T) { c.UserSearch.NameAttr = "cn" c.UserSearch.EmailAttr = "mail" c.UserSearch.IDAttr = "DN" - c.UserSearch.Username = "cn" + c.UserSearch.Username = UsernameAttributes{"cn"} tests := []subtest{ { @@ -105,7 +105,7 @@ func TestQueryWithEmailSuffix(t *testing.T) { c.UserSearch.NameAttr = "cn" c.UserSearch.EmailSuffix = "test.example.com" c.UserSearch.IDAttr = "DN" - c.UserSearch.Username = "cn" + c.UserSearch.Username = UsernameAttributes{"cn"} tests := []subtest{ { @@ -141,7 +141,7 @@ func TestUserFilter(t *testing.T) { c.UserSearch.NameAttr = "cn" c.UserSearch.EmailAttr = "mail" c.UserSearch.IDAttr = "DN" - c.UserSearch.Username = "cn" + c.UserSearch.Username = UsernameAttributes{"cn"} c.UserSearch.Filter = "(ou:dn:=Seattle)" tests := []subtest{ @@ -190,7 +190,7 @@ func TestUsernameWithMultipleAttributes(t *testing.T) { c.UserSearch.NameAttr = "cn" c.UserSearch.EmailAttr = "mail" c.UserSearch.IDAttr = "DN" - c.UserSearch.Username = "cn,mail" + c.UserSearch.Username = UsernameAttributes{"cn", "mail"} c.UserSearch.Filter = "(ou:dn:=Seattle)" tests := []subtest{ @@ -227,7 +227,7 @@ func TestGroupQuery(t *testing.T) { c.UserSearch.NameAttr = "cn" c.UserSearch.EmailAttr = "mail" c.UserSearch.IDAttr = "DN" - c.UserSearch.Username = "cn" + c.UserSearch.Username = UsernameAttributes{"cn"} c.GroupSearch.BaseDN = "ou=Groups,ou=TestGroupQuery,dc=example,dc=org" c.GroupSearch.UserMatchers = []UserMatcher{ { @@ -275,7 +275,7 @@ func TestGroupsOnUserEntity(t *testing.T) { c.UserSearch.NameAttr = "cn" c.UserSearch.EmailAttr = "mail" c.UserSearch.IDAttr = "DN" - c.UserSearch.Username = "cn" + c.UserSearch.Username = UsernameAttributes{"cn"} c.GroupSearch.BaseDN = "ou=Groups,ou=TestGroupsOnUserEntity,dc=example,dc=org" c.GroupSearch.UserMatchers = []UserMatcher{ { @@ -321,7 +321,7 @@ func TestGroupFilter(t *testing.T) { c.UserSearch.NameAttr = "cn" c.UserSearch.EmailAttr = "mail" c.UserSearch.IDAttr = "DN" - c.UserSearch.Username = "cn" + c.UserSearch.Username = UsernameAttributes{"cn"} c.GroupSearch.BaseDN = "ou=TestGroupFilter,dc=example,dc=org" c.GroupSearch.UserMatchers = []UserMatcher{ { @@ -370,7 +370,7 @@ func TestGroupToUserMatchers(t *testing.T) { c.UserSearch.NameAttr = "cn" c.UserSearch.EmailAttr = "mail" c.UserSearch.IDAttr = "DN" - c.UserSearch.Username = "cn" + c.UserSearch.Username = UsernameAttributes{"cn"} c.GroupSearch.BaseDN = "ou=TestGroupToUserMatchers,dc=example,dc=org" c.GroupSearch.UserMatchers = []UserMatcher{ { @@ -426,7 +426,7 @@ func TestDeprecatedGroupToUserMatcher(t *testing.T) { c.UserSearch.NameAttr = "cn" c.UserSearch.EmailAttr = "mail" c.UserSearch.IDAttr = "DN" - c.UserSearch.Username = "cn" + c.UserSearch.Username = UsernameAttributes{"cn"} c.GroupSearch.BaseDN = "ou=TestDeprecatedGroupToUserMatcher,dc=example,dc=org" c.GroupSearch.UserAttr = "DN" c.GroupSearch.GroupAttr = "member" @@ -471,7 +471,7 @@ func TestStartTLS(t *testing.T) { c.UserSearch.NameAttr = "cn" c.UserSearch.EmailAttr = "mail" c.UserSearch.IDAttr = "DN" - c.UserSearch.Username = "cn" + c.UserSearch.Username = UsernameAttributes{"cn"} tests := []subtest{ { @@ -495,7 +495,7 @@ func TestInsecureSkipVerify(t *testing.T) { c.UserSearch.NameAttr = "cn" c.UserSearch.EmailAttr = "mail" c.UserSearch.IDAttr = "DN" - c.UserSearch.Username = "cn" + c.UserSearch.Username = UsernameAttributes{"cn"} tests := []subtest{ { @@ -519,7 +519,7 @@ func TestLDAPS(t *testing.T) { c.UserSearch.NameAttr = "cn" c.UserSearch.EmailAttr = "mail" c.UserSearch.IDAttr = "DN" - c.UserSearch.Username = "cn" + c.UserSearch.Username = UsernameAttributes{"cn"} tests := []subtest{ { @@ -562,13 +562,43 @@ func TestUsernamePrompt(t *testing.T) { } } +func TestUsernameAttributesUnmarshal(t *testing.T) { + tests := []struct { + name string + json string + want UsernameAttributes + wantErr bool + }{ + {name: "single string", json: `"uid"`, want: UsernameAttributes{"uid"}}, + {name: "array of strings", json: `["uid","mail"]`, want: UsernameAttributes{"uid", "mail"}}, + {name: "single element array", json: `["cn"]`, want: UsernameAttributes{"cn"}}, + {name: "empty string", json: `""`, want: nil}, + {name: "invalid type", json: `123`, wantErr: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got UsernameAttributes + err := got.UnmarshalJSON([]byte(tt.json)) + if (err != nil) != tt.wantErr { + t.Fatalf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + if !tt.wantErr { + if diff := pretty.Compare(tt.want, got); diff != "" { + t.Errorf("unexpected result: %s", diff) + } + } + }) + } +} + func TestNestedGroups(t *testing.T) { c := &Config{} c.UserSearch.BaseDN = "ou=People,ou=TestNestedGroups,dc=example,dc=org" c.UserSearch.NameAttr = "cn" c.UserSearch.EmailAttr = "mail" c.UserSearch.IDAttr = "DN" - c.UserSearch.Username = "cn" + c.UserSearch.Username = UsernameAttributes{"cn"} c.GroupSearch.BaseDN = "ou=TestNestedGroups,dc=example,dc=org" c.GroupSearch.UserMatchers = []UserMatcher{