|
|
|
|
@ -1,7 +1,6 @@
|
|
|
|
|
package google |
|
|
|
|
|
|
|
|
|
import ( |
|
|
|
|
"context" |
|
|
|
|
"encoding/json" |
|
|
|
|
"fmt" |
|
|
|
|
"log/slog" |
|
|
|
|
@ -9,11 +8,14 @@ import (
|
|
|
|
|
"net/http/httptest" |
|
|
|
|
"net/url" |
|
|
|
|
"os" |
|
|
|
|
"regexp" |
|
|
|
|
"strings" |
|
|
|
|
"testing" |
|
|
|
|
|
|
|
|
|
"github.com/stretchr/testify/assert" |
|
|
|
|
"golang.org/x/oauth2" |
|
|
|
|
admin "google.golang.org/api/admin/directory/v1" |
|
|
|
|
cloudidentity "google.golang.org/api/cloudidentity/v1" |
|
|
|
|
"google.golang.org/api/option" |
|
|
|
|
|
|
|
|
|
"github.com/dexidp/dex/connector" |
|
|
|
|
@ -25,26 +27,65 @@ var (
|
|
|
|
|
// groups_2 groups_1
|
|
|
|
|
// │ ├────────┐
|
|
|
|
|
// └── user_1 user_2
|
|
|
|
|
testGroups = map[string][]*admin.Group{ |
|
|
|
|
adminTestGroups = map[string][]*admin.Group{ |
|
|
|
|
"user_1@dexidp.com": {{Email: "groups_2@dexidp.com"}, {Email: "groups_1@dexidp.com"}}, |
|
|
|
|
"user_2@dexidp.com": {{Email: "groups_1@dexidp.com"}}, |
|
|
|
|
"groups_1@dexidp.com": {{Email: "groups_0@dexidp.com"}}, |
|
|
|
|
"groups_2@dexidp.com": {{Email: "groups_0@dexidp.com"}}, |
|
|
|
|
"groups_0@dexidp.com": {}, |
|
|
|
|
} |
|
|
|
|
callCounter = make(map[string]int) |
|
|
|
|
adminCallCounter = make(map[string]int) |
|
|
|
|
|
|
|
|
|
cloudIdentityTestGroups = adminGroupsToMemberships(adminTestGroups) |
|
|
|
|
cloudIdentityCallCounter = make(map[string]int) |
|
|
|
|
cloudIdentityMemberKeyIdRE = regexp.MustCompile("member_key_id=='([^']+)'") |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
func adminGroupsToMemberships(testGroups map[string][]*admin.Group) map[string][]*cloudidentity.MembershipRelation { |
|
|
|
|
mapped := make(map[string][]*cloudidentity.MembershipRelation) |
|
|
|
|
for key, groups := range testGroups { |
|
|
|
|
mapped[key] = make([]*cloudidentity.MembershipRelation, 0, len(groups)) |
|
|
|
|
for _, g := range groups { |
|
|
|
|
mapped[key] = append(mapped[key], &cloudidentity.MembershipRelation{ |
|
|
|
|
GroupKey: &cloudidentity.EntityKey{Id: g.Email}, |
|
|
|
|
}) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
return mapped |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func testSetup() *httptest.Server { |
|
|
|
|
mux := http.NewServeMux() |
|
|
|
|
|
|
|
|
|
// https://developers.google.com/workspace/admin/directory/reference/rest
|
|
|
|
|
mux.HandleFunc("/admin/directory/v1/groups/", func(w http.ResponseWriter, r *http.Request) { |
|
|
|
|
w.Header().Add("Content-Type", "application/json") |
|
|
|
|
resp := admin.Groups{} |
|
|
|
|
|
|
|
|
|
userKey := r.URL.Query().Get("userKey") |
|
|
|
|
if groups, ok := testGroups[userKey]; ok { |
|
|
|
|
json.NewEncoder(w).Encode(admin.Groups{Groups: groups}) |
|
|
|
|
callCounter[userKey]++ |
|
|
|
|
adminCallCounter[userKey]++ |
|
|
|
|
if groups, ok := adminTestGroups[userKey]; ok { |
|
|
|
|
resp.Groups = groups |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
w.Header().Add("Content-Type", "application/json") |
|
|
|
|
json.NewEncoder(w).Encode(resp) |
|
|
|
|
}) |
|
|
|
|
// https://cloud.google.com/identity/docs/reference/rest
|
|
|
|
|
mux.HandleFunc("/v1/groups/-/memberships:searchDirectGroups", func(w http.ResponseWriter, r *http.Request) { |
|
|
|
|
resp := cloudidentity.SearchDirectGroupsResponse{} |
|
|
|
|
|
|
|
|
|
query := r.URL.Query().Get("query") |
|
|
|
|
memberKeyIdValue := cloudIdentityMemberKeyIdRE.FindSubmatch([]byte(query)) |
|
|
|
|
if len(memberKeyIdValue) > 1 { |
|
|
|
|
userKey := string(memberKeyIdValue[1]) |
|
|
|
|
cloudIdentityCallCounter[userKey]++ |
|
|
|
|
if memberships, ok := cloudIdentityTestGroups[userKey]; ok { |
|
|
|
|
resp.Memberships = memberships |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
w.Header().Add("Content-Type", "application/json") |
|
|
|
|
json.NewEncoder(w).Encode(resp) |
|
|
|
|
}) |
|
|
|
|
|
|
|
|
|
return httptest.NewServer(mux) |
|
|
|
|
@ -171,7 +212,65 @@ func TestOpen(t *testing.T) {
|
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func TestGetGroups(t *testing.T) { |
|
|
|
|
func TestGetGroupsUsingUser(t *testing.T) { |
|
|
|
|
ts := testSetup() |
|
|
|
|
defer ts.Close() |
|
|
|
|
|
|
|
|
|
conn, err := newConnector(&Config{ |
|
|
|
|
ClientID: "testClient", |
|
|
|
|
ClientSecret: "testSecret", |
|
|
|
|
RedirectURI: ts.URL + "/callback", |
|
|
|
|
Scopes: []string{"openid", "https://www.googleapis.com/auth/cloud-identity.groups.readonly"}, |
|
|
|
|
}) |
|
|
|
|
assert.Nil(t, err) |
|
|
|
|
|
|
|
|
|
conn.cloudIdentitySrv, err = cloudidentity.NewService(t.Context(), option.WithoutAuthentication(), option.WithEndpoint(ts.URL)) |
|
|
|
|
assert.Nil(t, err) |
|
|
|
|
type testCase struct { |
|
|
|
|
userKey string |
|
|
|
|
fetchTransitiveGroupMembership bool |
|
|
|
|
expectedGroups []string |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
for name, testCase := range map[string]testCase{ |
|
|
|
|
"user1_non_transitive_lookup": { |
|
|
|
|
userKey: "user_1@dexidp.com", |
|
|
|
|
fetchTransitiveGroupMembership: false, |
|
|
|
|
expectedGroups: []string{"groups_1@dexidp.com", "groups_2@dexidp.com"}, |
|
|
|
|
}, |
|
|
|
|
"user1_transitive_lookup": { |
|
|
|
|
userKey: "user_1@dexidp.com", |
|
|
|
|
fetchTransitiveGroupMembership: true, |
|
|
|
|
expectedGroups: []string{"groups_0@dexidp.com", "groups_1@dexidp.com", "groups_2@dexidp.com"}, |
|
|
|
|
}, |
|
|
|
|
"user2_non_transitive_lookup": { |
|
|
|
|
userKey: "user_2@dexidp.com", |
|
|
|
|
fetchTransitiveGroupMembership: false, |
|
|
|
|
expectedGroups: []string{"groups_1@dexidp.com"}, |
|
|
|
|
}, |
|
|
|
|
"user2_transitive_lookup": { |
|
|
|
|
userKey: "user_2@dexidp.com", |
|
|
|
|
fetchTransitiveGroupMembership: true, |
|
|
|
|
expectedGroups: []string{"groups_0@dexidp.com", "groups_1@dexidp.com"}, |
|
|
|
|
}, |
|
|
|
|
"user_not_found": { |
|
|
|
|
userKey: "user_3@dexidp.com", |
|
|
|
|
}, |
|
|
|
|
} { |
|
|
|
|
testCase := testCase |
|
|
|
|
cloudIdentityCallCounter = map[string]int{} |
|
|
|
|
t.Run(name, func(t *testing.T) { |
|
|
|
|
assert := assert.New(t) |
|
|
|
|
|
|
|
|
|
groups, err := conn.getGroupsUsingUser(t.Context(), &oauth2.Token{}, testCase.userKey, testCase.fetchTransitiveGroupMembership) |
|
|
|
|
assert.Nil(err) |
|
|
|
|
assert.ElementsMatch(testCase.expectedGroups, groups) |
|
|
|
|
t.Logf("[%s] Amount of API calls per userKey: %+v\n", t.Name(), cloudIdentityCallCounter) |
|
|
|
|
}) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func TestGetGroupsUsingAdmin(t *testing.T) { |
|
|
|
|
ts := testSetup() |
|
|
|
|
defer ts.Close() |
|
|
|
|
|
|
|
|
|
@ -188,12 +287,11 @@ func TestGetGroups(t *testing.T) {
|
|
|
|
|
}) |
|
|
|
|
assert.Nil(t, err) |
|
|
|
|
|
|
|
|
|
conn.adminSrv[wildcardDomainToAdminEmail], err = admin.NewService(context.Background(), option.WithoutAuthentication(), option.WithEndpoint(ts.URL)) |
|
|
|
|
conn.adminSrv[wildcardDomainToAdminEmail], err = admin.NewService(t.Context(), option.WithoutAuthentication(), option.WithEndpoint(ts.URL)) |
|
|
|
|
assert.Nil(t, err) |
|
|
|
|
type testCase struct { |
|
|
|
|
userKey string |
|
|
|
|
fetchTransitiveGroupMembership bool |
|
|
|
|
shouldErr bool |
|
|
|
|
expectedGroups []string |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
@ -201,42 +299,33 @@ func TestGetGroups(t *testing.T) {
|
|
|
|
|
"user1_non_transitive_lookup": { |
|
|
|
|
userKey: "user_1@dexidp.com", |
|
|
|
|
fetchTransitiveGroupMembership: false, |
|
|
|
|
shouldErr: false, |
|
|
|
|
expectedGroups: []string{"groups_1@dexidp.com", "groups_2@dexidp.com"}, |
|
|
|
|
}, |
|
|
|
|
"user1_transitive_lookup": { |
|
|
|
|
userKey: "user_1@dexidp.com", |
|
|
|
|
fetchTransitiveGroupMembership: true, |
|
|
|
|
shouldErr: false, |
|
|
|
|
expectedGroups: []string{"groups_0@dexidp.com", "groups_1@dexidp.com", "groups_2@dexidp.com"}, |
|
|
|
|
}, |
|
|
|
|
"user2_non_transitive_lookup": { |
|
|
|
|
userKey: "user_2@dexidp.com", |
|
|
|
|
fetchTransitiveGroupMembership: false, |
|
|
|
|
shouldErr: false, |
|
|
|
|
expectedGroups: []string{"groups_1@dexidp.com"}, |
|
|
|
|
}, |
|
|
|
|
"user2_transitive_lookup": { |
|
|
|
|
userKey: "user_2@dexidp.com", |
|
|
|
|
fetchTransitiveGroupMembership: true, |
|
|
|
|
shouldErr: false, |
|
|
|
|
expectedGroups: []string{"groups_0@dexidp.com", "groups_1@dexidp.com"}, |
|
|
|
|
}, |
|
|
|
|
} { |
|
|
|
|
testCase := testCase |
|
|
|
|
callCounter = map[string]int{} |
|
|
|
|
adminCallCounter = map[string]int{} |
|
|
|
|
t.Run(name, func(t *testing.T) { |
|
|
|
|
assert := assert.New(t) |
|
|
|
|
lookup := make(map[string]struct{}) |
|
|
|
|
|
|
|
|
|
groups, err := conn.getGroups(testCase.userKey, testCase.fetchTransitiveGroupMembership, lookup) |
|
|
|
|
if testCase.shouldErr { |
|
|
|
|
assert.NotNil(err) |
|
|
|
|
} else { |
|
|
|
|
assert.Nil(err) |
|
|
|
|
} |
|
|
|
|
groups, err := conn.getGroupsUsingAdmin(t.Context(), testCase.userKey, testCase.fetchTransitiveGroupMembership) |
|
|
|
|
assert.Nil(err) |
|
|
|
|
assert.ElementsMatch(testCase.expectedGroups, groups) |
|
|
|
|
t.Logf("[%s] Amount of API calls per userKey: %+v\n", t.Name(), callCounter) |
|
|
|
|
t.Logf("[%s] Amount of API calls per userKey: %+v\n", t.Name(), adminCallCounter) |
|
|
|
|
}) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
@ -258,7 +347,7 @@ func TestDomainToAdminEmailConfig(t *testing.T) {
|
|
|
|
|
}) |
|
|
|
|
assert.Nil(t, err) |
|
|
|
|
|
|
|
|
|
conn.adminSrv["dexidp.com"], err = admin.NewService(context.Background(), option.WithoutAuthentication(), option.WithEndpoint(ts.URL)) |
|
|
|
|
conn.adminSrv["dexidp.com"], err = admin.NewService(t.Context(), option.WithoutAuthentication(), option.WithEndpoint(ts.URL)) |
|
|
|
|
assert.Nil(t, err) |
|
|
|
|
type testCase struct { |
|
|
|
|
userKey string |
|
|
|
|
@ -280,18 +369,17 @@ func TestDomainToAdminEmailConfig(t *testing.T) {
|
|
|
|
|
}, |
|
|
|
|
} { |
|
|
|
|
testCase := testCase |
|
|
|
|
callCounter = map[string]int{} |
|
|
|
|
adminCallCounter = map[string]int{} |
|
|
|
|
t.Run(name, func(t *testing.T) { |
|
|
|
|
assert := assert.New(t) |
|
|
|
|
lookup := make(map[string]struct{}) |
|
|
|
|
|
|
|
|
|
_, err := conn.getGroups(testCase.userKey, true, lookup) |
|
|
|
|
_, err := conn.getGroupsUsingAdmin(t.Context(), testCase.userKey, true) |
|
|
|
|
if testCase.expectedErr != "" { |
|
|
|
|
assert.ErrorContains(err, testCase.expectedErr) |
|
|
|
|
} else { |
|
|
|
|
assert.Nil(err) |
|
|
|
|
} |
|
|
|
|
t.Logf("[%s] Amount of API calls per userKey: %+v\n", t.Name(), callCounter) |
|
|
|
|
t.Logf("[%s] Amount of API calls per userKey: %+v\n", t.Name(), adminCallCounter) |
|
|
|
|
}) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
@ -337,6 +425,7 @@ func TestGCEWorkloadIdentity(t *testing.T) {
|
|
|
|
|
os.Setenv("GCE_METADATA_HOST", metadataServerHost) |
|
|
|
|
os.Setenv("GOOGLE_APPLICATION_CREDENTIALS", "") |
|
|
|
|
os.Setenv("HOME", "/tmp") |
|
|
|
|
os.Setenv("APPDATA", "/tmp") |
|
|
|
|
|
|
|
|
|
gceMetadataFlags["failOnEmailRequest"] = true |
|
|
|
|
_, err := newConnector(&Config{ |
|
|
|
|
@ -358,7 +447,7 @@ func TestGCEWorkloadIdentity(t *testing.T) {
|
|
|
|
|
}) |
|
|
|
|
assert.Nil(t, err) |
|
|
|
|
|
|
|
|
|
conn.adminSrv["dexidp.com"], err = admin.NewService(context.Background(), option.WithoutAuthentication(), option.WithEndpoint(ts.URL)) |
|
|
|
|
conn.adminSrv["dexidp.com"], err = admin.NewService(t.Context(), option.WithoutAuthentication(), option.WithEndpoint(ts.URL)) |
|
|
|
|
assert.Nil(t, err) |
|
|
|
|
type testCase struct { |
|
|
|
|
userKey string |
|
|
|
|
@ -381,9 +470,8 @@ func TestGCEWorkloadIdentity(t *testing.T) {
|
|
|
|
|
} { |
|
|
|
|
t.Run(name, func(t *testing.T) { |
|
|
|
|
assert := assert.New(t) |
|
|
|
|
lookup := make(map[string]struct{}) |
|
|
|
|
|
|
|
|
|
_, err := conn.getGroups(testCase.userKey, true, lookup) |
|
|
|
|
_, err := conn.getGroupsUsingAdmin(t.Context(), testCase.userKey, true) |
|
|
|
|
if testCase.expectedErr != "" { |
|
|
|
|
assert.ErrorContains(err, testCase.expectedErr) |
|
|
|
|
} else { |
|
|
|
|
|