@ -3,6 +3,7 @@ package microsoft
import (
import (
"encoding/json"
"encoding/json"
"fmt"
"fmt"
"log/slog"
"net/http"
"net/http"
"net/http/httptest"
"net/http/httptest"
"net/url"
"net/url"
@ -48,7 +49,7 @@ func TestLoginURL(t *testing.T) {
expectEquals ( t , queryParams . Get ( "client_id" ) , clientID )
expectEquals ( t , queryParams . Get ( "client_id" ) , clientID )
expectEquals ( t , queryParams . Get ( "redirect_uri" ) , testURL )
expectEquals ( t , queryParams . Get ( "redirect_uri" ) , testURL )
expectEquals ( t , queryParams . Get ( "response_type" ) , "code" )
expectEquals ( t , queryParams . Get ( "response_type" ) , "code" )
expectEquals ( t , queryParams . Get ( "scope" ) , "user.read " )
expectEquals ( t , queryParams . Get ( "scope" ) , "openid https://graph.microsoft.com/.default " )
expectEquals ( t , queryParams . Get ( "state" ) , testState )
expectEquals ( t , queryParams . Get ( "state" ) , testState )
expectEquals ( t , queryParams . Get ( "prompt" ) , "" )
expectEquals ( t , queryParams . Get ( "prompt" ) , "" )
expectEquals ( t , queryParams . Get ( "domain_hint" ) , "" )
expectEquals ( t , queryParams . Get ( "domain_hint" ) , "" )
@ -104,8 +105,8 @@ func TestUserIdentityFromGraphAPI(t *testing.T) {
func TestUserGroupsFromGraphAPI ( t * testing . T ) {
func TestUserGroupsFromGraphAPI ( t * testing . T ) {
s := newTestServer ( map [ string ] testResponse {
s := newTestServer ( map [ string ] testResponse {
"/v1.0/me?$select=id,displayName,userPrincipalName" : { data : user { } } ,
"/v1.0/me?$select=id,displayName,userPrincipalName" : { data : user { } } ,
"/v1.0/me/getMemberGroups " : { data : map [ string ] interface { } {
"/v1.0/me/memberOf/microsoft.graph.group?$select=displayName,id " : { data : map [ string ] interface { } {
"value" : [ ] string { "a" , "b" } ,
"value" : [ ] group { { Name : "a" , Id : "1" } , { Name : "b" , Id : "2" } } ,
} } ,
} } ,
"/" + tenant + "/oauth2/v2.0/token" : dummyToken ,
"/" + tenant + "/oauth2/v2.0/token" : dummyToken ,
} )
} )
@ -113,12 +114,153 @@ func TestUserGroupsFromGraphAPI(t *testing.T) {
req , _ := http . NewRequest ( "GET" , s . URL , nil )
req , _ := http . NewRequest ( "GET" , s . URL , nil )
c := microsoftConnector { apiURL : s . URL , graphURL : s . URL , tenant : tenant }
c := microsoftConnector { apiURL : s . URL , graphURL : s . URL , tenant : tenant , logger : slog . Default ( ) , groupNameFormat : GroupName }
identity , err := c . HandleCallback ( connector . Scopes { Groups : true } , req )
identity , err := c . HandleCallback ( connector . Scopes { Groups : true } , req )
expectNil ( t , err )
expectNil ( t , err )
expectEquals ( t , identity . Groups , [ ] string { "a" , "b" } )
expectEquals ( t , identity . Groups , [ ] string { "a" , "b" } )
}
}
func TestUserGroupsWithGroupIDFormat ( t * testing . T ) {
s := newTestServer ( map [ string ] testResponse {
"/v1.0/me?$select=id,displayName,userPrincipalName" : { data : user { } } ,
"/v1.0/me/memberOf/microsoft.graph.group?$select=displayName,id" : { data : map [ string ] interface { } {
"value" : [ ] group { { Name : "GroupA" , Id : "id-1" } , { Name : "GroupB" , Id : "id-2" } } ,
} } ,
"/" + tenant + "/oauth2/v2.0/token" : dummyToken ,
} )
defer s . Close ( )
req , _ := http . NewRequest ( "GET" , s . URL , nil )
c := microsoftConnector { apiURL : s . URL , graphURL : s . URL , tenant : tenant , logger : slog . Default ( ) , groupNameFormat : GroupID }
identity , err := c . HandleCallback ( connector . Scopes { Groups : true } , req )
expectNil ( t , err )
expectEquals ( t , identity . Groups , [ ] string { "id-1" , "id-2" } )
}
func TestLoginURLWithCustomScopes ( t * testing . T ) {
testURL := "https://test.com"
testState := "some-state"
customScopes := [ ] string { "custom.scope1" , "custom.scope2" }
conn := microsoftConnector {
apiURL : testURL ,
graphURL : testURL ,
redirectURI : testURL ,
clientID : clientID ,
tenant : tenant ,
scopes : customScopes ,
}
loginURL , _ := conn . LoginURL ( connector . Scopes { } , conn . redirectURI , testState )
parsedLoginURL , _ := url . Parse ( loginURL )
queryParams := parsedLoginURL . Query ( )
// Custom scopes should be used, plus the default scope is always appended
expectEquals ( t , queryParams . Get ( "scope" ) , "custom.scope1 custom.scope2 https://graph.microsoft.com/.default" )
}
func TestLoginURLWithOfflineAccess ( t * testing . T ) {
testURL := "https://test.com"
testState := "some-state"
conn := microsoftConnector {
apiURL : testURL ,
graphURL : testURL ,
redirectURI : testURL ,
clientID : clientID ,
tenant : tenant ,
}
loginURL , _ := conn . LoginURL ( connector . Scopes { OfflineAccess : true } , conn . redirectURI , testState )
parsedLoginURL , _ := url . Parse ( loginURL )
queryParams := parsedLoginURL . Query ( )
expectEquals ( t , queryParams . Get ( "scope" ) , "openid https://graph.microsoft.com/.default offline_access" )
}
func TestUserGroupsWithWhitelist ( t * testing . T ) {
s := newTestServer ( map [ string ] testResponse {
"/v1.0/me?$select=id,displayName,userPrincipalName" : { data : user { ID : "user123" } } ,
"/v1.0/me/memberOf/microsoft.graph.group?$select=displayName,id" : { data : map [ string ] interface { } {
"value" : [ ] group { { Name : "allowed-group" , Id : "1" } , { Name : "other-group" , Id : "2" } } ,
} } ,
"/" + tenant + "/oauth2/v2.0/token" : dummyToken ,
} )
defer s . Close ( )
req , _ := http . NewRequest ( "GET" , s . URL , nil )
c := microsoftConnector {
apiURL : s . URL ,
graphURL : s . URL ,
tenant : tenant ,
logger : slog . Default ( ) ,
groupNameFormat : GroupName ,
groups : [ ] string { "allowed-group" } ,
useGroupsAsWhitelist : true ,
}
identity , err := c . HandleCallback ( connector . Scopes { Groups : true } , req )
expectNil ( t , err )
// Only the whitelisted group should be returned
expectEquals ( t , identity . Groups , [ ] string { "allowed-group" } )
}
func TestUserGroupsNotInRequiredGroups ( t * testing . T ) {
s := newTestServer ( map [ string ] testResponse {
"/v1.0/me?$select=id,displayName,userPrincipalName" : { data : user { ID : "user123" } } ,
"/v1.0/me/memberOf/microsoft.graph.group?$select=displayName,id" : { data : map [ string ] interface { } {
"value" : [ ] group { { Name : "some-group" , Id : "1" } } ,
} } ,
"/" + tenant + "/oauth2/v2.0/token" : dummyToken ,
} )
defer s . Close ( )
req , _ := http . NewRequest ( "GET" , s . URL , nil )
c := microsoftConnector {
apiURL : s . URL ,
graphURL : s . URL ,
tenant : tenant ,
logger : slog . Default ( ) ,
groupNameFormat : GroupName ,
groups : [ ] string { "required-group" } , // User is not in this group
}
_ , err := c . HandleCallback ( connector . Scopes { Groups : true } , req )
// Should fail because user is not in required group
if err == nil {
t . Error ( "Expected error when user is not in required groups" )
}
}
func TestUserGroupsInRequiredGroups ( t * testing . T ) {
s := newTestServer ( map [ string ] testResponse {
"/v1.0/me?$select=id,displayName,userPrincipalName" : { data : user { ID : "user123" } } ,
"/v1.0/me/memberOf/microsoft.graph.group?$select=displayName,id" : { data : map [ string ] interface { } {
"value" : [ ] group { { Name : "required-group" , Id : "1" } , { Name : "other-group" , Id : "2" } } ,
} } ,
"/" + tenant + "/oauth2/v2.0/token" : dummyToken ,
} )
defer s . Close ( )
req , _ := http . NewRequest ( "GET" , s . URL , nil )
c := microsoftConnector {
apiURL : s . URL ,
graphURL : s . URL ,
tenant : tenant ,
logger : slog . Default ( ) ,
groupNameFormat : GroupName ,
groups : [ ] string { "required-group" } ,
}
identity , err := c . HandleCallback ( connector . Scopes { Groups : true } , req )
expectNil ( t , err )
// All groups should be returned (not filtered) when useGroupsAsWhitelist is false
expectEquals ( t , identity . Groups , [ ] string { "required-group" , "other-group" } )
}
func newTestServer ( responses map [ string ] testResponse ) * httptest . Server {
func newTestServer ( responses map [ string ] testResponse ) * httptest . Server {
s := httptest . NewServer ( http . HandlerFunc ( func ( w http . ResponseWriter , r * http . Request ) {
s := httptest . NewServer ( http . HandlerFunc ( func ( w http . ResponseWriter , r * http . Request ) {
response , found := responses [ r . RequestURI ]
response , found := responses [ r . RequestURI ]