Browse Source

Merge efcb0d1f4e into e6740971b1

pull/4138/merge
Alexandre Barone 2 months ago committed by GitHub
parent
commit
d9d055374d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 164
      connector/google/google.go
  2. 150
      connector/google/google_test.go
  3. 2
      go.mod
  4. 4
      go.sum

164
connector/google/google.go

@ -17,6 +17,7 @@ import (
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
admin "google.golang.org/api/admin/directory/v1"
cloudidentity "google.golang.org/api/cloudidentity/v1"
"google.golang.org/api/impersonate"
"google.golang.org/api/option"
@ -160,6 +161,7 @@ type googleConnector struct {
domainToAdminEmail map[string]string
fetchTransitiveGroupMembership bool
adminSrv map[string]*admin.Service
cloudIdentitySrv *cloudidentity.Service
promptType string
}
@ -261,19 +263,15 @@ func (c *googleConnector) createIdentity(ctx context.Context, identity connector
}
}
var groups []string
if s.Groups && len(c.adminSrv) > 0 {
checkedGroups := make(map[string]struct{})
groups, err = c.getGroups(claims.Email, c.fetchTransitiveGroupMembership, checkedGroups)
if err != nil {
return identity, fmt.Errorf("google: could not retrieve groups: %v", err)
}
groups, err := c.getGroups(ctx, s, token, claims.Email, c.fetchTransitiveGroupMembership)
if err != nil {
return identity, fmt.Errorf("google: could not retrieve groups: %v", err)
}
if len(c.groups) > 0 {
groups = pkg_groups.Filter(groups, c.groups)
if len(groups) == 0 {
return identity, fmt.Errorf("google: user %q is not in any of the required groups", claims.Username)
}
if len(c.groups) > 0 {
groups = pkg_groups.Filter(groups, c.groups)
if len(groups) == 0 {
return identity, fmt.Errorf("google: user %q is not in any of the required groups", claims.Username)
}
}
@ -288,49 +286,131 @@ func (c *googleConnector) createIdentity(ctx context.Context, identity connector
return identity, nil
}
// getGroups creates a connection to the admin directory service and lists
// all groups the user is a member of
func (c *googleConnector) getGroups(email string, fetchTransitiveGroupMembership bool, checkedGroups map[string]struct{}) ([]string, error) {
var userGroups []string
var err error
groupsList := &admin.Groups{}
domain := c.extractDomainFromEmail(email)
adminSrv, err := c.findAdminService(domain)
// getGroups makes a best effort attempt to retrieve groups the provided email is a member of.
func (c *googleConnector) getGroups(ctx context.Context, s connector.Scopes, token *oauth2.Token, email string, fetchTransitiveGroupMembership bool) ([]string, error) {
if c.hasCloudIdentityGroupScope() {
return c.getGroupsUsingUser(ctx, token, email, fetchTransitiveGroupMembership)
}
if s.Groups && c.adminSrv != nil {
return c.getGroupsUsingAdmin(ctx, email, fetchTransitiveGroupMembership)
}
return nil, nil
}
// hasCloudIdentityGroupScope returns true if the provided scopes contain the required scope
// to retrieve groups from the Cloud Identity API.
func (c *googleConnector) hasCloudIdentityGroupScope() bool {
return slices.ContainsFunc(c.oauth2Config.Scopes, func(s string) bool {
return s == cloudidentity.CloudPlatformScope || s == cloudidentity.CloudIdentityGroupsScope || s == cloudidentity.CloudIdentityGroupsReadonlyScope
})
}
// getGroupsUsingUser uses the Cloud Identity API to retrieve groups the provided email is a member
// of. This method returns groups from direct and transitive memberships.
//
// In contrast to getGroupsUsingAdmin, this method does NOT require the authenticated client
// to have been granted domain-wide delegation. Instead, it relies on the OAuth access token
// having the appropriate permissions from requesting the Cloud Identity API scope
// (see hasCloudIdentityGroupScope).
func (c *googleConnector) getGroupsUsingUser(ctx context.Context, token *oauth2.Token, email string, fetchTransitiveGroupMembership bool) ([]string, error) {
svc, err := c.createCloudIdentityService(ctx, token)
if err != nil {
return nil, err
}
for {
groupsList, err = adminSrv.Groups.List().
UserKey(email).PageToken(groupsList.NextPageToken).Do()
if err != nil {
return nil, fmt.Errorf("could not list groups: %v", err)
}
for _, group := range groupsList.Groups {
if _, exists := checkedGroups[group.Email]; exists {
continue
var userGroups []string
resp := &cloudidentity.SearchDirectGroupsResponse{}
checkedGroups := make(map[string]struct{})
stack := []string{email}
for len(stack) > 0 {
n := len(stack) - 1
curEmail := stack[n]
stack = stack[:n]
query := fmt.Sprintf("member_key_id=='%s'", curEmail)
for {
resp, err = svc.Groups.Memberships.SearchDirectGroups("groups/-").
Context(ctx).Query(query).PageToken(resp.NextPageToken).Do()
if err != nil {
return nil, fmt.Errorf("could not list groups: %v", err)
}
checkedGroups[group.Email] = struct{}{}
// TODO (joelspeed): Make desired group key configurable
userGroups = append(userGroups, group.Email)
for _, m := range resp.Memberships {
group := m.GroupKey.Id
if _, exists := checkedGroups[group]; exists {
continue
}
checkedGroups[group] = struct{}{}
userGroups = append(userGroups, group)
if !fetchTransitiveGroupMembership {
continue
}
stack = append(stack, group)
}
if !fetchTransitiveGroupMembership {
continue
if resp.NextPageToken == "" {
break
}
}
}
return userGroups, nil
}
// createCloudIdentityService is a small helper useful for testing by allowing to override
// the cloud identity service.
func (c *googleConnector) createCloudIdentityService(ctx context.Context, token *oauth2.Token) (*cloudidentity.Service, error) {
if c.cloudIdentitySrv != nil {
return c.cloudIdentitySrv, nil
}
return cloudidentity.NewService(ctx, option.WithHTTPClient(c.oauth2Config.Client(ctx, token)))
}
// getGroupsUsingAdmin uses the Admin SDK API to retrieve groups the provided email is a member of.
//
// This method requires the authenticated client to have been granted domain-wide delegation.
func (c *googleConnector) getGroupsUsingAdmin(ctx context.Context, email string, fetchTransitiveGroupMembership bool) ([]string, error) {
domain := c.extractDomainFromEmail(email)
adminSrv, err := c.findAdminService(domain)
if err != nil {
return nil, err
}
// getGroups takes a user's email/alias as well as a group's email/alias
transitiveGroups, err := c.getGroups(group.Email, fetchTransitiveGroupMembership, checkedGroups)
var userGroups []string
resp := &admin.Groups{}
checkedGroups := make(map[string]struct{})
stack := []string{email}
for len(stack) > 0 {
n := len(stack) - 1
curEmail := stack[n]
stack = stack[:n]
for {
resp, err = adminSrv.Groups.List().
Context(ctx).UserKey(curEmail).PageToken(resp.NextPageToken).Do()
if err != nil {
return nil, fmt.Errorf("could not list transitive groups: %v", err)
return nil, fmt.Errorf("could not list groups: %v", err)
}
userGroups = append(userGroups, transitiveGroups...)
}
for _, g := range resp.Groups {
group := g.Email
if _, exists := checkedGroups[group]; exists {
continue
}
checkedGroups[group] = struct{}{}
// TODO (joelspeed): Make desired group key configurable
userGroups = append(userGroups, group)
if !fetchTransitiveGroupMembership {
continue
}
stack = append(stack, group)
}
if groupsList.NextPageToken == "" {
break
if resp.NextPageToken == "" {
break
}
}
}

150
connector/google/google_test.go

@ -1,7 +1,6 @@
package google
import (
"context"
"encoding/json"
"fmt"
"log/slog"
@ -9,11 +8,14 @@ import (
"net/http/httptest"
"net/url"
"os"
"regexp"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"golang.org/x/oauth2"
admin "google.golang.org/api/admin/directory/v1"
cloudidentity "google.golang.org/api/cloudidentity/v1"
"google.golang.org/api/option"
"github.com/dexidp/dex/connector"
@ -25,26 +27,65 @@ var (
// groups_2 groups_1
// │ ├────────┐
// └── user_1 user_2
testGroups = map[string][]*admin.Group{
adminTestGroups = map[string][]*admin.Group{
"user_1@dexidp.com": {{Email: "groups_2@dexidp.com"}, {Email: "groups_1@dexidp.com"}},
"user_2@dexidp.com": {{Email: "groups_1@dexidp.com"}},
"groups_1@dexidp.com": {{Email: "groups_0@dexidp.com"}},
"groups_2@dexidp.com": {{Email: "groups_0@dexidp.com"}},
"groups_0@dexidp.com": {},
}
callCounter = make(map[string]int)
adminCallCounter = make(map[string]int)
cloudIdentityTestGroups = adminGroupsToMemberships(adminTestGroups)
cloudIdentityCallCounter = make(map[string]int)
cloudIdentityMemberKeyIdRE = regexp.MustCompile("member_key_id=='([^']+)'")
)
func adminGroupsToMemberships(testGroups map[string][]*admin.Group) map[string][]*cloudidentity.MembershipRelation {
mapped := make(map[string][]*cloudidentity.MembershipRelation)
for key, groups := range testGroups {
mapped[key] = make([]*cloudidentity.MembershipRelation, 0, len(groups))
for _, g := range groups {
mapped[key] = append(mapped[key], &cloudidentity.MembershipRelation{
GroupKey: &cloudidentity.EntityKey{Id: g.Email},
})
}
}
return mapped
}
func testSetup() *httptest.Server {
mux := http.NewServeMux()
// https://developers.google.com/workspace/admin/directory/reference/rest
mux.HandleFunc("/admin/directory/v1/groups/", func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Content-Type", "application/json")
resp := admin.Groups{}
userKey := r.URL.Query().Get("userKey")
if groups, ok := testGroups[userKey]; ok {
json.NewEncoder(w).Encode(admin.Groups{Groups: groups})
callCounter[userKey]++
adminCallCounter[userKey]++
if groups, ok := adminTestGroups[userKey]; ok {
resp.Groups = groups
}
w.Header().Add("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
})
// https://cloud.google.com/identity/docs/reference/rest
mux.HandleFunc("/v1/groups/-/memberships:searchDirectGroups", func(w http.ResponseWriter, r *http.Request) {
resp := cloudidentity.SearchDirectGroupsResponse{}
query := r.URL.Query().Get("query")
memberKeyIdValue := cloudIdentityMemberKeyIdRE.FindSubmatch([]byte(query))
if len(memberKeyIdValue) > 1 {
userKey := string(memberKeyIdValue[1])
cloudIdentityCallCounter[userKey]++
if memberships, ok := cloudIdentityTestGroups[userKey]; ok {
resp.Memberships = memberships
}
}
w.Header().Add("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
})
return httptest.NewServer(mux)
@ -171,7 +212,65 @@ func TestOpen(t *testing.T) {
}
}
func TestGetGroups(t *testing.T) {
func TestGetGroupsUsingUser(t *testing.T) {
ts := testSetup()
defer ts.Close()
conn, err := newConnector(&Config{
ClientID: "testClient",
ClientSecret: "testSecret",
RedirectURI: ts.URL + "/callback",
Scopes: []string{"openid", "https://www.googleapis.com/auth/cloud-identity.groups.readonly"},
})
assert.Nil(t, err)
conn.cloudIdentitySrv, err = cloudidentity.NewService(t.Context(), option.WithoutAuthentication(), option.WithEndpoint(ts.URL))
assert.Nil(t, err)
type testCase struct {
userKey string
fetchTransitiveGroupMembership bool
expectedGroups []string
}
for name, testCase := range map[string]testCase{
"user1_non_transitive_lookup": {
userKey: "user_1@dexidp.com",
fetchTransitiveGroupMembership: false,
expectedGroups: []string{"groups_1@dexidp.com", "groups_2@dexidp.com"},
},
"user1_transitive_lookup": {
userKey: "user_1@dexidp.com",
fetchTransitiveGroupMembership: true,
expectedGroups: []string{"groups_0@dexidp.com", "groups_1@dexidp.com", "groups_2@dexidp.com"},
},
"user2_non_transitive_lookup": {
userKey: "user_2@dexidp.com",
fetchTransitiveGroupMembership: false,
expectedGroups: []string{"groups_1@dexidp.com"},
},
"user2_transitive_lookup": {
userKey: "user_2@dexidp.com",
fetchTransitiveGroupMembership: true,
expectedGroups: []string{"groups_0@dexidp.com", "groups_1@dexidp.com"},
},
"user_not_found": {
userKey: "user_3@dexidp.com",
},
} {
testCase := testCase
cloudIdentityCallCounter = map[string]int{}
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
groups, err := conn.getGroupsUsingUser(t.Context(), &oauth2.Token{}, testCase.userKey, testCase.fetchTransitiveGroupMembership)
assert.Nil(err)
assert.ElementsMatch(testCase.expectedGroups, groups)
t.Logf("[%s] Amount of API calls per userKey: %+v\n", t.Name(), cloudIdentityCallCounter)
})
}
}
func TestGetGroupsUsingAdmin(t *testing.T) {
ts := testSetup()
defer ts.Close()
@ -188,12 +287,11 @@ func TestGetGroups(t *testing.T) {
})
assert.Nil(t, err)
conn.adminSrv[wildcardDomainToAdminEmail], err = admin.NewService(context.Background(), option.WithoutAuthentication(), option.WithEndpoint(ts.URL))
conn.adminSrv[wildcardDomainToAdminEmail], err = admin.NewService(t.Context(), option.WithoutAuthentication(), option.WithEndpoint(ts.URL))
assert.Nil(t, err)
type testCase struct {
userKey string
fetchTransitiveGroupMembership bool
shouldErr bool
expectedGroups []string
}
@ -201,42 +299,33 @@ func TestGetGroups(t *testing.T) {
"user1_non_transitive_lookup": {
userKey: "user_1@dexidp.com",
fetchTransitiveGroupMembership: false,
shouldErr: false,
expectedGroups: []string{"groups_1@dexidp.com", "groups_2@dexidp.com"},
},
"user1_transitive_lookup": {
userKey: "user_1@dexidp.com",
fetchTransitiveGroupMembership: true,
shouldErr: false,
expectedGroups: []string{"groups_0@dexidp.com", "groups_1@dexidp.com", "groups_2@dexidp.com"},
},
"user2_non_transitive_lookup": {
userKey: "user_2@dexidp.com",
fetchTransitiveGroupMembership: false,
shouldErr: false,
expectedGroups: []string{"groups_1@dexidp.com"},
},
"user2_transitive_lookup": {
userKey: "user_2@dexidp.com",
fetchTransitiveGroupMembership: true,
shouldErr: false,
expectedGroups: []string{"groups_0@dexidp.com", "groups_1@dexidp.com"},
},
} {
testCase := testCase
callCounter = map[string]int{}
adminCallCounter = map[string]int{}
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
lookup := make(map[string]struct{})
groups, err := conn.getGroups(testCase.userKey, testCase.fetchTransitiveGroupMembership, lookup)
if testCase.shouldErr {
assert.NotNil(err)
} else {
assert.Nil(err)
}
groups, err := conn.getGroupsUsingAdmin(t.Context(), testCase.userKey, testCase.fetchTransitiveGroupMembership)
assert.Nil(err)
assert.ElementsMatch(testCase.expectedGroups, groups)
t.Logf("[%s] Amount of API calls per userKey: %+v\n", t.Name(), callCounter)
t.Logf("[%s] Amount of API calls per userKey: %+v\n", t.Name(), adminCallCounter)
})
}
}
@ -258,7 +347,7 @@ func TestDomainToAdminEmailConfig(t *testing.T) {
})
assert.Nil(t, err)
conn.adminSrv["dexidp.com"], err = admin.NewService(context.Background(), option.WithoutAuthentication(), option.WithEndpoint(ts.URL))
conn.adminSrv["dexidp.com"], err = admin.NewService(t.Context(), option.WithoutAuthentication(), option.WithEndpoint(ts.URL))
assert.Nil(t, err)
type testCase struct {
userKey string
@ -280,18 +369,17 @@ func TestDomainToAdminEmailConfig(t *testing.T) {
},
} {
testCase := testCase
callCounter = map[string]int{}
adminCallCounter = map[string]int{}
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
lookup := make(map[string]struct{})
_, err := conn.getGroups(testCase.userKey, true, lookup)
_, err := conn.getGroupsUsingAdmin(t.Context(), testCase.userKey, true)
if testCase.expectedErr != "" {
assert.ErrorContains(err, testCase.expectedErr)
} else {
assert.Nil(err)
}
t.Logf("[%s] Amount of API calls per userKey: %+v\n", t.Name(), callCounter)
t.Logf("[%s] Amount of API calls per userKey: %+v\n", t.Name(), adminCallCounter)
})
}
}
@ -337,6 +425,7 @@ func TestGCEWorkloadIdentity(t *testing.T) {
os.Setenv("GCE_METADATA_HOST", metadataServerHost)
os.Setenv("GOOGLE_APPLICATION_CREDENTIALS", "")
os.Setenv("HOME", "/tmp")
os.Setenv("APPDATA", "/tmp")
gceMetadataFlags["failOnEmailRequest"] = true
_, err := newConnector(&Config{
@ -358,7 +447,7 @@ func TestGCEWorkloadIdentity(t *testing.T) {
})
assert.Nil(t, err)
conn.adminSrv["dexidp.com"], err = admin.NewService(context.Background(), option.WithoutAuthentication(), option.WithEndpoint(ts.URL))
conn.adminSrv["dexidp.com"], err = admin.NewService(t.Context(), option.WithoutAuthentication(), option.WithEndpoint(ts.URL))
assert.Nil(t, err)
type testCase struct {
userKey string
@ -381,9 +470,8 @@ func TestGCEWorkloadIdentity(t *testing.T) {
} {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
lookup := make(map[string]struct{})
_, err := conn.getGroups(testCase.userKey, true, lookup)
_, err := conn.getGroupsUsingAdmin(t.Context(), testCase.userKey, true)
if testCase.expectedErr != "" {
assert.ErrorContains(err, testCase.expectedErr)
} else {

2
go.mod

@ -33,7 +33,7 @@ require (
go.etcd.io/etcd/client/pkg/v3 v3.6.7
go.etcd.io/etcd/client/v3 v3.6.7
golang.org/x/crypto v0.46.0
golang.org/x/exp v0.0.0-20221004215720-b9f4876ce741
golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6
golang.org/x/net v0.48.0
golang.org/x/oauth2 v0.34.0
google.golang.org/api v0.257.0

4
go.sum

@ -233,8 +233,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
golang.org/x/exp v0.0.0-20221004215720-b9f4876ce741 h1:fGZugkZk2UgYBxtpKmvub51Yno1LJDeEsRp2xGD+0gY=
golang.org/x/exp v0.0.0-20221004215720-b9f4876ce741/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE=
golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6 h1:y5zboxd6LQAqYIhHnB48p0ByQ/GnQx2BE33L8BOHQkI=
golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6/go.mod h1:U6Lno4MTRCDY+Ba7aCcauB9T60gsv5s4ralQzP72ZoQ=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=

Loading…
Cancel
Save