|
|
|
|
@ -1,6 +1,7 @@
|
|
|
|
|
package google |
|
|
|
|
|
|
|
|
|
import ( |
|
|
|
|
"context" |
|
|
|
|
"encoding/json" |
|
|
|
|
"fmt" |
|
|
|
|
"net/http" |
|
|
|
|
@ -10,17 +11,38 @@ import (
|
|
|
|
|
|
|
|
|
|
"github.com/sirupsen/logrus" |
|
|
|
|
"github.com/stretchr/testify/assert" |
|
|
|
|
admin "google.golang.org/api/admin/directory/v1" |
|
|
|
|
"google.golang.org/api/option" |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
var ( |
|
|
|
|
// groups_0
|
|
|
|
|
// ┌───────┤
|
|
|
|
|
// groups_2 groups_1
|
|
|
|
|
// │ ├────────┐
|
|
|
|
|
// └── user_1 user_2
|
|
|
|
|
testGroups = map[string][]*admin.Group{ |
|
|
|
|
"user_1@dexidp.com": {{Email: "groups_2@dexidp.com"}, {Email: "groups_1@dexidp.com"}}, |
|
|
|
|
"user_2@dexidp.com": {{Email: "groups_1@dexidp.com"}}, |
|
|
|
|
"groups_1@dexidp.com": {{Email: "groups_0@dexidp.com"}}, |
|
|
|
|
"groups_2@dexidp.com": {{Email: "groups_0@dexidp.com"}}, |
|
|
|
|
"groups_0@dexidp.com": {}, |
|
|
|
|
} |
|
|
|
|
callCounter = make(map[string]int) |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
func testSetup(t *testing.T) *httptest.Server { |
|
|
|
|
mux := http.NewServeMux() |
|
|
|
|
// TODO: mock calls
|
|
|
|
|
// mux.HandleFunc("/admin/directory/v1/groups", func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
|
// w.Header().Add("Content-Type", "application/json")
|
|
|
|
|
// json.NewEncoder(w).Encode(&admin.Groups{
|
|
|
|
|
// Groups: []*admin.Group{},
|
|
|
|
|
// })
|
|
|
|
|
// })
|
|
|
|
|
|
|
|
|
|
mux.HandleFunc("/admin/directory/v1/groups/", func(w http.ResponseWriter, r *http.Request) { |
|
|
|
|
w.Header().Add("Content-Type", "application/json") |
|
|
|
|
userKey := r.URL.Query().Get("userKey") |
|
|
|
|
if groups, ok := testGroups[userKey]; ok { |
|
|
|
|
json.NewEncoder(w).Encode(admin.Groups{Groups: groups}) |
|
|
|
|
callCounter[userKey]++ |
|
|
|
|
} |
|
|
|
|
}) |
|
|
|
|
|
|
|
|
|
return httptest.NewServer(mux) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
@ -144,3 +166,73 @@ func TestOpen(t *testing.T) {
|
|
|
|
|
}) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func TestGetGroups(t *testing.T) { |
|
|
|
|
ts := testSetup(t) |
|
|
|
|
defer ts.Close() |
|
|
|
|
|
|
|
|
|
serviceAccountFilePath, err := tempServiceAccountKey() |
|
|
|
|
assert.Nil(t, err) |
|
|
|
|
|
|
|
|
|
os.Setenv("GOOGLE_APPLICATION_CREDENTIALS", serviceAccountFilePath) |
|
|
|
|
conn, err := newConnector(&Config{ |
|
|
|
|
ClientID: "testClient", |
|
|
|
|
ClientSecret: "testSecret", |
|
|
|
|
RedirectURI: ts.URL + "/callback", |
|
|
|
|
Scopes: []string{"openid", "groups"}, |
|
|
|
|
AdminEmail: "admin@dexidp.com", |
|
|
|
|
}, ts.URL) |
|
|
|
|
assert.Nil(t, err) |
|
|
|
|
|
|
|
|
|
conn.adminSrv, err = admin.NewService(context.Background(), option.WithoutAuthentication(), option.WithEndpoint(ts.URL)) |
|
|
|
|
assert.Nil(t, err) |
|
|
|
|
type testCase struct { |
|
|
|
|
userKey string |
|
|
|
|
fetchTransitiveGroupMembership bool |
|
|
|
|
shouldErr bool |
|
|
|
|
expectedGroups []string |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
for name, testCase := range map[string]testCase{ |
|
|
|
|
"user1_non_transitive_lookup": { |
|
|
|
|
userKey: "user_1@dexidp.com", |
|
|
|
|
fetchTransitiveGroupMembership: false, |
|
|
|
|
shouldErr: false, |
|
|
|
|
expectedGroups: []string{"groups_1@dexidp.com", "groups_2@dexidp.com"}, |
|
|
|
|
}, |
|
|
|
|
"user1_transitive_lookup": { |
|
|
|
|
userKey: "user_1@dexidp.com", |
|
|
|
|
fetchTransitiveGroupMembership: true, |
|
|
|
|
shouldErr: false, |
|
|
|
|
expectedGroups: []string{"groups_0@dexidp.com", "groups_1@dexidp.com", "groups_2@dexidp.com"}, |
|
|
|
|
}, |
|
|
|
|
"user2_non_transitive_lookup": { |
|
|
|
|
userKey: "user_2@dexidp.com", |
|
|
|
|
fetchTransitiveGroupMembership: false, |
|
|
|
|
shouldErr: false, |
|
|
|
|
expectedGroups: []string{"groups_1@dexidp.com"}, |
|
|
|
|
}, |
|
|
|
|
"user2_transitive_lookup": { |
|
|
|
|
userKey: "user_2@dexidp.com", |
|
|
|
|
fetchTransitiveGroupMembership: true, |
|
|
|
|
shouldErr: false, |
|
|
|
|
expectedGroups: []string{"groups_0@dexidp.com", "groups_1@dexidp.com"}, |
|
|
|
|
}, |
|
|
|
|
} { |
|
|
|
|
testCase := testCase |
|
|
|
|
callCounter = map[string]int{} |
|
|
|
|
t.Run(name, func(t *testing.T) { |
|
|
|
|
assert := assert.New(t) |
|
|
|
|
lookup := make(map[string]struct{}) |
|
|
|
|
|
|
|
|
|
groups, err := conn.getGroups(testCase.userKey, testCase.fetchTransitiveGroupMembership, lookup) |
|
|
|
|
if testCase.shouldErr { |
|
|
|
|
assert.NotNil(err) |
|
|
|
|
} else { |
|
|
|
|
assert.Nil(err) |
|
|
|
|
} |
|
|
|
|
assert.ElementsMatch(testCase.expectedGroups, groups) |
|
|
|
|
t.Logf("[%s] Amount of API calls per userKey: %+v\n", t.Name(), callCounter) |
|
|
|
|
}) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|