diff --git a/connector/connector.go b/connector/connector.go index b1e069c3..1e554789 100644 --- a/connector/connector.go +++ b/connector/connector.go @@ -3,9 +3,22 @@ package connector import ( "context" + "fmt" "net/http" ) +// UserNotInRequiredGroupsError is returned by a connector when a user +// successfully authenticates but is not a member of any of the required groups. +// The server will respond with HTTP 403 Forbidden instead of 500. +type UserNotInRequiredGroupsError struct { + UserID string + Groups []string +} + +func (e *UserNotInRequiredGroupsError) Error() string { + return fmt.Sprintf("user %q is not in any of the required groups %v", e.UserID, e.Groups) +} + // Connector is a mechanism for federating login to a remote identity service. // // Implementations are expected to implement either the PasswordConnector or diff --git a/connector/microsoft/microsoft.go b/connector/microsoft/microsoft.go index 1db16942..ca6e025d 100644 --- a/connector/microsoft/microsoft.go +++ b/connector/microsoft/microsoft.go @@ -227,7 +227,7 @@ func (c *microsoftConnector) HandleCallback(s connector.Scopes, connData []byte, if c.groupsRequired(s.Groups) { groups, err := c.getGroups(ctx, client, user.ID) if err != nil { - return identity, fmt.Errorf("microsoft: get groups: %v", err) + return identity, fmt.Errorf("microsoft: get groups: %w", err) } identity.Groups = groups } @@ -318,7 +318,7 @@ func (c *microsoftConnector) Refresh(ctx context.Context, s connector.Scopes, id if c.groupsRequired(s.Groups) { groups, err := c.getGroups(ctx, client, user.ID) if err != nil { - return identity, fmt.Errorf("microsoft: get groups: %v", err) + return identity, fmt.Errorf("microsoft: get groups: %w", err) } identity.Groups = groups } @@ -404,7 +404,7 @@ func (c *microsoftConnector) getGroups(ctx context.Context, client *http.Client, // ensure that the user is in at least one required group filteredGroups := groups_pkg.Filter(userGroups, c.groups) if len(c.groups) > 0 && len(filteredGroups) == 0 { - return nil, fmt.Errorf("microsoft: user %v not in any of the required groups", userID) + return nil, &connector.UserNotInRequiredGroupsError{UserID: userID, Groups: c.groups} } else if c.useGroupsAsWhitelist { return filteredGroups, nil } diff --git a/connector/microsoft/microsoft_test.go b/connector/microsoft/microsoft_test.go index 1fa2f3dc..e7443bf4 100644 --- a/connector/microsoft/microsoft_test.go +++ b/connector/microsoft/microsoft_test.go @@ -2,6 +2,7 @@ package microsoft import ( "encoding/json" + "errors" "fmt" "net/http" "net/http/httptest" @@ -119,6 +120,39 @@ func TestUserGroupsFromGraphAPI(t *testing.T) { expectEquals(t, identity.Groups, []string{"a", "b"}) } +func TestUserNotInRequiredGroupFromGraphAPI(t *testing.T) { + s := newTestServer(map[string]testResponse{ + "/v1.0/me?$select=id,displayName,userPrincipalName": { + data: user{ID: "user-id-123", Name: "Jane Doe", Email: "jane.doe@example.com"}, + }, + // The user is a member of groups "c" and "d", but the connector only + // allows group "a" — so the user should be denied. + "/v1.0/me/getMemberGroups": {data: map[string]interface{}{ + "value": []string{"c", "d"}, + }}, + "/" + tenant + "/oauth2/v2.0/token": dummyToken, + }) + defer s.Close() + + req, _ := http.NewRequest("GET", s.URL, nil) + + c := microsoftConnector{ + apiURL: s.URL, + graphURL: s.URL, + tenant: tenant, + groups: []string{"a"}, + } + _, err := c.HandleCallback(connector.Scopes{Groups: true}, req) + if err == nil { + t.Fatal("expected error when user is not in any required group, got nil") + } + + var groupsErr *connector.UserNotInRequiredGroupsError + if !errors.As(err, &groupsErr) { + t.Errorf("expected *connector.UserNotInRequiredGroupsError, got %T: %v", err, err) + } +} + func newTestServer(responses map[string]testResponse) *httptest.Server { s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { response, found := responses[r.RequestURI] diff --git a/server/errors.go b/server/errors.go index c0b9d425..ec9146e0 100644 --- a/server/errors.go +++ b/server/errors.go @@ -23,4 +23,8 @@ const ( // ErrMsgMethodNotAllowed is shown when an unsupported HTTP method is used. ErrMsgMethodNotAllowed = "Method not allowed." + + // ErrMsgNotInRequiredGroups is shown when a user authenticates successfully + // but is not a member of any of the groups required by the connector. + ErrMsgNotInRequiredGroups = "You are not a member of any of the required groups to authenticate." ) diff --git a/server/handlers.go b/server/handlers.go index b8e11ce4..62f1650c 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -7,6 +7,7 @@ import ( "crypto/subtle" "encoding/base64" "encoding/json" + "errors" "fmt" "html/template" "net/http" @@ -499,7 +500,12 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request) if err != nil { s.logger.ErrorContext(r.Context(), "failed to authenticate", "err", err) - s.renderError(r, w, http.StatusInternalServerError, ErrMsgAuthenticationFailed) + var groupsErr *connector.UserNotInRequiredGroupsError + if errors.As(err, &groupsErr) { + s.renderError(r, w, http.StatusForbidden, ErrMsgNotInRequiredGroups) + } else { + s.renderError(r, w, http.StatusInternalServerError, ErrMsgAuthenticationFailed) + } return }