diff --git a/connector/google/google.go b/connector/google/google.go index e17ec5bd..b14ed2a5 100644 --- a/connector/google/google.go +++ b/connector/google/google.go @@ -17,6 +17,7 @@ import ( "golang.org/x/oauth2" "golang.org/x/oauth2/google" admin "google.golang.org/api/admin/directory/v1" + cloudidentity "google.golang.org/api/cloudidentity/v1" "google.golang.org/api/impersonate" "google.golang.org/api/option" @@ -160,6 +161,7 @@ type googleConnector struct { domainToAdminEmail map[string]string fetchTransitiveGroupMembership bool adminSrv map[string]*admin.Service + cloudIdentitySrv *cloudidentity.Service promptType string } @@ -261,19 +263,15 @@ func (c *googleConnector) createIdentity(ctx context.Context, identity connector } } - var groups []string - if s.Groups && len(c.adminSrv) > 0 { - checkedGroups := make(map[string]struct{}) - groups, err = c.getGroups(claims.Email, c.fetchTransitiveGroupMembership, checkedGroups) - if err != nil { - return identity, fmt.Errorf("google: could not retrieve groups: %v", err) - } + groups, err := c.getGroups(ctx, s, token, claims.Email, c.fetchTransitiveGroupMembership) + if err != nil { + return identity, fmt.Errorf("google: could not retrieve groups: %v", err) + } - if len(c.groups) > 0 { - groups = pkg_groups.Filter(groups, c.groups) - if len(groups) == 0 { - return identity, fmt.Errorf("google: user %q is not in any of the required groups", claims.Username) - } + if len(c.groups) > 0 { + groups = pkg_groups.Filter(groups, c.groups) + if len(groups) == 0 { + return identity, fmt.Errorf("google: user %q is not in any of the required groups", claims.Username) } } @@ -288,49 +286,131 @@ func (c *googleConnector) createIdentity(ctx context.Context, identity connector return identity, nil } -// getGroups creates a connection to the admin directory service and lists -// all groups the user is a member of -func (c *googleConnector) getGroups(email string, fetchTransitiveGroupMembership bool, checkedGroups map[string]struct{}) ([]string, error) { - var userGroups []string - var err error - groupsList := &admin.Groups{} - domain := c.extractDomainFromEmail(email) - adminSrv, err := c.findAdminService(domain) +// getGroups makes a best effort attempt to retrieve groups the provided email is a member of. +func (c *googleConnector) getGroups(ctx context.Context, s connector.Scopes, token *oauth2.Token, email string, fetchTransitiveGroupMembership bool) ([]string, error) { + if c.hasCloudIdentityGroupScope() { + return c.getGroupsUsingUser(ctx, token, email, fetchTransitiveGroupMembership) + } + if s.Groups && c.adminSrv != nil { + return c.getGroupsUsingAdmin(ctx, email, fetchTransitiveGroupMembership) + } + return nil, nil +} + +// hasCloudIdentityGroupScope returns true if the provided scopes contain the required scope +// to retrieve groups from the Cloud Identity API. +func (c *googleConnector) hasCloudIdentityGroupScope() bool { + return slices.ContainsFunc(c.oauth2Config.Scopes, func(s string) bool { + return s == cloudidentity.CloudPlatformScope || s == cloudidentity.CloudIdentityGroupsScope || s == cloudidentity.CloudIdentityGroupsReadonlyScope + }) +} + +// getGroupsUsingUser uses the Cloud Identity API to retrieve groups the provided email is a member +// of. This method returns groups from direct and transitive memberships. +// +// In contrast to getGroupsUsingAdmin, this method does NOT require the authenticated client +// to have been granted domain-wide delegation. Instead, it relies on the OAuth access token +// having the appropriate permissions from requesting the Cloud Identity API scope +// (see hasCloudIdentityGroupScope). +func (c *googleConnector) getGroupsUsingUser(ctx context.Context, token *oauth2.Token, email string, fetchTransitiveGroupMembership bool) ([]string, error) { + svc, err := c.createCloudIdentityService(ctx, token) if err != nil { return nil, err } - for { - groupsList, err = adminSrv.Groups.List(). - UserKey(email).PageToken(groupsList.NextPageToken).Do() - if err != nil { - return nil, fmt.Errorf("could not list groups: %v", err) - } - - for _, group := range groupsList.Groups { - if _, exists := checkedGroups[group.Email]; exists { - continue + var userGroups []string + resp := &cloudidentity.SearchDirectGroupsResponse{} + + checkedGroups := make(map[string]struct{}) + stack := []string{email} + for len(stack) > 0 { + n := len(stack) - 1 + curEmail := stack[n] + stack = stack[:n] + + query := fmt.Sprintf("member_key_id=='%s'", curEmail) + for { + resp, err = svc.Groups.Memberships.SearchDirectGroups("groups/-"). + Context(ctx).Query(query).PageToken(resp.NextPageToken).Do() + if err != nil { + return nil, fmt.Errorf("could not list groups: %v", err) } - checkedGroups[group.Email] = struct{}{} - // TODO (joelspeed): Make desired group key configurable - userGroups = append(userGroups, group.Email) + for _, m := range resp.Memberships { + group := m.GroupKey.Id + if _, exists := checkedGroups[group]; exists { + continue + } + checkedGroups[group] = struct{}{} + userGroups = append(userGroups, group) + if !fetchTransitiveGroupMembership { + continue + } + stack = append(stack, group) + } - if !fetchTransitiveGroupMembership { - continue + if resp.NextPageToken == "" { + break } + } + } + + return userGroups, nil +} + +// createCloudIdentityService is a small helper useful for testing by allowing to override +// the cloud identity service. +func (c *googleConnector) createCloudIdentityService(ctx context.Context, token *oauth2.Token) (*cloudidentity.Service, error) { + if c.cloudIdentitySrv != nil { + return c.cloudIdentitySrv, nil + } + return cloudidentity.NewService(ctx, option.WithHTTPClient(c.oauth2Config.Client(ctx, token))) +} + +// getGroupsUsingAdmin uses the Admin SDK API to retrieve groups the provided email is a member of. +// +// This method requires the authenticated client to have been granted domain-wide delegation. +func (c *googleConnector) getGroupsUsingAdmin(ctx context.Context, email string, fetchTransitiveGroupMembership bool) ([]string, error) { + domain := c.extractDomainFromEmail(email) + adminSrv, err := c.findAdminService(domain) + if err != nil { + return nil, err + } - // getGroups takes a user's email/alias as well as a group's email/alias - transitiveGroups, err := c.getGroups(group.Email, fetchTransitiveGroupMembership, checkedGroups) + var userGroups []string + resp := &admin.Groups{} + + checkedGroups := make(map[string]struct{}) + stack := []string{email} + for len(stack) > 0 { + n := len(stack) - 1 + curEmail := stack[n] + stack = stack[:n] + + for { + resp, err = adminSrv.Groups.List(). + Context(ctx).UserKey(curEmail).PageToken(resp.NextPageToken).Do() if err != nil { - return nil, fmt.Errorf("could not list transitive groups: %v", err) + return nil, fmt.Errorf("could not list groups: %v", err) } - userGroups = append(userGroups, transitiveGroups...) - } + for _, g := range resp.Groups { + group := g.Email + if _, exists := checkedGroups[group]; exists { + continue + } + checkedGroups[group] = struct{}{} + // TODO (joelspeed): Make desired group key configurable + userGroups = append(userGroups, group) + if !fetchTransitiveGroupMembership { + continue + } + stack = append(stack, group) + } - if groupsList.NextPageToken == "" { - break + if resp.NextPageToken == "" { + break + } } } diff --git a/connector/google/google_test.go b/connector/google/google_test.go index 8cc79739..46f77ff7 100644 --- a/connector/google/google_test.go +++ b/connector/google/google_test.go @@ -1,7 +1,6 @@ package google import ( - "context" "encoding/json" "fmt" "log/slog" @@ -9,11 +8,14 @@ import ( "net/http/httptest" "net/url" "os" + "regexp" "strings" "testing" "github.com/stretchr/testify/assert" + "golang.org/x/oauth2" admin "google.golang.org/api/admin/directory/v1" + cloudidentity "google.golang.org/api/cloudidentity/v1" "google.golang.org/api/option" "github.com/dexidp/dex/connector" @@ -25,26 +27,65 @@ var ( // groups_2 groups_1 // │ ├────────┐ // └── user_1 user_2 - testGroups = map[string][]*admin.Group{ + adminTestGroups = map[string][]*admin.Group{ "user_1@dexidp.com": {{Email: "groups_2@dexidp.com"}, {Email: "groups_1@dexidp.com"}}, "user_2@dexidp.com": {{Email: "groups_1@dexidp.com"}}, "groups_1@dexidp.com": {{Email: "groups_0@dexidp.com"}}, "groups_2@dexidp.com": {{Email: "groups_0@dexidp.com"}}, "groups_0@dexidp.com": {}, } - callCounter = make(map[string]int) + adminCallCounter = make(map[string]int) + + cloudIdentityTestGroups = adminGroupsToMemberships(adminTestGroups) + cloudIdentityCallCounter = make(map[string]int) + cloudIdentityMemberKeyIdRE = regexp.MustCompile("member_key_id=='([^']+)'") ) +func adminGroupsToMemberships(testGroups map[string][]*admin.Group) map[string][]*cloudidentity.MembershipRelation { + mapped := make(map[string][]*cloudidentity.MembershipRelation) + for key, groups := range testGroups { + mapped[key] = make([]*cloudidentity.MembershipRelation, 0, len(groups)) + for _, g := range groups { + mapped[key] = append(mapped[key], &cloudidentity.MembershipRelation{ + GroupKey: &cloudidentity.EntityKey{Id: g.Email}, + }) + } + } + return mapped +} + func testSetup() *httptest.Server { mux := http.NewServeMux() + // https://developers.google.com/workspace/admin/directory/reference/rest mux.HandleFunc("/admin/directory/v1/groups/", func(w http.ResponseWriter, r *http.Request) { - w.Header().Add("Content-Type", "application/json") + resp := admin.Groups{} + userKey := r.URL.Query().Get("userKey") - if groups, ok := testGroups[userKey]; ok { - json.NewEncoder(w).Encode(admin.Groups{Groups: groups}) - callCounter[userKey]++ + adminCallCounter[userKey]++ + if groups, ok := adminTestGroups[userKey]; ok { + resp.Groups = groups } + + w.Header().Add("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + }) + // https://cloud.google.com/identity/docs/reference/rest + mux.HandleFunc("/v1/groups/-/memberships:searchDirectGroups", func(w http.ResponseWriter, r *http.Request) { + resp := cloudidentity.SearchDirectGroupsResponse{} + + query := r.URL.Query().Get("query") + memberKeyIdValue := cloudIdentityMemberKeyIdRE.FindSubmatch([]byte(query)) + if len(memberKeyIdValue) > 1 { + userKey := string(memberKeyIdValue[1]) + cloudIdentityCallCounter[userKey]++ + if memberships, ok := cloudIdentityTestGroups[userKey]; ok { + resp.Memberships = memberships + } + } + + w.Header().Add("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) }) return httptest.NewServer(mux) @@ -171,7 +212,65 @@ func TestOpen(t *testing.T) { } } -func TestGetGroups(t *testing.T) { +func TestGetGroupsUsingUser(t *testing.T) { + ts := testSetup() + defer ts.Close() + + conn, err := newConnector(&Config{ + ClientID: "testClient", + ClientSecret: "testSecret", + RedirectURI: ts.URL + "/callback", + Scopes: []string{"openid", "https://www.googleapis.com/auth/cloud-identity.groups.readonly"}, + }) + assert.Nil(t, err) + + conn.cloudIdentitySrv, err = cloudidentity.NewService(t.Context(), option.WithoutAuthentication(), option.WithEndpoint(ts.URL)) + assert.Nil(t, err) + type testCase struct { + userKey string + fetchTransitiveGroupMembership bool + expectedGroups []string + } + + for name, testCase := range map[string]testCase{ + "user1_non_transitive_lookup": { + userKey: "user_1@dexidp.com", + fetchTransitiveGroupMembership: false, + expectedGroups: []string{"groups_1@dexidp.com", "groups_2@dexidp.com"}, + }, + "user1_transitive_lookup": { + userKey: "user_1@dexidp.com", + fetchTransitiveGroupMembership: true, + expectedGroups: []string{"groups_0@dexidp.com", "groups_1@dexidp.com", "groups_2@dexidp.com"}, + }, + "user2_non_transitive_lookup": { + userKey: "user_2@dexidp.com", + fetchTransitiveGroupMembership: false, + expectedGroups: []string{"groups_1@dexidp.com"}, + }, + "user2_transitive_lookup": { + userKey: "user_2@dexidp.com", + fetchTransitiveGroupMembership: true, + expectedGroups: []string{"groups_0@dexidp.com", "groups_1@dexidp.com"}, + }, + "user_not_found": { + userKey: "user_3@dexidp.com", + }, + } { + testCase := testCase + cloudIdentityCallCounter = map[string]int{} + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + + groups, err := conn.getGroupsUsingUser(t.Context(), &oauth2.Token{}, testCase.userKey, testCase.fetchTransitiveGroupMembership) + assert.Nil(err) + assert.ElementsMatch(testCase.expectedGroups, groups) + t.Logf("[%s] Amount of API calls per userKey: %+v\n", t.Name(), cloudIdentityCallCounter) + }) + } +} + +func TestGetGroupsUsingAdmin(t *testing.T) { ts := testSetup() defer ts.Close() @@ -188,12 +287,11 @@ func TestGetGroups(t *testing.T) { }) assert.Nil(t, err) - conn.adminSrv[wildcardDomainToAdminEmail], err = admin.NewService(context.Background(), option.WithoutAuthentication(), option.WithEndpoint(ts.URL)) + conn.adminSrv[wildcardDomainToAdminEmail], err = admin.NewService(t.Context(), option.WithoutAuthentication(), option.WithEndpoint(ts.URL)) assert.Nil(t, err) type testCase struct { userKey string fetchTransitiveGroupMembership bool - shouldErr bool expectedGroups []string } @@ -201,42 +299,33 @@ func TestGetGroups(t *testing.T) { "user1_non_transitive_lookup": { userKey: "user_1@dexidp.com", fetchTransitiveGroupMembership: false, - shouldErr: false, expectedGroups: []string{"groups_1@dexidp.com", "groups_2@dexidp.com"}, }, "user1_transitive_lookup": { userKey: "user_1@dexidp.com", fetchTransitiveGroupMembership: true, - shouldErr: false, expectedGroups: []string{"groups_0@dexidp.com", "groups_1@dexidp.com", "groups_2@dexidp.com"}, }, "user2_non_transitive_lookup": { userKey: "user_2@dexidp.com", fetchTransitiveGroupMembership: false, - shouldErr: false, expectedGroups: []string{"groups_1@dexidp.com"}, }, "user2_transitive_lookup": { userKey: "user_2@dexidp.com", fetchTransitiveGroupMembership: true, - shouldErr: false, expectedGroups: []string{"groups_0@dexidp.com", "groups_1@dexidp.com"}, }, } { testCase := testCase - callCounter = map[string]int{} + adminCallCounter = map[string]int{} t.Run(name, func(t *testing.T) { assert := assert.New(t) - lookup := make(map[string]struct{}) - groups, err := conn.getGroups(testCase.userKey, testCase.fetchTransitiveGroupMembership, lookup) - if testCase.shouldErr { - assert.NotNil(err) - } else { - assert.Nil(err) - } + groups, err := conn.getGroupsUsingAdmin(t.Context(), testCase.userKey, testCase.fetchTransitiveGroupMembership) + assert.Nil(err) assert.ElementsMatch(testCase.expectedGroups, groups) - t.Logf("[%s] Amount of API calls per userKey: %+v\n", t.Name(), callCounter) + t.Logf("[%s] Amount of API calls per userKey: %+v\n", t.Name(), adminCallCounter) }) } } @@ -258,7 +347,7 @@ func TestDomainToAdminEmailConfig(t *testing.T) { }) assert.Nil(t, err) - conn.adminSrv["dexidp.com"], err = admin.NewService(context.Background(), option.WithoutAuthentication(), option.WithEndpoint(ts.URL)) + conn.adminSrv["dexidp.com"], err = admin.NewService(t.Context(), option.WithoutAuthentication(), option.WithEndpoint(ts.URL)) assert.Nil(t, err) type testCase struct { userKey string @@ -280,18 +369,17 @@ func TestDomainToAdminEmailConfig(t *testing.T) { }, } { testCase := testCase - callCounter = map[string]int{} + adminCallCounter = map[string]int{} t.Run(name, func(t *testing.T) { assert := assert.New(t) - lookup := make(map[string]struct{}) - _, err := conn.getGroups(testCase.userKey, true, lookup) + _, err := conn.getGroupsUsingAdmin(t.Context(), testCase.userKey, true) if testCase.expectedErr != "" { assert.ErrorContains(err, testCase.expectedErr) } else { assert.Nil(err) } - t.Logf("[%s] Amount of API calls per userKey: %+v\n", t.Name(), callCounter) + t.Logf("[%s] Amount of API calls per userKey: %+v\n", t.Name(), adminCallCounter) }) } } @@ -337,6 +425,7 @@ func TestGCEWorkloadIdentity(t *testing.T) { os.Setenv("GCE_METADATA_HOST", metadataServerHost) os.Setenv("GOOGLE_APPLICATION_CREDENTIALS", "") os.Setenv("HOME", "/tmp") + os.Setenv("APPDATA", "/tmp") gceMetadataFlags["failOnEmailRequest"] = true _, err := newConnector(&Config{ @@ -358,7 +447,7 @@ func TestGCEWorkloadIdentity(t *testing.T) { }) assert.Nil(t, err) - conn.adminSrv["dexidp.com"], err = admin.NewService(context.Background(), option.WithoutAuthentication(), option.WithEndpoint(ts.URL)) + conn.adminSrv["dexidp.com"], err = admin.NewService(t.Context(), option.WithoutAuthentication(), option.WithEndpoint(ts.URL)) assert.Nil(t, err) type testCase struct { userKey string @@ -381,9 +470,8 @@ func TestGCEWorkloadIdentity(t *testing.T) { } { t.Run(name, func(t *testing.T) { assert := assert.New(t) - lookup := make(map[string]struct{}) - _, err := conn.getGroups(testCase.userKey, true, lookup) + _, err := conn.getGroupsUsingAdmin(t.Context(), testCase.userKey, true) if testCase.expectedErr != "" { assert.ErrorContains(err, testCase.expectedErr) } else { diff --git a/go.mod b/go.mod index 2c37d08c..9d52ac5c 100644 --- a/go.mod +++ b/go.mod @@ -33,7 +33,7 @@ require ( go.etcd.io/etcd/client/pkg/v3 v3.6.7 go.etcd.io/etcd/client/v3 v3.6.7 golang.org/x/crypto v0.46.0 - golang.org/x/exp v0.0.0-20221004215720-b9f4876ce741 + golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6 golang.org/x/net v0.48.0 golang.org/x/oauth2 v0.34.0 google.golang.org/api v0.257.0 diff --git a/go.sum b/go.sum index ad306a9e..c5e3e043 100644 --- a/go.sum +++ b/go.sum @@ -233,8 +233,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= -golang.org/x/exp v0.0.0-20221004215720-b9f4876ce741 h1:fGZugkZk2UgYBxtpKmvub51Yno1LJDeEsRp2xGD+0gY= -golang.org/x/exp v0.0.0-20221004215720-b9f4876ce741/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE= +golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6 h1:y5zboxd6LQAqYIhHnB48p0ByQ/GnQx2BE33L8BOHQkI= +golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6/go.mod h1:U6Lno4MTRCDY+Ba7aCcauB9T60gsv5s4ralQzP72ZoQ= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=