diff --git a/connector/google/google.go b/connector/google/google.go index 4a8599c0..3e2c181b 100644 --- a/connector/google/google.go +++ b/connector/google/google.go @@ -17,6 +17,7 @@ import ( "golang.org/x/oauth2" "golang.org/x/oauth2/google" admin "google.golang.org/api/admin/directory/v1" + "google.golang.org/api/cloudidentity/v1" "google.golang.org/api/impersonate" "google.golang.org/api/option" @@ -53,11 +54,17 @@ type Config struct { // Deprecated: Use DomainToAdminEmail AdminEmail string - // Required if ServiceAccountFilePath + // If ServiceAccountFilePath is set, this value is ignored if UseCloudIdentityAPI is set. Otherwise, it's required. // The map workspace domain to email of a GSuite super user which the service account will impersonate // when listing groups DomainToAdminEmail map[string]string + // If set, Cloud Identity API is used to fetch groups for a user. In particular, no user impersonation takes place. + // If ServiceAccountFilePath is not set, Application Default Credentials will be used. Otherwise, credentials will + // be generated from the file placed at the specified path. + // Defaults to false. + UseCloudIdentityAPI bool `json:"useCloudIdentityAPI"` + // If this field is true, fetch direct group membership and transitive group membership FetchTransitiveGroupMembership bool `json:"fetchTransitiveGroupMembership"` @@ -66,6 +73,14 @@ type Config struct { PromptType *string `json:"promptType"` } +func validateConfigForCloudIdentity(c *Config, logger *slog.Logger) error { + if len(c.DomainToAdminEmail) > 0 || len(c.AdminEmail) > 0 { + logger.Warn("For cloud identity calls \"DomainToAdminEmail\" and \"AdminEmail\" are ignored. It's safe to remove both configuration options.") + } + + return nil +} + // Open returns a connector which can be used to login users through Google. func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector, err error) { logger = logger.With(slog.Group("connector", "type", "google", "id", id)) @@ -94,22 +109,38 @@ func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector, adminSrv := make(map[string]*admin.Service) - // We know impersonation is required when using a service account credential - // TODO: or is it? - if len(c.DomainToAdminEmail) == 0 && c.ServiceAccountFilePath != "" { - cancel() - return nil, fmt.Errorf("directory service requires the domainToAdminEmail option to be configured") - } + var groupsMembershipsService *cloudidentity.GroupsMembershipsService - if (len(c.DomainToAdminEmail) > 0) || slices.Contains(scopes, "groups") { - for domain, adminEmail := range c.DomainToAdminEmail { - srv, err := createDirectoryService(c.ServiceAccountFilePath, adminEmail, logger) - if err != nil { - cancel() - return nil, fmt.Errorf("could not create directory service: %v", err) - } + if c.UseCloudIdentityAPI { + err = validateConfigForCloudIdentity(c, logger) + if err != nil { + cancel() + return nil, err + } - adminSrv[domain] = srv + groupsMembershipsService, err = createGroupsMembershipsService(c.ServiceAccountFilePath, logger) + if err != nil { + cancel() + return nil, fmt.Errorf("could not create groups memebership service: %v", err) + } + } else { + // We know impersonation is required when using a service account credential + // TODO: or is it? + if len(c.DomainToAdminEmail) == 0 && c.ServiceAccountFilePath != "" { + cancel() + return nil, fmt.Errorf("directory service requires the domainToAdminEmail option to be configured") + } + + if (len(c.DomainToAdminEmail) > 0) || slices.Contains(scopes, "groups") { + for domain, adminEmail := range c.DomainToAdminEmail { + srv, err := createDirectoryService(c.ServiceAccountFilePath, adminEmail, logger) + if err != nil { + cancel() + return nil, fmt.Errorf("could not create directory service: %v", err) + } + + adminSrv[domain] = srv + } } } @@ -137,8 +168,10 @@ func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector, groups: c.Groups, serviceAccountFilePath: c.ServiceAccountFilePath, domainToAdminEmail: c.DomainToAdminEmail, + useCloudIdentityAPI: c.UseCloudIdentityAPI, fetchTransitiveGroupMembership: c.FetchTransitiveGroupMembership, adminSrv: adminSrv, + groupsMembershipsService: groupsMembershipsService, promptType: promptType, }, nil } @@ -158,8 +191,10 @@ type googleConnector struct { groups []string serviceAccountFilePath string domainToAdminEmail map[string]string + useCloudIdentityAPI bool fetchTransitiveGroupMembership bool adminSrv map[string]*admin.Service + groupsMembershipsService *cloudidentity.GroupsMembershipsService promptType string } @@ -262,11 +297,22 @@ func (c *googleConnector) createIdentity(ctx context.Context, identity connector } var groups []string - if s.Groups && len(c.adminSrv) > 0 { + if s.Groups { checkedGroups := make(map[string]struct{}) - groups, err = c.getGroups(claims.Email, c.fetchTransitiveGroupMembership, checkedGroups) - if err != nil { - return identity, fmt.Errorf("google: could not retrieve groups: %v", err) + if c.useCloudIdentityAPI { + if c.groupsMembershipsService != nil { + groups, err = c.getGroupsFromCloudIdentityAPI(claims.Email, c.fetchTransitiveGroupMembership, checkedGroups) + if err != nil { + return identity, fmt.Errorf("google: could not retrieve groups from Cloud Identity API: %v", err) + } + } + } else { + if len(c.adminSrv) > 0 { + groups, err = c.getGroupsFromAdminAPI(claims.Email, c.fetchTransitiveGroupMembership, checkedGroups) + if err != nil { + return identity, fmt.Errorf("google: could not retrieve groups form Admin API: %v", err) + } + } } if len(c.groups) > 0 { @@ -288,9 +334,9 @@ func (c *googleConnector) createIdentity(ctx context.Context, identity connector return identity, nil } -// getGroups creates a connection to the admin directory service and lists +// getGroupsFromAdminAPI creates a connection to the admin directory service and lists // all groups the user is a member of -func (c *googleConnector) getGroups(email string, fetchTransitiveGroupMembership bool, checkedGroups map[string]struct{}) ([]string, error) { +func (c *googleConnector) getGroupsFromAdminAPI(email string, fetchTransitiveGroupMembership bool, checkedGroups map[string]struct{}) ([]string, error) { var userGroups []string var err error groupsList := &admin.Groups{} @@ -321,7 +367,53 @@ func (c *googleConnector) getGroups(email string, fetchTransitiveGroupMembership } // getGroups takes a user's email/alias as well as a group's email/alias - transitiveGroups, err := c.getGroups(group.Email, fetchTransitiveGroupMembership, checkedGroups) + transitiveGroups, err := c.getGroupsFromAdminAPI(group.Email, fetchTransitiveGroupMembership, checkedGroups) + if err != nil { + return nil, fmt.Errorf("could not list transitive groups: %v", err) + } + + userGroups = append(userGroups, transitiveGroups...) + } + + if groupsList.NextPageToken == "" { + break + } + } + + return userGroups, nil +} + +// getGroupsFromCloudIdentityAPI creates a connection to the cloud identity service and lists +// all groups the user is a member of +func (c *googleConnector) getGroupsFromCloudIdentityAPI(email string, fetchTransitiveGroupMembership bool, checkedGroups map[string]struct{}) ([]string, error) { + var userGroups []string + var err error + groupsList := &cloudidentity.SearchDirectGroupsResponse{} + groupsMembershipService := c.groupsMembershipsService + + for { + query := fmt.Sprintf("member_key_id=='%s'", email) + groupsList, err = groupsMembershipService.SearchDirectGroups("groups/-"). + Query(query).PageToken(groupsList.NextPageToken).Do() + if err != nil { + return nil, fmt.Errorf("could not list groups: %v", err) + } + + for _, membership := range groupsList.Memberships { + groupEmail := strings.ToLower(membership.GroupKey.Id) + if _, exists := checkedGroups[groupEmail]; exists { + continue + } + + checkedGroups[groupEmail] = struct{}{} + // TODO (joelspeed): Make desired group key configurable + userGroups = append(userGroups, groupEmail) + + if !fetchTransitiveGroupMembership { + continue + } + + transitiveGroups, err := c.getGroupsFromCloudIdentityAPI(groupEmail, fetchTransitiveGroupMembership, checkedGroups) if err != nil { return nil, fmt.Errorf("could not list transitive groups: %v", err) } @@ -455,3 +547,37 @@ func createDirectoryService(serviceAccountFilePath, email string, logger *slog.L return admin.NewService(ctx, option.WithHTTPClient(config.Client(ctx))) } + +func createGroupsMembershipsService(serviceAccountFilePath string, logger *slog.Logger) (service *cloudidentity.GroupsMembershipsService, err error) { + ctx := context.Background() + var credentials *google.Credentials + + var cloudIdentityService *cloudidentity.Service + + if serviceAccountFilePath == "" { + logger.Info("Using Application Default Credentials") + cloudIdentityService, err = cloudidentity.NewService(ctx, option.WithScopes(cloudidentity.CloudIdentityGroupsReadonlyScope)) + if err != nil { + return nil, fmt.Errorf("error creating cloud identity service: %v", err) + } + } else { + logger.Info("Using credentials file at", "sa_path", serviceAccountFilePath) + + jsonCredentials, err := os.ReadFile(serviceAccountFilePath) + if err != nil { + return nil, fmt.Errorf("error reading credentials from file: %v", err) + } + + credentials, err = google.CredentialsFromJSON(ctx, jsonCredentials, cloudidentity.CloudIdentityGroupsReadonlyScope) + if err != nil { + return nil, fmt.Errorf("failed creating credentials from file: %w", err) + } + + cloudIdentityService, err = cloudidentity.NewService(ctx, option.WithCredentials(credentials)) + if err != nil { + return nil, fmt.Errorf("error creating cloud identity service: %v", err) + } + } + + return cloudidentity.NewGroupsMembershipsService(cloudIdentityService), nil +} diff --git a/connector/google/google_test.go b/connector/google/google_test.go index ce0e017c..39284bf1 100644 --- a/connector/google/google_test.go +++ b/connector/google/google_test.go @@ -9,11 +9,13 @@ import ( "net/http/httptest" "net/url" "os" + "regexp" "strings" "testing" "github.com/stretchr/testify/assert" admin "google.golang.org/api/admin/directory/v1" + "google.golang.org/api/cloudidentity/v1" "google.golang.org/api/option" "github.com/dexidp/dex/connector" @@ -37,6 +39,7 @@ var ( func testSetup() *httptest.Server { mux := http.NewServeMux() + re := regexp.MustCompile(`^member_key_id\s*==\s*'(.+)'$`) mux.HandleFunc("/admin/directory/v1/groups/", func(w http.ResponseWriter, r *http.Request) { w.Header().Add("Content-Type", "application/json") @@ -47,6 +50,27 @@ func testSetup() *httptest.Server { } }) + mux.HandleFunc("/v1/groups/-/memberships:searchDirectGroups", func(w http.ResponseWriter, r *http.Request) { + w.Header().Add("Content-Type", "application/json") + query := r.URL.Query().Get("query") + userKey := re.FindStringSubmatch(query)[1] + if groups, ok := testGroups[userKey]; ok { + var memberships []*cloudidentity.MembershipRelation + + for _, group := range groups { + memberships = append( + memberships, + &cloudidentity.MembershipRelation{ + GroupKey: &cloudidentity.EntityKey{Id: group.Email}, + }, + ) + } + + json.NewEncoder(w).Encode(cloudidentity.SearchDirectGroupsResponse{Memberships: memberships}) + callCounter[userKey]++ + } + }) + return httptest.NewServer(mux) } @@ -75,12 +99,35 @@ func tempServiceAccountKey() (string, error) { "project_id": "sample-project", "private_key_id": "sample-key-id", "private_key": "-----BEGIN PRIVATE KEY-----\nsample-key\n-----END PRIVATE KEY-----\n", + "client_email": "service-account@example.com", "client_id": "sample-client-id", "client_x509_cert_url": "localhost", }) return fd.Name(), err } +func tempWorkloadIdentityFederation() (string, error) { + fd, err := os.CreateTemp("", "workload_identity_federation") + if err != nil { + return "", err + } + defer fd.Close() + err = json.NewEncoder(fd).Encode(map[string]any{ + "type": "external_account", + "audience": "//iam.googleapis.com/projects/111111111111/locations/global/workloadIdentityPools/aws-pool/providers/aws-provider", + "subject_token_type": "urn:ietf:params:aws:token-type:aws4_request", + "service_account_impersonation_url": "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/sa@my-google-project.iam.gserviceaccount.com:generateAccessToken", + "token_url": "https://sts.googleapis.com/v1/token", + "credential_source": map[string]string{ + "environment_id": "aws1", + "region_url": "http://169.254.169.254/latest/meta-data/placement/availability-zone", + "url": "http://169.254.169.254/latest/meta-data/iam/security-credentials", + "regional_cred_verification_url": "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", + }, + }) + return fd.Name(), err +} + func TestOpen(t *testing.T) { ts := testSetup() defer ts.Close() @@ -229,7 +276,149 @@ func TestGetGroups(t *testing.T) { assert := assert.New(t) lookup := make(map[string]struct{}) - groups, err := conn.getGroups(testCase.userKey, testCase.fetchTransitiveGroupMembership, lookup) + groups, err := conn.getGroupsFromAdminAPI(testCase.userKey, testCase.fetchTransitiveGroupMembership, lookup) + if testCase.shouldErr { + assert.NotNil(err) + } else { + assert.Nil(err) + } + assert.ElementsMatch(testCase.expectedGroups, groups) + t.Logf("[%s] Amount of API calls per userKey: %+v\n", t.Name(), callCounter) + }) + } +} + +func TestGetGroupsWithCloudIdentityApi(t *testing.T) { + ts := testSetup() + defer ts.Close() + + serviceAccountFilePath, err := tempServiceAccountKey() + assert.Nil(t, err) + + os.Setenv("GOOGLE_APPLICATION_CREDENTIALS", serviceAccountFilePath) + conn, err := newConnector(&Config{ + ClientID: "testClient", + ClientSecret: "testSecret", + RedirectURI: ts.URL + "/callback", + Scopes: []string{"openid", "groups"}, + UseCloudIdentityAPI: true, + }) + assert.Nil(t, err) + + cloudIdentityService, err := cloudidentity.NewService(context.Background(), option.WithoutAuthentication(), option.WithEndpoint(ts.URL)) + assert.Nil(t, err) + conn.groupsMembershipsService = cloudidentity.NewGroupsMembershipsService(cloudIdentityService) + type testCase struct { + userKey string + fetchTransitiveGroupMembership bool + shouldErr bool + expectedGroups []string + } + + for name, testCase := range map[string]testCase{ + "user1_non_transitive_lookup": { + userKey: "user_1@dexidp.com", + fetchTransitiveGroupMembership: false, + shouldErr: false, + expectedGroups: []string{"groups_1@dexidp.com", "groups_2@dexidp.com"}, + }, + "user1_transitive_lookup": { + userKey: "user_1@dexidp.com", + fetchTransitiveGroupMembership: true, + shouldErr: false, + expectedGroups: []string{"groups_0@dexidp.com", "groups_1@dexidp.com", "groups_2@dexidp.com"}, + }, + "user2_non_transitive_lookup": { + userKey: "user_2@dexidp.com", + fetchTransitiveGroupMembership: false, + shouldErr: false, + expectedGroups: []string{"groups_1@dexidp.com"}, + }, + "user2_transitive_lookup": { + userKey: "user_2@dexidp.com", + fetchTransitiveGroupMembership: true, + shouldErr: false, + expectedGroups: []string{"groups_0@dexidp.com", "groups_1@dexidp.com"}, + }, + } { + testCase := testCase + callCounter = map[string]int{} + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + lookup := make(map[string]struct{}) + + groups, err := conn.getGroupsFromCloudIdentityAPI(testCase.userKey, testCase.fetchTransitiveGroupMembership, lookup) + if testCase.shouldErr { + assert.NotNil(err) + } else { + assert.Nil(err) + } + assert.ElementsMatch(testCase.expectedGroups, groups) + t.Logf("[%s] Amount of API calls per userKey: %+v\n", t.Name(), callCounter) + }) + } +} + +func TestGetGroupsWithCloudIdentityApiAndWorkloadIdentityFederation(t *testing.T) { + ts := testSetup() + defer ts.Close() + + workloadIdentityFederationFilePath, err := tempWorkloadIdentityFederation() + assert.Nil(t, err) + + os.Setenv("GOOGLE_APPLICATION_CREDENTIALS", workloadIdentityFederationFilePath) + conn, err := newConnector(&Config{ + ClientID: "testClient", + ClientSecret: "testSecret", + RedirectURI: ts.URL + "/callback", + Scopes: []string{"openid", "groups"}, + UseCloudIdentityAPI: true, + }) + assert.Nil(t, err) + + cloudIdentityService, err := cloudidentity.NewService(context.Background(), option.WithoutAuthentication(), option.WithEndpoint(ts.URL)) + assert.Nil(t, err) + conn.groupsMembershipsService = cloudidentity.NewGroupsMembershipsService(cloudIdentityService) + type testCase struct { + userKey string + fetchTransitiveGroupMembership bool + shouldErr bool + expectedGroups []string + } + + for name, testCase := range map[string]testCase{ + "user1_non_transitive_lookup": { + userKey: "user_1@dexidp.com", + fetchTransitiveGroupMembership: false, + shouldErr: false, + expectedGroups: []string{"groups_1@dexidp.com", "groups_2@dexidp.com"}, + }, + "user1_transitive_lookup": { + userKey: "user_1@dexidp.com", + fetchTransitiveGroupMembership: true, + shouldErr: false, + expectedGroups: []string{"groups_0@dexidp.com", "groups_1@dexidp.com", "groups_2@dexidp.com"}, + }, + "user2_non_transitive_lookup": { + userKey: "user_2@dexidp.com", + fetchTransitiveGroupMembership: false, + shouldErr: false, + expectedGroups: []string{"groups_1@dexidp.com"}, + }, + "user2_transitive_lookup": { + userKey: "user_2@dexidp.com", + fetchTransitiveGroupMembership: true, + shouldErr: false, + expectedGroups: []string{"groups_0@dexidp.com", "groups_1@dexidp.com"}, + }, + } { + testCase := testCase + callCounter = map[string]int{} + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + lookup := make(map[string]struct{}) + + groups, err := conn.getGroupsFromCloudIdentityAPI(testCase.userKey, testCase.fetchTransitiveGroupMembership, lookup) if testCase.shouldErr { assert.NotNil(err) } else { @@ -285,7 +474,7 @@ func TestDomainToAdminEmailConfig(t *testing.T) { assert := assert.New(t) lookup := make(map[string]struct{}) - _, err := conn.getGroups(testCase.userKey, true, lookup) + _, err := conn.getGroupsFromAdminAPI(testCase.userKey, true, lookup) if testCase.expectedErr != "" { assert.ErrorContains(err, testCase.expectedErr) } else { @@ -383,7 +572,7 @@ func TestGCEWorkloadIdentity(t *testing.T) { assert := assert.New(t) lookup := make(map[string]struct{}) - _, err := conn.getGroups(testCase.userKey, true, lookup) + _, err := conn.getGroupsFromAdminAPI(testCase.userKey, true, lookup) if testCase.expectedErr != "" { assert.ErrorContains(err, testCase.expectedErr) } else {