diff --git a/server/handlers.go b/server/handlers.go index f8d0ed64..855ce2db 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -80,6 +80,7 @@ type discovery struct { UserInfo string `json:"userinfo_endpoint"` DeviceEndpoint string `json:"device_authorization_endpoint"` Introspect string `json:"introspection_endpoint"` + Registration string `json:"registration_endpoint"` GrantTypes []string `json:"grant_types_supported"` ResponseTypes []string `json:"response_types_supported"` Subjects []string `json:"subject_types_supported"` @@ -114,6 +115,7 @@ func (s *Server) constructDiscovery() discovery { UserInfo: s.absURL("/userinfo"), DeviceEndpoint: s.absURL("/device/code"), Introspect: s.absURL("/token/introspect"), + Registration: s.absURL("/register"), Subjects: []string{"public"}, IDTokenAlgs: []string{string(jose.RS256)}, CodeChallengeAlgs: []string{codeChallengeMethodS256, codeChallengeMethodPlain}, @@ -1490,3 +1492,174 @@ func usernamePrompt(conn connector.PasswordConnector) string { } return "Username" } + +// clientRegistrationRequest represents an RFC 7591 client registration request +type clientRegistrationRequest struct { + RedirectURIs []string `json:"redirect_uris"` + ClientName string `json:"client_name,omitempty"` + TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"` + GrantTypes []string `json:"grant_types,omitempty"` + ResponseTypes []string `json:"response_types,omitempty"` + Scope string `json:"scope,omitempty"` + LogoURI string `json:"logo_uri,omitempty"` +} + +// clientRegistrationResponse represents an RFC 7591 client registration response +type clientRegistrationResponse struct { + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret,omitempty"` + ClientSecretExpiresAt int64 `json:"client_secret_expires_at"` + ClientName string `json:"client_name,omitempty"` + RedirectURIs []string `json:"redirect_uris"` + TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"` + GrantTypes []string `json:"grant_types,omitempty"` + ResponseTypes []string `json:"response_types,omitempty"` + Scope string `json:"scope,omitempty"` + LogoURI string `json:"logo_uri,omitempty"` +} + +// handleClientRegistration implements RFC 7591 OAuth 2.0 Dynamic Client Registration Protocol +func (s *Server) handleClientRegistration(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + // Only POST method is allowed + if r.Method != http.MethodPost { + s.registrationErrHelper(w, errInvalidRequest, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Check Initial Access Token if configured (RFC 7591 Section 3.1) + if s.registrationToken != "" { + authHeader := r.Header.Get("Authorization") + const bearerPrefix = "Bearer " + + if authHeader == "" || !strings.HasPrefix(authHeader, bearerPrefix) { + w.Header().Set("WWW-Authenticate", "Bearer") + s.registrationErrHelper(w, errInvalidRequest, "Initial access token required", http.StatusUnauthorized) + return + } + + providedToken := strings.TrimPrefix(authHeader, bearerPrefix) + if providedToken != s.registrationToken { + w.Header().Set("WWW-Authenticate", "Bearer error=\"invalid_token\"") + s.registrationErrHelper(w, errInvalidRequest, "Invalid initial access token", http.StatusUnauthorized) + return + } + + s.logger.InfoContext(ctx, "client registration authenticated with initial access token") + } else { + s.logger.WarnContext(ctx, "client registration endpoint is open - no authentication required. Set registrationToken in config for production use.") + } + + // Parse the request body + var req clientRegistrationRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + s.logger.ErrorContext(ctx, "failed to parse registration request", "err", err) + s.registrationErrHelper(w, errInvalidRequest, "Invalid JSON request body", http.StatusBadRequest) + return + } + + // Validate required fields + if len(req.RedirectURIs) == 0 { + s.registrationErrHelper(w, errInvalidRequest, "redirect_uris is required", http.StatusBadRequest) + return + } + + // Apply default values + if req.TokenEndpointAuthMethod == "" { + req.TokenEndpointAuthMethod = "client_secret_basic" + } + if len(req.GrantTypes) == 0 { + req.GrantTypes = []string{grantTypeAuthorizationCode, grantTypeRefreshToken} + } + if len(req.ResponseTypes) == 0 { + req.ResponseTypes = []string{responseTypeCode} + } + + // Validate token_endpoint_auth_method + if req.TokenEndpointAuthMethod != "client_secret_basic" && req.TokenEndpointAuthMethod != "client_secret_post" && req.TokenEndpointAuthMethod != "none" { + s.registrationErrHelper(w, errInvalidRequest, "Unsupported token_endpoint_auth_method", http.StatusBadRequest) + return + } + + // Validate grant_types + for _, gt := range req.GrantTypes { + if !contains(s.supportedGrantTypes, gt) { + s.registrationErrHelper(w, errInvalidRequest, fmt.Sprintf("Unsupported grant_type: %s", gt), http.StatusBadRequest) + return + } + } + + // Validate response_types + for _, rt := range req.ResponseTypes { + if !s.supportedResponseTypes[rt] { + s.registrationErrHelper(w, errInvalidRequest, fmt.Sprintf("Unsupported response_type: %s", rt), http.StatusBadRequest) + return + } + } + + // Generate client_id and client_secret + // Following the same pattern as the gRPC API (api.go:CreateClient) + clientID := storage.NewID() + + // Determine if this is a public client + isPublic := req.TokenEndpointAuthMethod == "none" + + // Only generate secret for confidential clients + var clientSecret string + if !isPublic { + clientSecret = storage.NewID() + storage.NewID() // Double NewID for longer secret + } + + // Create the client in storage + client := storage.Client{ + ID: clientID, + Secret: clientSecret, + RedirectURIs: req.RedirectURIs, + Name: req.ClientName, + LogoURL: req.LogoURI, + Public: isPublic, + } + + if err := s.storage.CreateClient(ctx, client); err != nil { + s.logger.ErrorContext(ctx, "failed to create client", "err", err) + if err == storage.ErrAlreadyExists { + s.registrationErrHelper(w, errInvalidRequest, "Client ID already exists", http.StatusBadRequest) + } else { + s.registrationErrHelper(w, errServerError, "Failed to register client", http.StatusInternalServerError) + } + return + } + + // Build the response + resp := clientRegistrationResponse{ + ClientID: clientID, + ClientSecret: clientSecret, + ClientSecretExpiresAt: 0, // 0 indicates the secret never expires + ClientName: req.ClientName, + RedirectURIs: req.RedirectURIs, + TokenEndpointAuthMethod: req.TokenEndpointAuthMethod, + GrantTypes: req.GrantTypes, + ResponseTypes: req.ResponseTypes, + Scope: req.Scope, + LogoURI: req.LogoURI, + } + + // For public clients, don't return the secret + if isPublic { + resp.ClientSecret = "" + } + + // Return HTTP 201 Created + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + if err := json.NewEncoder(w).Encode(resp); err != nil { + s.logger.ErrorContext(ctx, "failed to encode registration response", "err", err) + } +} + +func (s *Server) registrationErrHelper(w http.ResponseWriter, typ, description string, statusCode int) { + if err := tokenErr(w, typ, description, statusCode); err != nil { + s.logger.Error("registration error response", "err", err) + } +} diff --git a/server/handlers_test.go b/server/handlers_test.go index 114712ba..e2912551 100644 --- a/server/handlers_test.go +++ b/server/handlers_test.go @@ -55,6 +55,7 @@ func TestHandleDiscovery(t *testing.T) { UserInfo: fmt.Sprintf("%s/userinfo", httpServer.URL), DeviceEndpoint: fmt.Sprintf("%s/device/code", httpServer.URL), Introspect: fmt.Sprintf("%s/token/introspect", httpServer.URL), + Registration: fmt.Sprintf("%s/register", httpServer.URL), GrantTypes: []string{ "authorization_code", "refresh_token", @@ -802,3 +803,253 @@ func setNonEmpty(vals url.Values, key, value string) { vals.Set(key, value) } } + +func TestHandleClientRegistration(t *testing.T) { + tests := []struct { + name string + requestBody clientRegistrationRequest + expectedStatusCode int + validateResponse func(t *testing.T, resp clientRegistrationResponse) + }{ + { + name: "successful registration with minimal fields", + requestBody: clientRegistrationRequest{ + RedirectURIs: []string{"https://example.com/callback"}, + }, + expectedStatusCode: http.StatusCreated, + validateResponse: func(t *testing.T, resp clientRegistrationResponse) { + require.NotEmpty(t, resp.ClientID) + require.NotEmpty(t, resp.ClientSecret) + require.Equal(t, int64(0), resp.ClientSecretExpiresAt) + require.Equal(t, []string{"https://example.com/callback"}, resp.RedirectURIs) + require.Equal(t, "client_secret_basic", resp.TokenEndpointAuthMethod) + require.Equal(t, []string{"authorization_code", "refresh_token"}, resp.GrantTypes) + require.Equal(t, []string{"code"}, resp.ResponseTypes) + }, + }, + { + name: "successful registration with all fields", + requestBody: clientRegistrationRequest{ + RedirectURIs: []string{"https://example.com/callback", "https://example.com/callback2"}, + ClientName: "Test Client", + TokenEndpointAuthMethod: "client_secret_post", + GrantTypes: []string{"authorization_code"}, + ResponseTypes: []string{"code"}, + Scope: "openid email profile", + LogoURI: "https://example.com/logo.png", + }, + expectedStatusCode: http.StatusCreated, + validateResponse: func(t *testing.T, resp clientRegistrationResponse) { + require.NotEmpty(t, resp.ClientID) + require.NotEmpty(t, resp.ClientSecret) + require.Equal(t, "Test Client", resp.ClientName) + require.Equal(t, "client_secret_post", resp.TokenEndpointAuthMethod) + require.Equal(t, []string{"authorization_code"}, resp.GrantTypes) + require.Equal(t, []string{"code"}, resp.ResponseTypes) + require.Equal(t, "openid email profile", resp.Scope) + require.Equal(t, "https://example.com/logo.png", resp.LogoURI) + }, + }, + { + name: "public client (no secret)", + requestBody: clientRegistrationRequest{ + RedirectURIs: []string{"https://example.com/callback"}, + TokenEndpointAuthMethod: "none", + }, + expectedStatusCode: http.StatusCreated, + validateResponse: func(t *testing.T, resp clientRegistrationResponse) { + require.NotEmpty(t, resp.ClientID) + require.Empty(t, resp.ClientSecret) + require.Equal(t, "none", resp.TokenEndpointAuthMethod) + }, + }, + { + name: "missing redirect_uris", + requestBody: clientRegistrationRequest{ + ClientName: "Test Client", + }, + expectedStatusCode: http.StatusBadRequest, + }, + { + name: "unsupported token_endpoint_auth_method", + requestBody: clientRegistrationRequest{ + RedirectURIs: []string{"https://example.com/callback"}, + TokenEndpointAuthMethod: "invalid_method", + }, + expectedStatusCode: http.StatusBadRequest, + }, + { + name: "unsupported grant_type", + requestBody: clientRegistrationRequest{ + RedirectURIs: []string{"https://example.com/callback"}, + GrantTypes: []string{"invalid_grant"}, + }, + expectedStatusCode: http.StatusBadRequest, + }, + { + name: "unsupported response_type", + requestBody: clientRegistrationRequest{ + RedirectURIs: []string{"https://example.com/callback"}, + ResponseTypes: []string{"invalid_response"}, + }, + expectedStatusCode: http.StatusBadRequest, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + httpServer, s := newTestServer(t, nil) + defer httpServer.Close() + + body, err := json.Marshal(tc.requestBody) + require.NoError(t, err) + + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, httpServer.URL+"/register", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + s.handleClientRegistration(rr, req) + + require.Equal(t, tc.expectedStatusCode, rr.Code, rr.Body.String()) + + if tc.expectedStatusCode == http.StatusCreated { + var resp clientRegistrationResponse + err := json.NewDecoder(rr.Result().Body).Decode(&resp) + require.NoError(t, err) + tc.validateResponse(t, resp) + + // Verify the client was actually created in storage + ctx := context.Background() + client, err := s.storage.GetClient(ctx, resp.ClientID) + require.NoError(t, err) + require.Equal(t, resp.ClientID, client.ID) + require.Equal(t, resp.RedirectURIs, client.RedirectURIs) + } + }) + } +} + +func TestHandleClientRegistrationMethodNotAllowed(t *testing.T) { + httpServer, s := newTestServer(t, nil) + defer httpServer.Close() + + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, httpServer.URL+"/register", nil) + + s.handleClientRegistration(rr, req) + + require.Equal(t, http.StatusMethodNotAllowed, rr.Code) +} + +func TestHandleClientRegistrationInvalidJSON(t *testing.T) { + httpServer, s := newTestServer(t, nil) + defer httpServer.Close() + + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, httpServer.URL+"/register", strings.NewReader("invalid json")) + req.Header.Set("Content-Type", "application/json") + + s.handleClientRegistration(rr, req) + + require.Equal(t, http.StatusBadRequest, rr.Code) +} + +func TestHandleClientRegistrationWithAuth(t *testing.T) { + tests := []struct { + name string + registrationToken string + authHeader string + requestBody clientRegistrationRequest + expectedStatusCode int + }{ + { + name: "successful registration with valid token", + registrationToken: "secret-token-123", + authHeader: "Bearer secret-token-123", + requestBody: clientRegistrationRequest{ + RedirectURIs: []string{"https://example.com/callback"}, + }, + expectedStatusCode: http.StatusCreated, + }, + { + name: "missing auth header when token required", + registrationToken: "secret-token-123", + authHeader: "", + requestBody: clientRegistrationRequest{ + RedirectURIs: []string{"https://example.com/callback"}, + }, + expectedStatusCode: http.StatusUnauthorized, + }, + { + name: "invalid token", + registrationToken: "secret-token-123", + authHeader: "Bearer wrong-token", + requestBody: clientRegistrationRequest{ + RedirectURIs: []string{"https://example.com/callback"}, + }, + expectedStatusCode: http.StatusUnauthorized, + }, + { + name: "malformed auth header", + registrationToken: "secret-token-123", + authHeader: "Basic secret-token-123", + requestBody: clientRegistrationRequest{ + RedirectURIs: []string{"https://example.com/callback"}, + }, + expectedStatusCode: http.StatusUnauthorized, + }, + { + name: "open registration (no token configured)", + registrationToken: "", + authHeader: "", + requestBody: clientRegistrationRequest{ + RedirectURIs: []string{"https://example.com/callback"}, + }, + expectedStatusCode: http.StatusCreated, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + httpServer, s := newTestServer(t, func(c *Config) { + c.RegistrationToken = tc.registrationToken + }) + defer httpServer.Close() + + body, err := json.Marshal(tc.requestBody) + require.NoError(t, err) + + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, httpServer.URL+"/register", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + if tc.authHeader != "" { + req.Header.Set("Authorization", tc.authHeader) + } + + s.handleClientRegistration(rr, req) + + require.Equal(t, tc.expectedStatusCode, rr.Code, rr.Body.String()) + + if tc.expectedStatusCode == http.StatusCreated { + var resp clientRegistrationResponse + err := json.NewDecoder(rr.Result().Body).Decode(&resp) + require.NoError(t, err) + require.NotEmpty(t, resp.ClientID) + require.NotEmpty(t, resp.ClientSecret) + + // Verify the client was actually created in storage + client, err := s.storage.GetClient(ctx, resp.ClientID) + require.NoError(t, err) + require.Equal(t, resp.ClientID, client.ID) + } + + // Check WWW-Authenticate header on 401 + if tc.expectedStatusCode == http.StatusUnauthorized { + wwwAuth := rr.Header().Get("WWW-Authenticate") + require.NotEmpty(t, wwwAuth) + require.Contains(t, wwwAuth, "Bearer") + } + }) + } +} diff --git a/server/server.go b/server/server.go index 70e8ae75..5590675e 100644 --- a/server/server.go +++ b/server/server.go @@ -109,6 +109,11 @@ type Config struct { GCFrequency time.Duration // Defaults to 5 minutes + // RegistrationToken is an optional bearer token required for dynamic client registration. + // If empty, the /register endpoint allows open registration (not recommended for production). + // Set this to restrict registration to authorized parties only. + RegistrationToken string + // If specified, the server will use this function for determining time. Now func() time.Time @@ -187,6 +192,9 @@ type Server struct { // Used for password grant passwordConnector string + // Optional bearer token for client registration endpoint + registrationToken string + supportedResponseTypes map[string]bool supportedGrantTypes []string @@ -315,6 +323,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) now: now, templates: tmpls, passwordConnector: c.PasswordConnector, + registrationToken: c.RegistrationToken, logger: c.Logger, } @@ -477,6 +486,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) handleWithCORS("/keys", s.handlePublicKeys) handleWithCORS("/userinfo", s.handleUserInfo) handleWithCORS("/token/introspect", s.handleIntrospect) + handleWithCORS("/register", s.handleClientRegistration) handleFunc("/auth", s.handleAuthorization) handleFunc("/auth/{connector}", s.handleConnectorLogin) handleFunc("/auth/{connector}/login", s.handlePasswordLogin)