diff --git a/connector/microsoft/microsoft.go b/connector/microsoft/microsoft.go index 2fcf6a75..c1a802ad 100644 --- a/connector/microsoft/microsoft.go +++ b/connector/microsoft/microsoft.go @@ -32,11 +32,11 @@ const ( ) const ( - // Microsoft requires this scope to access user's profile - scopeUser = "user.read" - // Microsoft requires this scope to list groups the user is a member of - // and resolve their ids to groups names. - scopeGroups = "directory.read.all" + // Microsoft requires the scopes to start with openid + scopeOpenID = "openid" + // Get the permissions configured on the application registration + // see https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-permissions-and-consent#the-default-scope + scopeDefault = "https://graph.microsoft.com/.default" // Microsoft requires this scope to return a refresh token // see https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-permissions-and-consent#offline_access scopeOfflineAccess = "offline_access" @@ -62,7 +62,7 @@ type Config struct { PromptType string `json:"promptType"` DomainHint string `json:"domainHint"` - Scopes []string `json:"scopes"` // defaults to scopeUser (user.read) + Scopes []string `json:"scopes"` // defaults to scopeOpenID (openid) } // Open returns a strategy for logging in through Microsoft. @@ -153,11 +153,9 @@ func (c *microsoftConnector) oauth2Config(scopes connector.Scopes) *oauth2.Confi if len(c.scopes) > 0 { microsoftScopes = c.scopes } else { - microsoftScopes = append(microsoftScopes, scopeUser) - } - if c.groupsRequired(scopes.Groups) { - microsoftScopes = append(microsoftScopes, scopeGroups) + microsoftScopes = append(microsoftScopes, scopeOpenID) } + microsoftScopes = append(microsoftScopes, scopeDefault) if scopes.OfflineAccess { microsoftScopes = append(microsoftScopes, scopeOfflineAccess) @@ -386,21 +384,15 @@ func (c *microsoftConnector) user(ctx context.Context, client *http.Client) (u u // Supports $filter and $orderby. type group struct { Name string `json:"displayName"` + Id string `json:"id,omitempty"` } func (c *microsoftConnector) getGroups(ctx context.Context, client *http.Client, userID string) ([]string, error) { - userGroups, err := c.getGroupIDs(ctx, client) + userGroups, err := c.queryGroups(ctx, client) if err != nil { return nil, err } - if c.groupNameFormat == GroupName { - userGroups, err = c.getGroupNames(ctx, client, userGroups) - if err != nil { - return nil, err - } - } - // ensure that the user is in at least one required group filteredGroups := groups_pkg.Filter(userGroups, c.groups) if len(c.groups) > 0 && len(filteredGroups) == 0 { @@ -412,51 +404,26 @@ func (c *microsoftConnector) getGroups(ctx context.Context, client *http.Client, return userGroups, nil } -func (c *microsoftConnector) getGroupIDs(ctx context.Context, client *http.Client) (ids []string, err error) { - // https://developer.microsoft.com/en-us/graph/docs/api-reference/v1.0/api/user_getmembergroups - in := &struct { - SecurityEnabledOnly bool `json:"securityEnabledOnly"` - }{c.onlySecurityGroups} - reqURL := c.graphURL + "/v1.0/me/getMemberGroups" - for { - var out []string - var next string - - next, err = c.post(ctx, client, reqURL, in, &out) - if err != nil { - return ids, err - } - - ids = append(ids, out...) - if next == "" { - return - } - reqURL = next - } -} - -func (c *microsoftConnector) getGroupNames(ctx context.Context, client *http.Client, ids []string) (groups []string, err error) { - if len(ids) == 0 { - return - } - - // https://developer.microsoft.com/en-us/graph/docs/api-reference/v1.0/api/directoryobject_getbyids - in := &struct { - IDs []string `json:"ids"` - Types []string `json:"types"` - }{ids, []string{"group"}} - reqURL := c.graphURL + "/v1.0/directoryObjects/getByIds" +func (c *microsoftConnector) queryGroups(ctx context.Context, client *http.Client) (groups []string, err error) { + reqURL := c.graphURL + "/v1.0/me/memberOf/microsoft.graph.group?$select=displayName,id" for { var out []group var next string - next, err = c.post(ctx, client, reqURL, in, &out) + next, err = c.get(ctx, client, reqURL, &out) if err != nil { + c.logger.Info("resolved groups", "groups", groups, "error", err.Error()) return groups, err } for _, g := range out { - groups = append(groups, g.Name) + if c.groupNameFormat == GroupName { + c.logger.Info("resolved another group", "name", g.Name) + groups = append(groups, g.Name) + } else { + c.logger.Info("resolved another group", "id", g.Id) + groups = append(groups, g.Id) + } } if next == "" { return @@ -466,6 +433,7 @@ func (c *microsoftConnector) getGroupNames(ctx context.Context, client *http.Cli } func (c *microsoftConnector) post(ctx context.Context, client *http.Client, reqURL string, in interface{}, out interface{}) (string, error) { + c.logger.Info("post url", "url", reqURL) var payload bytes.Buffer err := json.NewEncoder(&payload).Encode(in) @@ -500,6 +468,36 @@ func (c *microsoftConnector) post(ctx context.Context, client *http.Client, reqU return next, nil } +func (c *microsoftConnector) get(ctx context.Context, client *http.Client, reqURL string, out interface{}) (string, error) { + c.logger.Info("get url", "url", reqURL) + + req, err := http.NewRequest("GET", reqURL, nil) + if err != nil { + return "", fmt.Errorf("new req: %v", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := client.Do(req.WithContext(ctx)) + if err != nil { + return "", fmt.Errorf("get URL %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", newGraphError(resp.Body) + } + + var next string + if err = json.NewDecoder(resp.Body).Decode(&struct { + NextLink *string `json:"@odata.nextLink"` + Value interface{} `json:"value"` + }{&next, out}); err != nil { + return "", fmt.Errorf("JSON decode: %v", err) + } + + return next, nil +} + type graphError struct { Code string `json:"code"` Message string `json:"message"` diff --git a/connector/microsoft/microsoft_test.go b/connector/microsoft/microsoft_test.go index 67be660f..28b6753d 100644 --- a/connector/microsoft/microsoft_test.go +++ b/connector/microsoft/microsoft_test.go @@ -3,6 +3,7 @@ package microsoft import ( "encoding/json" "fmt" + "log/slog" "net/http" "net/http/httptest" "net/url" @@ -48,7 +49,7 @@ func TestLoginURL(t *testing.T) { expectEquals(t, queryParams.Get("client_id"), clientID) expectEquals(t, queryParams.Get("redirect_uri"), testURL) expectEquals(t, queryParams.Get("response_type"), "code") - expectEquals(t, queryParams.Get("scope"), "user.read") + expectEquals(t, queryParams.Get("scope"), "openid https://graph.microsoft.com/.default") expectEquals(t, queryParams.Get("state"), testState) expectEquals(t, queryParams.Get("prompt"), "") expectEquals(t, queryParams.Get("domain_hint"), "") @@ -104,8 +105,8 @@ func TestUserIdentityFromGraphAPI(t *testing.T) { func TestUserGroupsFromGraphAPI(t *testing.T) { s := newTestServer(map[string]testResponse{ "/v1.0/me?$select=id,displayName,userPrincipalName": {data: user{}}, - "/v1.0/me/getMemberGroups": {data: map[string]interface{}{ - "value": []string{"a", "b"}, + "/v1.0/me/memberOf/microsoft.graph.group?$select=displayName,id": {data: map[string]interface{}{ + "value": []group{{Name: "a", Id: "1"}, {Name: "b", Id: "2"}}, }}, "/" + tenant + "/oauth2/v2.0/token": dummyToken, }) @@ -113,12 +114,153 @@ func TestUserGroupsFromGraphAPI(t *testing.T) { req, _ := http.NewRequest("GET", s.URL, nil) - c := microsoftConnector{apiURL: s.URL, graphURL: s.URL, tenant: tenant} + c := microsoftConnector{apiURL: s.URL, graphURL: s.URL, tenant: tenant, logger: slog.Default(), groupNameFormat: GroupName} identity, err := c.HandleCallback(connector.Scopes{Groups: true}, req) expectNil(t, err) expectEquals(t, identity.Groups, []string{"a", "b"}) } +func TestUserGroupsWithGroupIDFormat(t *testing.T) { + s := newTestServer(map[string]testResponse{ + "/v1.0/me?$select=id,displayName,userPrincipalName": {data: user{}}, + "/v1.0/me/memberOf/microsoft.graph.group?$select=displayName,id": {data: map[string]interface{}{ + "value": []group{{Name: "GroupA", Id: "id-1"}, {Name: "GroupB", Id: "id-2"}}, + }}, + "/" + tenant + "/oauth2/v2.0/token": dummyToken, + }) + defer s.Close() + + req, _ := http.NewRequest("GET", s.URL, nil) + + c := microsoftConnector{apiURL: s.URL, graphURL: s.URL, tenant: tenant, logger: slog.Default(), groupNameFormat: GroupID} + identity, err := c.HandleCallback(connector.Scopes{Groups: true}, req) + expectNil(t, err) + expectEquals(t, identity.Groups, []string{"id-1", "id-2"}) +} + +func TestLoginURLWithCustomScopes(t *testing.T) { + testURL := "https://test.com" + testState := "some-state" + customScopes := []string{"custom.scope1", "custom.scope2"} + + conn := microsoftConnector{ + apiURL: testURL, + graphURL: testURL, + redirectURI: testURL, + clientID: clientID, + tenant: tenant, + scopes: customScopes, + } + + loginURL, _ := conn.LoginURL(connector.Scopes{}, conn.redirectURI, testState) + + parsedLoginURL, _ := url.Parse(loginURL) + queryParams := parsedLoginURL.Query() + + // Custom scopes should be used, plus the default scope is always appended + expectEquals(t, queryParams.Get("scope"), "custom.scope1 custom.scope2 https://graph.microsoft.com/.default") +} + +func TestLoginURLWithOfflineAccess(t *testing.T) { + testURL := "https://test.com" + testState := "some-state" + + conn := microsoftConnector{ + apiURL: testURL, + graphURL: testURL, + redirectURI: testURL, + clientID: clientID, + tenant: tenant, + } + + loginURL, _ := conn.LoginURL(connector.Scopes{OfflineAccess: true}, conn.redirectURI, testState) + + parsedLoginURL, _ := url.Parse(loginURL) + queryParams := parsedLoginURL.Query() + + expectEquals(t, queryParams.Get("scope"), "openid https://graph.microsoft.com/.default offline_access") +} + +func TestUserGroupsWithWhitelist(t *testing.T) { + s := newTestServer(map[string]testResponse{ + "/v1.0/me?$select=id,displayName,userPrincipalName": {data: user{ID: "user123"}}, + "/v1.0/me/memberOf/microsoft.graph.group?$select=displayName,id": {data: map[string]interface{}{ + "value": []group{{Name: "allowed-group", Id: "1"}, {Name: "other-group", Id: "2"}}, + }}, + "/" + tenant + "/oauth2/v2.0/token": dummyToken, + }) + defer s.Close() + + req, _ := http.NewRequest("GET", s.URL, nil) + + c := microsoftConnector{ + apiURL: s.URL, + graphURL: s.URL, + tenant: tenant, + logger: slog.Default(), + groupNameFormat: GroupName, + groups: []string{"allowed-group"}, + useGroupsAsWhitelist: true, + } + identity, err := c.HandleCallback(connector.Scopes{Groups: true}, req) + expectNil(t, err) + // Only the whitelisted group should be returned + expectEquals(t, identity.Groups, []string{"allowed-group"}) +} + +func TestUserGroupsNotInRequiredGroups(t *testing.T) { + s := newTestServer(map[string]testResponse{ + "/v1.0/me?$select=id,displayName,userPrincipalName": {data: user{ID: "user123"}}, + "/v1.0/me/memberOf/microsoft.graph.group?$select=displayName,id": {data: map[string]interface{}{ + "value": []group{{Name: "some-group", Id: "1"}}, + }}, + "/" + tenant + "/oauth2/v2.0/token": dummyToken, + }) + defer s.Close() + + req, _ := http.NewRequest("GET", s.URL, nil) + + c := microsoftConnector{ + apiURL: s.URL, + graphURL: s.URL, + tenant: tenant, + logger: slog.Default(), + groupNameFormat: GroupName, + groups: []string{"required-group"}, // User is not in this group + } + _, err := c.HandleCallback(connector.Scopes{Groups: true}, req) + // Should fail because user is not in required group + if err == nil { + t.Error("Expected error when user is not in required groups") + } +} + +func TestUserGroupsInRequiredGroups(t *testing.T) { + s := newTestServer(map[string]testResponse{ + "/v1.0/me?$select=id,displayName,userPrincipalName": {data: user{ID: "user123"}}, + "/v1.0/me/memberOf/microsoft.graph.group?$select=displayName,id": {data: map[string]interface{}{ + "value": []group{{Name: "required-group", Id: "1"}, {Name: "other-group", Id: "2"}}, + }}, + "/" + tenant + "/oauth2/v2.0/token": dummyToken, + }) + defer s.Close() + + req, _ := http.NewRequest("GET", s.URL, nil) + + c := microsoftConnector{ + apiURL: s.URL, + graphURL: s.URL, + tenant: tenant, + logger: slog.Default(), + groupNameFormat: GroupName, + groups: []string{"required-group"}, + } + identity, err := c.HandleCallback(connector.Scopes{Groups: true}, req) + expectNil(t, err) + // All groups should be returned (not filtered) when useGroupsAsWhitelist is false + expectEquals(t, identity.Groups, []string{"required-group", "other-group"}) +} + func newTestServer(responses map[string]testResponse) *httptest.Server { s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { response, found := responses[r.RequestURI]