Browse Source

Merge 47538f8d08 into 93985dedff

pull/4169/merge
Michael Dudzinski 1 day ago committed by GitHub
parent
commit
f089e6cd00
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 170
      connector/google/google.go
  2. 195
      connector/google/google_test.go

170
connector/google/google.go

@ -17,6 +17,7 @@ import (
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
admin "google.golang.org/api/admin/directory/v1"
"google.golang.org/api/cloudidentity/v1"
"google.golang.org/api/impersonate"
"google.golang.org/api/option"
@ -53,11 +54,17 @@ type Config struct {
// Deprecated: Use DomainToAdminEmail
AdminEmail string
// Required if ServiceAccountFilePath
// If ServiceAccountFilePath is set, this value is ignored if UseCloudIdentityAPI is set. Otherwise, it's required.
// The map workspace domain to email of a GSuite super user which the service account will impersonate
// when listing groups
DomainToAdminEmail map[string]string
// If set, Cloud Identity API is used to fetch groups for a user. In particular, no user impersonation takes place.
// If ServiceAccountFilePath is not set, Application Default Credentials will be used. Otherwise, credentials will
// be generated from the file placed at the specified path.
// Defaults to false.
UseCloudIdentityAPI bool `json:"useCloudIdentityAPI"`
// If this field is true, fetch direct group membership and transitive group membership
FetchTransitiveGroupMembership bool `json:"fetchTransitiveGroupMembership"`
@ -66,6 +73,14 @@ type Config struct {
PromptType *string `json:"promptType"`
}
func validateConfigForCloudIdentity(c *Config, logger *slog.Logger) error {
if len(c.DomainToAdminEmail) > 0 || len(c.AdminEmail) > 0 {
logger.Warn("For cloud identity calls \"DomainToAdminEmail\" and \"AdminEmail\" are ignored. It's safe to remove both configuration options.")
}
return nil
}
// Open returns a connector which can be used to login users through Google.
func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector, err error) {
logger = logger.With(slog.Group("connector", "type", "google", "id", id))
@ -94,22 +109,38 @@ func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector,
adminSrv := make(map[string]*admin.Service)
// We know impersonation is required when using a service account credential
// TODO: or is it?
if len(c.DomainToAdminEmail) == 0 && c.ServiceAccountFilePath != "" {
cancel()
return nil, fmt.Errorf("directory service requires the domainToAdminEmail option to be configured")
}
var groupsMembershipsService *cloudidentity.GroupsMembershipsService
if (len(c.DomainToAdminEmail) > 0) || slices.Contains(scopes, "groups") {
for domain, adminEmail := range c.DomainToAdminEmail {
srv, err := createDirectoryService(c.ServiceAccountFilePath, adminEmail, logger)
if err != nil {
cancel()
return nil, fmt.Errorf("could not create directory service: %v", err)
}
if c.UseCloudIdentityAPI {
err = validateConfigForCloudIdentity(c, logger)
if err != nil {
cancel()
return nil, err
}
adminSrv[domain] = srv
groupsMembershipsService, err = createGroupsMembershipsService(c.ServiceAccountFilePath, logger)
if err != nil {
cancel()
return nil, fmt.Errorf("could not create groups memebership service: %v", err)
}
} else {
// We know impersonation is required when using a service account credential
// TODO: or is it?
if len(c.DomainToAdminEmail) == 0 && c.ServiceAccountFilePath != "" {
cancel()
return nil, fmt.Errorf("directory service requires the domainToAdminEmail option to be configured")
}
if (len(c.DomainToAdminEmail) > 0) || slices.Contains(scopes, "groups") {
for domain, adminEmail := range c.DomainToAdminEmail {
srv, err := createDirectoryService(c.ServiceAccountFilePath, adminEmail, logger)
if err != nil {
cancel()
return nil, fmt.Errorf("could not create directory service: %v", err)
}
adminSrv[domain] = srv
}
}
}
@ -137,8 +168,10 @@ func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector,
groups: c.Groups,
serviceAccountFilePath: c.ServiceAccountFilePath,
domainToAdminEmail: c.DomainToAdminEmail,
useCloudIdentityAPI: c.UseCloudIdentityAPI,
fetchTransitiveGroupMembership: c.FetchTransitiveGroupMembership,
adminSrv: adminSrv,
groupsMembershipsService: groupsMembershipsService,
promptType: promptType,
}, nil
}
@ -158,8 +191,10 @@ type googleConnector struct {
groups []string
serviceAccountFilePath string
domainToAdminEmail map[string]string
useCloudIdentityAPI bool
fetchTransitiveGroupMembership bool
adminSrv map[string]*admin.Service
groupsMembershipsService *cloudidentity.GroupsMembershipsService
promptType string
}
@ -262,11 +297,22 @@ func (c *googleConnector) createIdentity(ctx context.Context, identity connector
}
var groups []string
if s.Groups && len(c.adminSrv) > 0 {
if s.Groups {
checkedGroups := make(map[string]struct{})
groups, err = c.getGroups(claims.Email, c.fetchTransitiveGroupMembership, checkedGroups)
if err != nil {
return identity, fmt.Errorf("google: could not retrieve groups: %v", err)
if c.useCloudIdentityAPI {
if c.groupsMembershipsService != nil {
groups, err = c.getGroupsFromCloudIdentityAPI(claims.Email, c.fetchTransitiveGroupMembership, checkedGroups)
if err != nil {
return identity, fmt.Errorf("google: could not retrieve groups from Cloud Identity API: %v", err)
}
}
} else {
if len(c.adminSrv) > 0 {
groups, err = c.getGroupsFromAdminAPI(claims.Email, c.fetchTransitiveGroupMembership, checkedGroups)
if err != nil {
return identity, fmt.Errorf("google: could not retrieve groups form Admin API: %v", err)
}
}
}
if len(c.groups) > 0 {
@ -288,9 +334,9 @@ func (c *googleConnector) createIdentity(ctx context.Context, identity connector
return identity, nil
}
// getGroups creates a connection to the admin directory service and lists
// getGroupsFromAdminAPI creates a connection to the admin directory service and lists
// all groups the user is a member of
func (c *googleConnector) getGroups(email string, fetchTransitiveGroupMembership bool, checkedGroups map[string]struct{}) ([]string, error) {
func (c *googleConnector) getGroupsFromAdminAPI(email string, fetchTransitiveGroupMembership bool, checkedGroups map[string]struct{}) ([]string, error) {
var userGroups []string
var err error
groupsList := &admin.Groups{}
@ -321,7 +367,53 @@ func (c *googleConnector) getGroups(email string, fetchTransitiveGroupMembership
}
// getGroups takes a user's email/alias as well as a group's email/alias
transitiveGroups, err := c.getGroups(group.Email, fetchTransitiveGroupMembership, checkedGroups)
transitiveGroups, err := c.getGroupsFromAdminAPI(group.Email, fetchTransitiveGroupMembership, checkedGroups)
if err != nil {
return nil, fmt.Errorf("could not list transitive groups: %v", err)
}
userGroups = append(userGroups, transitiveGroups...)
}
if groupsList.NextPageToken == "" {
break
}
}
return userGroups, nil
}
// getGroupsFromCloudIdentityAPI creates a connection to the cloud identity service and lists
// all groups the user is a member of
func (c *googleConnector) getGroupsFromCloudIdentityAPI(email string, fetchTransitiveGroupMembership bool, checkedGroups map[string]struct{}) ([]string, error) {
var userGroups []string
var err error
groupsList := &cloudidentity.SearchDirectGroupsResponse{}
groupsMembershipService := c.groupsMembershipsService
for {
query := fmt.Sprintf("member_key_id=='%s'", email)
groupsList, err = groupsMembershipService.SearchDirectGroups("groups/-").
Query(query).PageToken(groupsList.NextPageToken).Do()
if err != nil {
return nil, fmt.Errorf("could not list groups: %v", err)
}
for _, membership := range groupsList.Memberships {
groupEmail := strings.ToLower(membership.GroupKey.Id)
if _, exists := checkedGroups[groupEmail]; exists {
continue
}
checkedGroups[groupEmail] = struct{}{}
// TODO (joelspeed): Make desired group key configurable
userGroups = append(userGroups, groupEmail)
if !fetchTransitiveGroupMembership {
continue
}
transitiveGroups, err := c.getGroupsFromCloudIdentityAPI(groupEmail, fetchTransitiveGroupMembership, checkedGroups)
if err != nil {
return nil, fmt.Errorf("could not list transitive groups: %v", err)
}
@ -455,3 +547,37 @@ func createDirectoryService(serviceAccountFilePath, email string, logger *slog.L
return admin.NewService(ctx, option.WithHTTPClient(config.Client(ctx)))
}
func createGroupsMembershipsService(serviceAccountFilePath string, logger *slog.Logger) (service *cloudidentity.GroupsMembershipsService, err error) {
ctx := context.Background()
var credentials *google.Credentials
var cloudIdentityService *cloudidentity.Service
if serviceAccountFilePath == "" {
logger.Info("Using Application Default Credentials")
cloudIdentityService, err = cloudidentity.NewService(ctx, option.WithScopes(cloudidentity.CloudIdentityGroupsReadonlyScope))
if err != nil {
return nil, fmt.Errorf("error creating cloud identity service: %v", err)
}
} else {
logger.Info("Using credentials file at", "sa_path", serviceAccountFilePath)
jsonCredentials, err := os.ReadFile(serviceAccountFilePath)
if err != nil {
return nil, fmt.Errorf("error reading credentials from file: %v", err)
}
credentials, err = google.CredentialsFromJSON(ctx, jsonCredentials, cloudidentity.CloudIdentityGroupsReadonlyScope)
if err != nil {
return nil, fmt.Errorf("failed creating credentials from file: %w", err)
}
cloudIdentityService, err = cloudidentity.NewService(ctx, option.WithCredentials(credentials))
if err != nil {
return nil, fmt.Errorf("error creating cloud identity service: %v", err)
}
}
return cloudidentity.NewGroupsMembershipsService(cloudIdentityService), nil
}

195
connector/google/google_test.go

@ -9,11 +9,13 @@ import (
"net/http/httptest"
"net/url"
"os"
"regexp"
"strings"
"testing"
"github.com/stretchr/testify/assert"
admin "google.golang.org/api/admin/directory/v1"
"google.golang.org/api/cloudidentity/v1"
"google.golang.org/api/option"
"github.com/dexidp/dex/connector"
@ -37,6 +39,7 @@ var (
func testSetup() *httptest.Server {
mux := http.NewServeMux()
re := regexp.MustCompile(`^member_key_id\s*==\s*'(.+)'$`)
mux.HandleFunc("/admin/directory/v1/groups/", func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Content-Type", "application/json")
@ -47,6 +50,27 @@ func testSetup() *httptest.Server {
}
})
mux.HandleFunc("/v1/groups/-/memberships:searchDirectGroups", func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Content-Type", "application/json")
query := r.URL.Query().Get("query")
userKey := re.FindStringSubmatch(query)[1]
if groups, ok := testGroups[userKey]; ok {
var memberships []*cloudidentity.MembershipRelation
for _, group := range groups {
memberships = append(
memberships,
&cloudidentity.MembershipRelation{
GroupKey: &cloudidentity.EntityKey{Id: group.Email},
},
)
}
json.NewEncoder(w).Encode(cloudidentity.SearchDirectGroupsResponse{Memberships: memberships})
callCounter[userKey]++
}
})
return httptest.NewServer(mux)
}
@ -75,12 +99,35 @@ func tempServiceAccountKey() (string, error) {
"project_id": "sample-project",
"private_key_id": "sample-key-id",
"private_key": "-----BEGIN PRIVATE KEY-----\nsample-key\n-----END PRIVATE KEY-----\n",
"client_email": "service-account@example.com",
"client_id": "sample-client-id",
"client_x509_cert_url": "localhost",
})
return fd.Name(), err
}
func tempWorkloadIdentityFederation() (string, error) {
fd, err := os.CreateTemp("", "workload_identity_federation")
if err != nil {
return "", err
}
defer fd.Close()
err = json.NewEncoder(fd).Encode(map[string]any{
"type": "external_account",
"audience": "//iam.googleapis.com/projects/111111111111/locations/global/workloadIdentityPools/aws-pool/providers/aws-provider",
"subject_token_type": "urn:ietf:params:aws:token-type:aws4_request",
"service_account_impersonation_url": "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/sa@my-google-project.iam.gserviceaccount.com:generateAccessToken",
"token_url": "https://sts.googleapis.com/v1/token",
"credential_source": map[string]string{
"environment_id": "aws1",
"region_url": "http://169.254.169.254/latest/meta-data/placement/availability-zone",
"url": "http://169.254.169.254/latest/meta-data/iam/security-credentials",
"regional_cred_verification_url": "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
},
})
return fd.Name(), err
}
func TestOpen(t *testing.T) {
ts := testSetup()
defer ts.Close()
@ -229,7 +276,149 @@ func TestGetGroups(t *testing.T) {
assert := assert.New(t)
lookup := make(map[string]struct{})
groups, err := conn.getGroups(testCase.userKey, testCase.fetchTransitiveGroupMembership, lookup)
groups, err := conn.getGroupsFromAdminAPI(testCase.userKey, testCase.fetchTransitiveGroupMembership, lookup)
if testCase.shouldErr {
assert.NotNil(err)
} else {
assert.Nil(err)
}
assert.ElementsMatch(testCase.expectedGroups, groups)
t.Logf("[%s] Amount of API calls per userKey: %+v\n", t.Name(), callCounter)
})
}
}
func TestGetGroupsWithCloudIdentityApi(t *testing.T) {
ts := testSetup()
defer ts.Close()
serviceAccountFilePath, err := tempServiceAccountKey()
assert.Nil(t, err)
os.Setenv("GOOGLE_APPLICATION_CREDENTIALS", serviceAccountFilePath)
conn, err := newConnector(&Config{
ClientID: "testClient",
ClientSecret: "testSecret",
RedirectURI: ts.URL + "/callback",
Scopes: []string{"openid", "groups"},
UseCloudIdentityAPI: true,
})
assert.Nil(t, err)
cloudIdentityService, err := cloudidentity.NewService(context.Background(), option.WithoutAuthentication(), option.WithEndpoint(ts.URL))
assert.Nil(t, err)
conn.groupsMembershipsService = cloudidentity.NewGroupsMembershipsService(cloudIdentityService)
type testCase struct {
userKey string
fetchTransitiveGroupMembership bool
shouldErr bool
expectedGroups []string
}
for name, testCase := range map[string]testCase{
"user1_non_transitive_lookup": {
userKey: "user_1@dexidp.com",
fetchTransitiveGroupMembership: false,
shouldErr: false,
expectedGroups: []string{"groups_1@dexidp.com", "groups_2@dexidp.com"},
},
"user1_transitive_lookup": {
userKey: "user_1@dexidp.com",
fetchTransitiveGroupMembership: true,
shouldErr: false,
expectedGroups: []string{"groups_0@dexidp.com", "groups_1@dexidp.com", "groups_2@dexidp.com"},
},
"user2_non_transitive_lookup": {
userKey: "user_2@dexidp.com",
fetchTransitiveGroupMembership: false,
shouldErr: false,
expectedGroups: []string{"groups_1@dexidp.com"},
},
"user2_transitive_lookup": {
userKey: "user_2@dexidp.com",
fetchTransitiveGroupMembership: true,
shouldErr: false,
expectedGroups: []string{"groups_0@dexidp.com", "groups_1@dexidp.com"},
},
} {
testCase := testCase
callCounter = map[string]int{}
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
lookup := make(map[string]struct{})
groups, err := conn.getGroupsFromCloudIdentityAPI(testCase.userKey, testCase.fetchTransitiveGroupMembership, lookup)
if testCase.shouldErr {
assert.NotNil(err)
} else {
assert.Nil(err)
}
assert.ElementsMatch(testCase.expectedGroups, groups)
t.Logf("[%s] Amount of API calls per userKey: %+v\n", t.Name(), callCounter)
})
}
}
func TestGetGroupsWithCloudIdentityApiAndWorkloadIdentityFederation(t *testing.T) {
ts := testSetup()
defer ts.Close()
workloadIdentityFederationFilePath, err := tempWorkloadIdentityFederation()
assert.Nil(t, err)
os.Setenv("GOOGLE_APPLICATION_CREDENTIALS", workloadIdentityFederationFilePath)
conn, err := newConnector(&Config{
ClientID: "testClient",
ClientSecret: "testSecret",
RedirectURI: ts.URL + "/callback",
Scopes: []string{"openid", "groups"},
UseCloudIdentityAPI: true,
})
assert.Nil(t, err)
cloudIdentityService, err := cloudidentity.NewService(context.Background(), option.WithoutAuthentication(), option.WithEndpoint(ts.URL))
assert.Nil(t, err)
conn.groupsMembershipsService = cloudidentity.NewGroupsMembershipsService(cloudIdentityService)
type testCase struct {
userKey string
fetchTransitiveGroupMembership bool
shouldErr bool
expectedGroups []string
}
for name, testCase := range map[string]testCase{
"user1_non_transitive_lookup": {
userKey: "user_1@dexidp.com",
fetchTransitiveGroupMembership: false,
shouldErr: false,
expectedGroups: []string{"groups_1@dexidp.com", "groups_2@dexidp.com"},
},
"user1_transitive_lookup": {
userKey: "user_1@dexidp.com",
fetchTransitiveGroupMembership: true,
shouldErr: false,
expectedGroups: []string{"groups_0@dexidp.com", "groups_1@dexidp.com", "groups_2@dexidp.com"},
},
"user2_non_transitive_lookup": {
userKey: "user_2@dexidp.com",
fetchTransitiveGroupMembership: false,
shouldErr: false,
expectedGroups: []string{"groups_1@dexidp.com"},
},
"user2_transitive_lookup": {
userKey: "user_2@dexidp.com",
fetchTransitiveGroupMembership: true,
shouldErr: false,
expectedGroups: []string{"groups_0@dexidp.com", "groups_1@dexidp.com"},
},
} {
testCase := testCase
callCounter = map[string]int{}
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
lookup := make(map[string]struct{})
groups, err := conn.getGroupsFromCloudIdentityAPI(testCase.userKey, testCase.fetchTransitiveGroupMembership, lookup)
if testCase.shouldErr {
assert.NotNil(err)
} else {
@ -285,7 +474,7 @@ func TestDomainToAdminEmailConfig(t *testing.T) {
assert := assert.New(t)
lookup := make(map[string]struct{})
_, err := conn.getGroups(testCase.userKey, true, lookup)
_, err := conn.getGroupsFromAdminAPI(testCase.userKey, true, lookup)
if testCase.expectedErr != "" {
assert.ErrorContains(err, testCase.expectedErr)
} else {
@ -383,7 +572,7 @@ func TestGCEWorkloadIdentity(t *testing.T) {
assert := assert.New(t)
lookup := make(map[string]struct{})
_, err := conn.getGroups(testCase.userKey, true, lookup)
_, err := conn.getGroupsFromAdminAPI(testCase.userKey, true, lookup)
if testCase.expectedErr != "" {
assert.ErrorContains(err, testCase.expectedErr)
} else {

Loading…
Cancel
Save