|
|
|
|
@ -60,6 +60,7 @@ func TestHandleDiscovery(t *testing.T) {
|
|
|
|
|
UserInfo: fmt.Sprintf("%s/userinfo", httpServer.URL), |
|
|
|
|
DeviceEndpoint: fmt.Sprintf("%s/device/code", httpServer.URL), |
|
|
|
|
Introspect: fmt.Sprintf("%s/token/introspect", httpServer.URL), |
|
|
|
|
Registration: fmt.Sprintf("%s/register", httpServer.URL), |
|
|
|
|
GrantTypes: []string{ |
|
|
|
|
"authorization_code", |
|
|
|
|
"refresh_token", |
|
|
|
|
@ -892,3 +893,253 @@ func setNonEmpty(vals url.Values, key, value string) {
|
|
|
|
|
vals.Set(key, value) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func TestHandleClientRegistration(t *testing.T) { |
|
|
|
|
tests := []struct { |
|
|
|
|
name string |
|
|
|
|
requestBody clientRegistrationRequest |
|
|
|
|
expectedStatusCode int |
|
|
|
|
validateResponse func(t *testing.T, resp clientRegistrationResponse) |
|
|
|
|
}{ |
|
|
|
|
{ |
|
|
|
|
name: "successful registration with minimal fields", |
|
|
|
|
requestBody: clientRegistrationRequest{ |
|
|
|
|
RedirectURIs: []string{"https://example.com/callback"}, |
|
|
|
|
}, |
|
|
|
|
expectedStatusCode: http.StatusCreated, |
|
|
|
|
validateResponse: func(t *testing.T, resp clientRegistrationResponse) { |
|
|
|
|
require.NotEmpty(t, resp.ClientID) |
|
|
|
|
require.NotEmpty(t, resp.ClientSecret) |
|
|
|
|
require.Equal(t, int64(0), resp.ClientSecretExpiresAt) |
|
|
|
|
require.Equal(t, []string{"https://example.com/callback"}, resp.RedirectURIs) |
|
|
|
|
require.Equal(t, "client_secret_basic", resp.TokenEndpointAuthMethod) |
|
|
|
|
require.Equal(t, []string{"authorization_code", "refresh_token"}, resp.GrantTypes) |
|
|
|
|
require.Equal(t, []string{"code"}, resp.ResponseTypes) |
|
|
|
|
}, |
|
|
|
|
}, |
|
|
|
|
{ |
|
|
|
|
name: "successful registration with all fields", |
|
|
|
|
requestBody: clientRegistrationRequest{ |
|
|
|
|
RedirectURIs: []string{"https://example.com/callback", "https://example.com/callback2"}, |
|
|
|
|
ClientName: "Test Client", |
|
|
|
|
TokenEndpointAuthMethod: "client_secret_post", |
|
|
|
|
GrantTypes: []string{"authorization_code"}, |
|
|
|
|
ResponseTypes: []string{"code"}, |
|
|
|
|
Scope: "openid email profile", |
|
|
|
|
LogoURI: "https://example.com/logo.png", |
|
|
|
|
}, |
|
|
|
|
expectedStatusCode: http.StatusCreated, |
|
|
|
|
validateResponse: func(t *testing.T, resp clientRegistrationResponse) { |
|
|
|
|
require.NotEmpty(t, resp.ClientID) |
|
|
|
|
require.NotEmpty(t, resp.ClientSecret) |
|
|
|
|
require.Equal(t, "Test Client", resp.ClientName) |
|
|
|
|
require.Equal(t, "client_secret_post", resp.TokenEndpointAuthMethod) |
|
|
|
|
require.Equal(t, []string{"authorization_code"}, resp.GrantTypes) |
|
|
|
|
require.Equal(t, []string{"code"}, resp.ResponseTypes) |
|
|
|
|
require.Equal(t, "openid email profile", resp.Scope) |
|
|
|
|
require.Equal(t, "https://example.com/logo.png", resp.LogoURI) |
|
|
|
|
}, |
|
|
|
|
}, |
|
|
|
|
{ |
|
|
|
|
name: "public client (no secret)", |
|
|
|
|
requestBody: clientRegistrationRequest{ |
|
|
|
|
RedirectURIs: []string{"https://example.com/callback"}, |
|
|
|
|
TokenEndpointAuthMethod: "none", |
|
|
|
|
}, |
|
|
|
|
expectedStatusCode: http.StatusCreated, |
|
|
|
|
validateResponse: func(t *testing.T, resp clientRegistrationResponse) { |
|
|
|
|
require.NotEmpty(t, resp.ClientID) |
|
|
|
|
require.Empty(t, resp.ClientSecret) |
|
|
|
|
require.Equal(t, "none", resp.TokenEndpointAuthMethod) |
|
|
|
|
}, |
|
|
|
|
}, |
|
|
|
|
{ |
|
|
|
|
name: "missing redirect_uris", |
|
|
|
|
requestBody: clientRegistrationRequest{ |
|
|
|
|
ClientName: "Test Client", |
|
|
|
|
}, |
|
|
|
|
expectedStatusCode: http.StatusBadRequest, |
|
|
|
|
}, |
|
|
|
|
{ |
|
|
|
|
name: "unsupported token_endpoint_auth_method", |
|
|
|
|
requestBody: clientRegistrationRequest{ |
|
|
|
|
RedirectURIs: []string{"https://example.com/callback"}, |
|
|
|
|
TokenEndpointAuthMethod: "invalid_method", |
|
|
|
|
}, |
|
|
|
|
expectedStatusCode: http.StatusBadRequest, |
|
|
|
|
}, |
|
|
|
|
{ |
|
|
|
|
name: "unsupported grant_type", |
|
|
|
|
requestBody: clientRegistrationRequest{ |
|
|
|
|
RedirectURIs: []string{"https://example.com/callback"}, |
|
|
|
|
GrantTypes: []string{"invalid_grant"}, |
|
|
|
|
}, |
|
|
|
|
expectedStatusCode: http.StatusBadRequest, |
|
|
|
|
}, |
|
|
|
|
{ |
|
|
|
|
name: "unsupported response_type", |
|
|
|
|
requestBody: clientRegistrationRequest{ |
|
|
|
|
RedirectURIs: []string{"https://example.com/callback"}, |
|
|
|
|
ResponseTypes: []string{"invalid_response"}, |
|
|
|
|
}, |
|
|
|
|
expectedStatusCode: http.StatusBadRequest, |
|
|
|
|
}, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
for _, tc := range tests { |
|
|
|
|
t.Run(tc.name, func(t *testing.T) { |
|
|
|
|
httpServer, s := newTestServer(t, nil) |
|
|
|
|
defer httpServer.Close() |
|
|
|
|
|
|
|
|
|
body, err := json.Marshal(tc.requestBody) |
|
|
|
|
require.NoError(t, err) |
|
|
|
|
|
|
|
|
|
rr := httptest.NewRecorder() |
|
|
|
|
req := httptest.NewRequest(http.MethodPost, httpServer.URL+"/register", bytes.NewReader(body)) |
|
|
|
|
req.Header.Set("Content-Type", "application/json") |
|
|
|
|
|
|
|
|
|
s.handleClientRegistration(rr, req) |
|
|
|
|
|
|
|
|
|
require.Equal(t, tc.expectedStatusCode, rr.Code, rr.Body.String()) |
|
|
|
|
|
|
|
|
|
if tc.expectedStatusCode == http.StatusCreated { |
|
|
|
|
var resp clientRegistrationResponse |
|
|
|
|
err := json.NewDecoder(rr.Result().Body).Decode(&resp) |
|
|
|
|
require.NoError(t, err) |
|
|
|
|
tc.validateResponse(t, resp) |
|
|
|
|
|
|
|
|
|
// Verify the client was actually created in storage
|
|
|
|
|
ctx := context.Background() |
|
|
|
|
client, err := s.storage.GetClient(ctx, resp.ClientID) |
|
|
|
|
require.NoError(t, err) |
|
|
|
|
require.Equal(t, resp.ClientID, client.ID) |
|
|
|
|
require.Equal(t, resp.RedirectURIs, client.RedirectURIs) |
|
|
|
|
} |
|
|
|
|
}) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func TestHandleClientRegistrationMethodNotAllowed(t *testing.T) { |
|
|
|
|
httpServer, s := newTestServer(t, nil) |
|
|
|
|
defer httpServer.Close() |
|
|
|
|
|
|
|
|
|
rr := httptest.NewRecorder() |
|
|
|
|
req := httptest.NewRequest(http.MethodGet, httpServer.URL+"/register", nil) |
|
|
|
|
|
|
|
|
|
s.handleClientRegistration(rr, req) |
|
|
|
|
|
|
|
|
|
require.Equal(t, http.StatusMethodNotAllowed, rr.Code) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func TestHandleClientRegistrationInvalidJSON(t *testing.T) { |
|
|
|
|
httpServer, s := newTestServer(t, nil) |
|
|
|
|
defer httpServer.Close() |
|
|
|
|
|
|
|
|
|
rr := httptest.NewRecorder() |
|
|
|
|
req := httptest.NewRequest(http.MethodPost, httpServer.URL+"/register", strings.NewReader("invalid json")) |
|
|
|
|
req.Header.Set("Content-Type", "application/json") |
|
|
|
|
|
|
|
|
|
s.handleClientRegistration(rr, req) |
|
|
|
|
|
|
|
|
|
require.Equal(t, http.StatusBadRequest, rr.Code) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func TestHandleClientRegistrationWithAuth(t *testing.T) { |
|
|
|
|
tests := []struct { |
|
|
|
|
name string |
|
|
|
|
registrationToken string |
|
|
|
|
authHeader string |
|
|
|
|
requestBody clientRegistrationRequest |
|
|
|
|
expectedStatusCode int |
|
|
|
|
}{ |
|
|
|
|
{ |
|
|
|
|
name: "successful registration with valid token", |
|
|
|
|
registrationToken: "secret-token-123", |
|
|
|
|
authHeader: "Bearer secret-token-123", |
|
|
|
|
requestBody: clientRegistrationRequest{ |
|
|
|
|
RedirectURIs: []string{"https://example.com/callback"}, |
|
|
|
|
}, |
|
|
|
|
expectedStatusCode: http.StatusCreated, |
|
|
|
|
}, |
|
|
|
|
{ |
|
|
|
|
name: "missing auth header when token required", |
|
|
|
|
registrationToken: "secret-token-123", |
|
|
|
|
authHeader: "", |
|
|
|
|
requestBody: clientRegistrationRequest{ |
|
|
|
|
RedirectURIs: []string{"https://example.com/callback"}, |
|
|
|
|
}, |
|
|
|
|
expectedStatusCode: http.StatusUnauthorized, |
|
|
|
|
}, |
|
|
|
|
{ |
|
|
|
|
name: "invalid token", |
|
|
|
|
registrationToken: "secret-token-123", |
|
|
|
|
authHeader: "Bearer wrong-token", |
|
|
|
|
requestBody: clientRegistrationRequest{ |
|
|
|
|
RedirectURIs: []string{"https://example.com/callback"}, |
|
|
|
|
}, |
|
|
|
|
expectedStatusCode: http.StatusUnauthorized, |
|
|
|
|
}, |
|
|
|
|
{ |
|
|
|
|
name: "malformed auth header", |
|
|
|
|
registrationToken: "secret-token-123", |
|
|
|
|
authHeader: "Basic secret-token-123", |
|
|
|
|
requestBody: clientRegistrationRequest{ |
|
|
|
|
RedirectURIs: []string{"https://example.com/callback"}, |
|
|
|
|
}, |
|
|
|
|
expectedStatusCode: http.StatusUnauthorized, |
|
|
|
|
}, |
|
|
|
|
{ |
|
|
|
|
name: "open registration (no token configured)", |
|
|
|
|
registrationToken: "", |
|
|
|
|
authHeader: "", |
|
|
|
|
requestBody: clientRegistrationRequest{ |
|
|
|
|
RedirectURIs: []string{"https://example.com/callback"}, |
|
|
|
|
}, |
|
|
|
|
expectedStatusCode: http.StatusCreated, |
|
|
|
|
}, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
for _, tc := range tests { |
|
|
|
|
t.Run(tc.name, func(t *testing.T) { |
|
|
|
|
ctx := context.Background() |
|
|
|
|
httpServer, s := newTestServer(t, func(c *Config) { |
|
|
|
|
c.RegistrationToken = tc.registrationToken |
|
|
|
|
}) |
|
|
|
|
defer httpServer.Close() |
|
|
|
|
|
|
|
|
|
body, err := json.Marshal(tc.requestBody) |
|
|
|
|
require.NoError(t, err) |
|
|
|
|
|
|
|
|
|
rr := httptest.NewRecorder() |
|
|
|
|
req := httptest.NewRequest(http.MethodPost, httpServer.URL+"/register", bytes.NewReader(body)) |
|
|
|
|
req.Header.Set("Content-Type", "application/json") |
|
|
|
|
if tc.authHeader != "" { |
|
|
|
|
req.Header.Set("Authorization", tc.authHeader) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
s.handleClientRegistration(rr, req) |
|
|
|
|
|
|
|
|
|
require.Equal(t, tc.expectedStatusCode, rr.Code, rr.Body.String()) |
|
|
|
|
|
|
|
|
|
if tc.expectedStatusCode == http.StatusCreated { |
|
|
|
|
var resp clientRegistrationResponse |
|
|
|
|
err := json.NewDecoder(rr.Result().Body).Decode(&resp) |
|
|
|
|
require.NoError(t, err) |
|
|
|
|
require.NotEmpty(t, resp.ClientID) |
|
|
|
|
require.NotEmpty(t, resp.ClientSecret) |
|
|
|
|
|
|
|
|
|
// Verify the client was actually created in storage
|
|
|
|
|
client, err := s.storage.GetClient(ctx, resp.ClientID) |
|
|
|
|
require.NoError(t, err) |
|
|
|
|
require.Equal(t, resp.ClientID, client.ID) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Check WWW-Authenticate header on 401
|
|
|
|
|
if tc.expectedStatusCode == http.StatusUnauthorized { |
|
|
|
|
wwwAuth := rr.Header().Get("WWW-Authenticate") |
|
|
|
|
require.NotEmpty(t, wwwAuth) |
|
|
|
|
require.Contains(t, wwwAuth, "Bearer") |
|
|
|
|
} |
|
|
|
|
}) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|