From 47538f8d085d9d1fbd3e446b7c7478a2f9a3ce85 Mon Sep 17 00:00:00 2001 From: Michael Dudzinski Date: Tue, 10 Oct 2023 14:36:19 +0200 Subject: [PATCH] feat: [Google connector] use Cloud Identity API for fetching groups A new Google connector option, `useCloudIdentityApi`, has been introduced. If the value is `true`, dex will use cloud identity api to fetch groups. In particular, no user impersonation happens. The logic to obtain the credentials is based on Application Default Credentials. Alternatively, the user is allowed to pass a path to a credentials JSON file using `serviceAccountFilePath`. In both cases, the principal described linked to the credentials requires group read rights. In case of a Service Account, a custom admin role with this right need to be created in Google Workspace, and the Service Account needs to be assigned to this role. Moreover, Workload Identity Federation is supported as Application Default Credentials supports this use case. Make sure to include `Service Account Token Creator` for the linked service account in that case. Signed-off-by: Michael Dudzinski --- connector/google/google.go | 170 ++++++++++++++++++++++++---- connector/google/google_test.go | 195 +++++++++++++++++++++++++++++++- 2 files changed, 340 insertions(+), 25 deletions(-) diff --git a/connector/google/google.go b/connector/google/google.go index e17ec5bd..d42698e5 100644 --- a/connector/google/google.go +++ b/connector/google/google.go @@ -17,6 +17,7 @@ import ( "golang.org/x/oauth2" "golang.org/x/oauth2/google" admin "google.golang.org/api/admin/directory/v1" + "google.golang.org/api/cloudidentity/v1" "google.golang.org/api/impersonate" "google.golang.org/api/option" @@ -53,11 +54,17 @@ type Config struct { // Deprecated: Use DomainToAdminEmail AdminEmail string - // Required if ServiceAccountFilePath + // If ServiceAccountFilePath is set, this value is ignored if UseCloudIdentityAPI is set. Otherwise, it's required. // The map workspace domain to email of a GSuite super user which the service account will impersonate // when listing groups DomainToAdminEmail map[string]string + // If set, Cloud Identity API is used to fetch groups for a user. In particular, no user impersonation takes place. + // If ServiceAccountFilePath is not set, Application Default Credentials will be used. Otherwise, credentials will + // be generated from the file placed at the specified path. + // Defaults to false. + UseCloudIdentityAPI bool `json:"useCloudIdentityAPI"` + // If this field is true, fetch direct group membership and transitive group membership FetchTransitiveGroupMembership bool `json:"fetchTransitiveGroupMembership"` @@ -66,6 +73,14 @@ type Config struct { PromptType *string `json:"promptType"` } +func validateConfigForCloudIdentity(c *Config, logger *slog.Logger) error { + if len(c.DomainToAdminEmail) > 0 || len(c.AdminEmail) > 0 { + logger.Warn("For cloud identity calls \"DomainToAdminEmail\" and \"AdminEmail\" are ignored. It's safe to remove both configuration options.") + } + + return nil +} + // Open returns a connector which can be used to login users through Google. func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector, err error) { logger = logger.With(slog.Group("connector", "type", "google", "id", id)) @@ -94,22 +109,38 @@ func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector, adminSrv := make(map[string]*admin.Service) - // We know impersonation is required when using a service account credential - // TODO: or is it? - if len(c.DomainToAdminEmail) == 0 && c.ServiceAccountFilePath != "" { - cancel() - return nil, fmt.Errorf("directory service requires the domainToAdminEmail option to be configured") - } + var groupsMembershipsService *cloudidentity.GroupsMembershipsService - if (len(c.DomainToAdminEmail) > 0) || slices.Contains(scopes, "groups") { - for domain, adminEmail := range c.DomainToAdminEmail { - srv, err := createDirectoryService(c.ServiceAccountFilePath, adminEmail, logger) - if err != nil { - cancel() - return nil, fmt.Errorf("could not create directory service: %v", err) - } + if c.UseCloudIdentityAPI { + err = validateConfigForCloudIdentity(c, logger) + if err != nil { + cancel() + return nil, err + } - adminSrv[domain] = srv + groupsMembershipsService, err = createGroupsMembershipsService(c.ServiceAccountFilePath, logger) + if err != nil { + cancel() + return nil, fmt.Errorf("could not create groups memebership service: %v", err) + } + } else { + // We know impersonation is required when using a service account credential + // TODO: or is it? + if len(c.DomainToAdminEmail) == 0 && c.ServiceAccountFilePath != "" { + cancel() + return nil, fmt.Errorf("directory service requires the domainToAdminEmail option to be configured") + } + + if (len(c.DomainToAdminEmail) > 0) || slices.Contains(scopes, "groups") { + for domain, adminEmail := range c.DomainToAdminEmail { + srv, err := createDirectoryService(c.ServiceAccountFilePath, adminEmail, logger) + if err != nil { + cancel() + return nil, fmt.Errorf("could not create directory service: %v", err) + } + + adminSrv[domain] = srv + } } } @@ -137,8 +168,10 @@ func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector, groups: c.Groups, serviceAccountFilePath: c.ServiceAccountFilePath, domainToAdminEmail: c.DomainToAdminEmail, + useCloudIdentityAPI: c.UseCloudIdentityAPI, fetchTransitiveGroupMembership: c.FetchTransitiveGroupMembership, adminSrv: adminSrv, + groupsMembershipsService: groupsMembershipsService, promptType: promptType, }, nil } @@ -158,8 +191,10 @@ type googleConnector struct { groups []string serviceAccountFilePath string domainToAdminEmail map[string]string + useCloudIdentityAPI bool fetchTransitiveGroupMembership bool adminSrv map[string]*admin.Service + groupsMembershipsService *cloudidentity.GroupsMembershipsService promptType string } @@ -262,11 +297,22 @@ func (c *googleConnector) createIdentity(ctx context.Context, identity connector } var groups []string - if s.Groups && len(c.adminSrv) > 0 { + if s.Groups { checkedGroups := make(map[string]struct{}) - groups, err = c.getGroups(claims.Email, c.fetchTransitiveGroupMembership, checkedGroups) - if err != nil { - return identity, fmt.Errorf("google: could not retrieve groups: %v", err) + if c.useCloudIdentityAPI { + if c.groupsMembershipsService != nil { + groups, err = c.getGroupsFromCloudIdentityAPI(claims.Email, c.fetchTransitiveGroupMembership, checkedGroups) + if err != nil { + return identity, fmt.Errorf("google: could not retrieve groups from Cloud Identity API: %v", err) + } + } + } else { + if len(c.adminSrv) > 0 { + groups, err = c.getGroupsFromAdminAPI(claims.Email, c.fetchTransitiveGroupMembership, checkedGroups) + if err != nil { + return identity, fmt.Errorf("google: could not retrieve groups form Admin API: %v", err) + } + } } if len(c.groups) > 0 { @@ -288,9 +334,9 @@ func (c *googleConnector) createIdentity(ctx context.Context, identity connector return identity, nil } -// getGroups creates a connection to the admin directory service and lists +// getGroupsFromAdminAPI creates a connection to the admin directory service and lists // all groups the user is a member of -func (c *googleConnector) getGroups(email string, fetchTransitiveGroupMembership bool, checkedGroups map[string]struct{}) ([]string, error) { +func (c *googleConnector) getGroupsFromAdminAPI(email string, fetchTransitiveGroupMembership bool, checkedGroups map[string]struct{}) ([]string, error) { var userGroups []string var err error groupsList := &admin.Groups{} @@ -321,7 +367,53 @@ func (c *googleConnector) getGroups(email string, fetchTransitiveGroupMembership } // getGroups takes a user's email/alias as well as a group's email/alias - transitiveGroups, err := c.getGroups(group.Email, fetchTransitiveGroupMembership, checkedGroups) + transitiveGroups, err := c.getGroupsFromAdminAPI(group.Email, fetchTransitiveGroupMembership, checkedGroups) + if err != nil { + return nil, fmt.Errorf("could not list transitive groups: %v", err) + } + + userGroups = append(userGroups, transitiveGroups...) + } + + if groupsList.NextPageToken == "" { + break + } + } + + return userGroups, nil +} + +// getGroupsFromCloudIdentityAPI creates a connection to the cloud identity service and lists +// all groups the user is a member of +func (c *googleConnector) getGroupsFromCloudIdentityAPI(email string, fetchTransitiveGroupMembership bool, checkedGroups map[string]struct{}) ([]string, error) { + var userGroups []string + var err error + groupsList := &cloudidentity.SearchDirectGroupsResponse{} + groupsMembershipService := c.groupsMembershipsService + + for { + query := fmt.Sprintf("member_key_id=='%s'", email) + groupsList, err = groupsMembershipService.SearchDirectGroups("groups/-"). + Query(query).PageToken(groupsList.NextPageToken).Do() + if err != nil { + return nil, fmt.Errorf("could not list groups: %v", err) + } + + for _, membership := range groupsList.Memberships { + groupEmail := strings.ToLower(membership.GroupKey.Id) + if _, exists := checkedGroups[groupEmail]; exists { + continue + } + + checkedGroups[groupEmail] = struct{}{} + // TODO (joelspeed): Make desired group key configurable + userGroups = append(userGroups, groupEmail) + + if !fetchTransitiveGroupMembership { + continue + } + + transitiveGroups, err := c.getGroupsFromCloudIdentityAPI(groupEmail, fetchTransitiveGroupMembership, checkedGroups) if err != nil { return nil, fmt.Errorf("could not list transitive groups: %v", err) } @@ -455,3 +547,37 @@ func createDirectoryService(serviceAccountFilePath, email string, logger *slog.L return admin.NewService(ctx, option.WithHTTPClient(config.Client(ctx))) } + +func createGroupsMembershipsService(serviceAccountFilePath string, logger *slog.Logger) (service *cloudidentity.GroupsMembershipsService, err error) { + ctx := context.Background() + var credentials *google.Credentials + + var cloudIdentityService *cloudidentity.Service + + if serviceAccountFilePath == "" { + logger.Info("Using Application Default Credentials") + cloudIdentityService, err = cloudidentity.NewService(ctx, option.WithScopes(cloudidentity.CloudIdentityGroupsReadonlyScope)) + if err != nil { + return nil, fmt.Errorf("error creating cloud identity service: %v", err) + } + } else { + logger.Info("Using credentials file at", "sa_path", serviceAccountFilePath) + + jsonCredentials, err := os.ReadFile(serviceAccountFilePath) + if err != nil { + return nil, fmt.Errorf("error reading credentials from file: %v", err) + } + + credentials, err = google.CredentialsFromJSON(ctx, jsonCredentials, cloudidentity.CloudIdentityGroupsReadonlyScope) + if err != nil { + return nil, fmt.Errorf("failed creating credentials from file: %w", err) + } + + cloudIdentityService, err = cloudidentity.NewService(ctx, option.WithCredentials(credentials)) + if err != nil { + return nil, fmt.Errorf("error creating cloud identity service: %v", err) + } + } + + return cloudidentity.NewGroupsMembershipsService(cloudIdentityService), nil +} diff --git a/connector/google/google_test.go b/connector/google/google_test.go index 8cc79739..bea30c4e 100644 --- a/connector/google/google_test.go +++ b/connector/google/google_test.go @@ -9,11 +9,13 @@ import ( "net/http/httptest" "net/url" "os" + "regexp" "strings" "testing" "github.com/stretchr/testify/assert" admin "google.golang.org/api/admin/directory/v1" + "google.golang.org/api/cloudidentity/v1" "google.golang.org/api/option" "github.com/dexidp/dex/connector" @@ -37,6 +39,7 @@ var ( func testSetup() *httptest.Server { mux := http.NewServeMux() + re := regexp.MustCompile(`^member_key_id\s*==\s*'(.+)'$`) mux.HandleFunc("/admin/directory/v1/groups/", func(w http.ResponseWriter, r *http.Request) { w.Header().Add("Content-Type", "application/json") @@ -47,6 +50,27 @@ func testSetup() *httptest.Server { } }) + mux.HandleFunc("/v1/groups/-/memberships:searchDirectGroups", func(w http.ResponseWriter, r *http.Request) { + w.Header().Add("Content-Type", "application/json") + query := r.URL.Query().Get("query") + userKey := re.FindStringSubmatch(query)[1] + if groups, ok := testGroups[userKey]; ok { + var memberships []*cloudidentity.MembershipRelation + + for _, group := range groups { + memberships = append( + memberships, + &cloudidentity.MembershipRelation{ + GroupKey: &cloudidentity.EntityKey{Id: group.Email}, + }, + ) + } + + json.NewEncoder(w).Encode(cloudidentity.SearchDirectGroupsResponse{Memberships: memberships}) + callCounter[userKey]++ + } + }) + return httptest.NewServer(mux) } @@ -75,12 +99,35 @@ func tempServiceAccountKey() (string, error) { "project_id": "sample-project", "private_key_id": "sample-key-id", "private_key": "-----BEGIN PRIVATE KEY-----\nsample-key\n-----END PRIVATE KEY-----\n", + "client_email": "service-account@example.com", "client_id": "sample-client-id", "client_x509_cert_url": "localhost", }) return fd.Name(), err } +func tempWorkloadIdentityFederation() (string, error) { + fd, err := os.CreateTemp("", "workload_identity_federation") + if err != nil { + return "", err + } + defer fd.Close() + err = json.NewEncoder(fd).Encode(map[string]any{ + "type": "external_account", + "audience": "//iam.googleapis.com/projects/111111111111/locations/global/workloadIdentityPools/aws-pool/providers/aws-provider", + "subject_token_type": "urn:ietf:params:aws:token-type:aws4_request", + "service_account_impersonation_url": "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/sa@my-google-project.iam.gserviceaccount.com:generateAccessToken", + "token_url": "https://sts.googleapis.com/v1/token", + "credential_source": map[string]string{ + "environment_id": "aws1", + "region_url": "http://169.254.169.254/latest/meta-data/placement/availability-zone", + "url": "http://169.254.169.254/latest/meta-data/iam/security-credentials", + "regional_cred_verification_url": "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", + }, + }) + return fd.Name(), err +} + func TestOpen(t *testing.T) { ts := testSetup() defer ts.Close() @@ -229,7 +276,149 @@ func TestGetGroups(t *testing.T) { assert := assert.New(t) lookup := make(map[string]struct{}) - groups, err := conn.getGroups(testCase.userKey, testCase.fetchTransitiveGroupMembership, lookup) + groups, err := conn.getGroupsFromAdminAPI(testCase.userKey, testCase.fetchTransitiveGroupMembership, lookup) + if testCase.shouldErr { + assert.NotNil(err) + } else { + assert.Nil(err) + } + assert.ElementsMatch(testCase.expectedGroups, groups) + t.Logf("[%s] Amount of API calls per userKey: %+v\n", t.Name(), callCounter) + }) + } +} + +func TestGetGroupsWithCloudIdentityApi(t *testing.T) { + ts := testSetup() + defer ts.Close() + + serviceAccountFilePath, err := tempServiceAccountKey() + assert.Nil(t, err) + + os.Setenv("GOOGLE_APPLICATION_CREDENTIALS", serviceAccountFilePath) + conn, err := newConnector(&Config{ + ClientID: "testClient", + ClientSecret: "testSecret", + RedirectURI: ts.URL + "/callback", + Scopes: []string{"openid", "groups"}, + UseCloudIdentityAPI: true, + }) + assert.Nil(t, err) + + cloudIdentityService, err := cloudidentity.NewService(context.Background(), option.WithoutAuthentication(), option.WithEndpoint(ts.URL)) + assert.Nil(t, err) + conn.groupsMembershipsService = cloudidentity.NewGroupsMembershipsService(cloudIdentityService) + type testCase struct { + userKey string + fetchTransitiveGroupMembership bool + shouldErr bool + expectedGroups []string + } + + for name, testCase := range map[string]testCase{ + "user1_non_transitive_lookup": { + userKey: "user_1@dexidp.com", + fetchTransitiveGroupMembership: false, + shouldErr: false, + expectedGroups: []string{"groups_1@dexidp.com", "groups_2@dexidp.com"}, + }, + "user1_transitive_lookup": { + userKey: "user_1@dexidp.com", + fetchTransitiveGroupMembership: true, + shouldErr: false, + expectedGroups: []string{"groups_0@dexidp.com", "groups_1@dexidp.com", "groups_2@dexidp.com"}, + }, + "user2_non_transitive_lookup": { + userKey: "user_2@dexidp.com", + fetchTransitiveGroupMembership: false, + shouldErr: false, + expectedGroups: []string{"groups_1@dexidp.com"}, + }, + "user2_transitive_lookup": { + userKey: "user_2@dexidp.com", + fetchTransitiveGroupMembership: true, + shouldErr: false, + expectedGroups: []string{"groups_0@dexidp.com", "groups_1@dexidp.com"}, + }, + } { + testCase := testCase + callCounter = map[string]int{} + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + lookup := make(map[string]struct{}) + + groups, err := conn.getGroupsFromCloudIdentityAPI(testCase.userKey, testCase.fetchTransitiveGroupMembership, lookup) + if testCase.shouldErr { + assert.NotNil(err) + } else { + assert.Nil(err) + } + assert.ElementsMatch(testCase.expectedGroups, groups) + t.Logf("[%s] Amount of API calls per userKey: %+v\n", t.Name(), callCounter) + }) + } +} + +func TestGetGroupsWithCloudIdentityApiAndWorkloadIdentityFederation(t *testing.T) { + ts := testSetup() + defer ts.Close() + + workloadIdentityFederationFilePath, err := tempWorkloadIdentityFederation() + assert.Nil(t, err) + + os.Setenv("GOOGLE_APPLICATION_CREDENTIALS", workloadIdentityFederationFilePath) + conn, err := newConnector(&Config{ + ClientID: "testClient", + ClientSecret: "testSecret", + RedirectURI: ts.URL + "/callback", + Scopes: []string{"openid", "groups"}, + UseCloudIdentityAPI: true, + }) + assert.Nil(t, err) + + cloudIdentityService, err := cloudidentity.NewService(context.Background(), option.WithoutAuthentication(), option.WithEndpoint(ts.URL)) + assert.Nil(t, err) + conn.groupsMembershipsService = cloudidentity.NewGroupsMembershipsService(cloudIdentityService) + type testCase struct { + userKey string + fetchTransitiveGroupMembership bool + shouldErr bool + expectedGroups []string + } + + for name, testCase := range map[string]testCase{ + "user1_non_transitive_lookup": { + userKey: "user_1@dexidp.com", + fetchTransitiveGroupMembership: false, + shouldErr: false, + expectedGroups: []string{"groups_1@dexidp.com", "groups_2@dexidp.com"}, + }, + "user1_transitive_lookup": { + userKey: "user_1@dexidp.com", + fetchTransitiveGroupMembership: true, + shouldErr: false, + expectedGroups: []string{"groups_0@dexidp.com", "groups_1@dexidp.com", "groups_2@dexidp.com"}, + }, + "user2_non_transitive_lookup": { + userKey: "user_2@dexidp.com", + fetchTransitiveGroupMembership: false, + shouldErr: false, + expectedGroups: []string{"groups_1@dexidp.com"}, + }, + "user2_transitive_lookup": { + userKey: "user_2@dexidp.com", + fetchTransitiveGroupMembership: true, + shouldErr: false, + expectedGroups: []string{"groups_0@dexidp.com", "groups_1@dexidp.com"}, + }, + } { + testCase := testCase + callCounter = map[string]int{} + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + lookup := make(map[string]struct{}) + + groups, err := conn.getGroupsFromCloudIdentityAPI(testCase.userKey, testCase.fetchTransitiveGroupMembership, lookup) if testCase.shouldErr { assert.NotNil(err) } else { @@ -285,7 +474,7 @@ func TestDomainToAdminEmailConfig(t *testing.T) { assert := assert.New(t) lookup := make(map[string]struct{}) - _, err := conn.getGroups(testCase.userKey, true, lookup) + _, err := conn.getGroupsFromAdminAPI(testCase.userKey, true, lookup) if testCase.expectedErr != "" { assert.ErrorContains(err, testCase.expectedErr) } else { @@ -383,7 +572,7 @@ func TestGCEWorkloadIdentity(t *testing.T) { assert := assert.New(t) lookup := make(map[string]struct{}) - _, err := conn.getGroups(testCase.userKey, true, lookup) + _, err := conn.getGroupsFromAdminAPI(testCase.userKey, true, lookup) if testCase.expectedErr != "" { assert.ErrorContains(err, testCase.expectedErr) } else {