diff --git a/connector/microsoft/microsoft.go b/connector/microsoft/microsoft.go index bbc3d6c6..a3275a33 100644 --- a/connector/microsoft/microsoft.go +++ b/connector/microsoft/microsoft.go @@ -63,6 +63,8 @@ type Config struct { DomainHint string `json:"domainHint"` Scopes []string `json:"scopes"` // defaults to scopeUser (user.read) + + AllowedDomains []string `json:"allowedDomains"` } // Open returns a strategy for logging in through Microsoft. @@ -83,6 +85,7 @@ func (c *Config) Open(id string, logger log.Logger) (connector.Connector, error) promptType: c.PromptType, domainHint: c.DomainHint, scopes: c.Scopes, + allowedDomains: c.AllowedDomains, } if m.apiURL == "" { @@ -138,6 +141,7 @@ type microsoftConnector struct { promptType string domainHint string scopes []string + allowedDomains []string } func (c *microsoftConnector) isOrgTenant() bool { @@ -217,6 +221,11 @@ func (c *microsoftConnector) HandleCallback(s connector.Scopes, r *http.Request) user.Email = strings.ToLower(user.Email) } + // Check if the email's domain is in the allowed list + if !c.isAllowedDomain(user.Email) { + return identity, fmt.Errorf("email (%s) domain not allowed", user.Email) + } + identity = connector.Identity{ UserID: user.ID, Username: user.Name, @@ -531,3 +540,22 @@ func (e *oauth2Error) Error() string { } return e.error + ": " + e.errorDescription } + +func (c *microsoftConnector) isAllowedDomain(email string) bool { + + if len(c.allowedDomains) == 0 { + return true + } + + parts := strings.Split(email, "@") + if len(parts) != 2 { + return false + } + domain := parts[1] + for _, d := range c.allowedDomains { + if d == domain { + return true + } + } + return false +} diff --git a/connector/microsoft/microsoft_test.go b/connector/microsoft/microsoft_test.go index 67be660f..3384efe6 100644 --- a/connector/microsoft/microsoft_test.go +++ b/connector/microsoft/microsoft_test.go @@ -3,6 +3,7 @@ package microsoft import ( "encoding/json" "fmt" + "github.com/stretchr/testify/assert" "net/http" "net/http/httptest" "net/url" @@ -119,6 +120,66 @@ func TestUserGroupsFromGraphAPI(t *testing.T) { expectEquals(t, identity.Groups, []string{"a", "b"}) } +func TestDomainNotAllowed(t *testing.T) { + s := newTestServer(map[string]testResponse{ + "/v1.0/me?$select=id,displayName,userPrincipalName": { + data: user{ID: "S56767889", Name: "Jane Doe", Email: "jane.doe@example.com"}, + }, + "/" + tenant + "/oauth2/v2.0/token": dummyToken, + }) + defer s.Close() + + req, _ := http.NewRequest("GET", s.URL, nil) + + c := microsoftConnector{apiURL: s.URL, graphURL: s.URL, tenant: tenant, allowedDomains: []string{"dcode.tech"}} + identity, err := c.HandleCallback(connector.Scopes{Groups: false}, req) + + assert.Error(t, err, "email (jane.doe@example.com) domain not allowed") + assert.Equal(t, connector.Identity{}, identity) +} + +func TestDomainListAllowed(t *testing.T) { + testCases := []struct { + email string + allowed bool + domain string + }{ + {"jane.doe@dcode.tech", true, "dcode.tech"}, // Allowed domain + {"joe.bloggs@example.com", true, "example.com"}, // Allowed domain + {"john.smith@otherdomain.com", false, "otherdomain.com"}, // Not allowed domain + } + + for _, tc := range testCases { + s := newTestServer(map[string]testResponse{ + "/v1.0/me?$select=id,displayName,userPrincipalName": { + data: user{ID: "S56767889", Name: "John Doe", Email: tc.email}, + }, + "/" + tenant + "/oauth2/v2.0/token": dummyToken, + }) + defer s.Close() + + req, _ := http.NewRequest("GET", s.URL, nil) + + // Setup the microsoftConnector with allowed domains + c := microsoftConnector{ + apiURL: s.URL, + graphURL: s.URL, + tenant: tenant, + allowedDomains: []string{"dcode.tech", "example.com"}, + } + + identity, err := c.HandleCallback(connector.Scopes{Groups: false}, req) + + if tc.allowed { + assert.NoError(t, err, "Expected no error for allowed domain: "+tc.domain) + assert.NotEqual(t, connector.Identity{}, identity, "Expected a non-empty identity struct") + } else { + assert.Error(t, err, "Expected error for non-allowed domain: "+tc.email) + assert.Equal(t, connector.Identity{}, identity, "Expected an empty identity struct") + } + } +} + func newTestServer(responses map[string]testResponse) *httptest.Server { s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { response, found := responses[r.RequestURI]