Browse Source

Add allowedDomains to Microsoft connector.

The existing Microsoft connector does not support allowedDomains like Google connector or other connectors allows.
pull/3515/head
Maor Davidov 2 years ago
parent
commit
4ccb43e8d8
  1. 28
      connector/microsoft/microsoft.go
  2. 61
      connector/microsoft/microsoft_test.go

28
connector/microsoft/microsoft.go

@ -63,6 +63,8 @@ type Config struct {
DomainHint string `json:"domainHint"`
Scopes []string `json:"scopes"` // defaults to scopeUser (user.read)
AllowedDomains []string `json:"allowedDomains"`
}
// Open returns a strategy for logging in through Microsoft.
@ -83,6 +85,7 @@ func (c *Config) Open(id string, logger log.Logger) (connector.Connector, error)
promptType: c.PromptType,
domainHint: c.DomainHint,
scopes: c.Scopes,
allowedDomains: c.AllowedDomains,
}
if m.apiURL == "" {
@ -138,6 +141,7 @@ type microsoftConnector struct {
promptType string
domainHint string
scopes []string
allowedDomains []string
}
func (c *microsoftConnector) isOrgTenant() bool {
@ -217,6 +221,11 @@ func (c *microsoftConnector) HandleCallback(s connector.Scopes, r *http.Request)
user.Email = strings.ToLower(user.Email)
}
// Check if the email's domain is in the allowed list
if !c.isAllowedDomain(user.Email) {
return identity, fmt.Errorf("email (%s) domain not allowed", user.Email)
}
identity = connector.Identity{
UserID: user.ID,
Username: user.Name,
@ -531,3 +540,22 @@ func (e *oauth2Error) Error() string {
}
return e.error + ": " + e.errorDescription
}
func (c *microsoftConnector) isAllowedDomain(email string) bool {
if len(c.allowedDomains) == 0 {
return true
}
parts := strings.Split(email, "@")
if len(parts) != 2 {
return false
}
domain := parts[1]
for _, d := range c.allowedDomains {
if d == domain {
return true
}
}
return false
}

61
connector/microsoft/microsoft_test.go

@ -3,6 +3,7 @@ package microsoft
import (
"encoding/json"
"fmt"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"net/url"
@ -119,6 +120,66 @@ func TestUserGroupsFromGraphAPI(t *testing.T) {
expectEquals(t, identity.Groups, []string{"a", "b"})
}
func TestDomainNotAllowed(t *testing.T) {
s := newTestServer(map[string]testResponse{
"/v1.0/me?$select=id,displayName,userPrincipalName": {
data: user{ID: "S56767889", Name: "Jane Doe", Email: "jane.doe@example.com"},
},
"/" + tenant + "/oauth2/v2.0/token": dummyToken,
})
defer s.Close()
req, _ := http.NewRequest("GET", s.URL, nil)
c := microsoftConnector{apiURL: s.URL, graphURL: s.URL, tenant: tenant, allowedDomains: []string{"dcode.tech"}}
identity, err := c.HandleCallback(connector.Scopes{Groups: false}, req)
assert.Error(t, err, "email (jane.doe@example.com) domain not allowed")
assert.Equal(t, connector.Identity{}, identity)
}
func TestDomainListAllowed(t *testing.T) {
testCases := []struct {
email string
allowed bool
domain string
}{
{"jane.doe@dcode.tech", true, "dcode.tech"}, // Allowed domain
{"joe.bloggs@example.com", true, "example.com"}, // Allowed domain
{"john.smith@otherdomain.com", false, "otherdomain.com"}, // Not allowed domain
}
for _, tc := range testCases {
s := newTestServer(map[string]testResponse{
"/v1.0/me?$select=id,displayName,userPrincipalName": {
data: user{ID: "S56767889", Name: "John Doe", Email: tc.email},
},
"/" + tenant + "/oauth2/v2.0/token": dummyToken,
})
defer s.Close()
req, _ := http.NewRequest("GET", s.URL, nil)
// Setup the microsoftConnector with allowed domains
c := microsoftConnector{
apiURL: s.URL,
graphURL: s.URL,
tenant: tenant,
allowedDomains: []string{"dcode.tech", "example.com"},
}
identity, err := c.HandleCallback(connector.Scopes{Groups: false}, req)
if tc.allowed {
assert.NoError(t, err, "Expected no error for allowed domain: "+tc.domain)
assert.NotEqual(t, connector.Identity{}, identity, "Expected a non-empty identity struct")
} else {
assert.Error(t, err, "Expected error for non-allowed domain: "+tc.email)
assert.Equal(t, connector.Identity{}, identity, "Expected an empty identity struct")
}
}
}
func newTestServer(responses map[string]testResponse) *httptest.Server {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
response, found := responses[r.RequestURI]

Loading…
Cancel
Save