Browse Source

Merge fcedca4429 into a6962a8ba4

pull/4383/merge
ByteBaker 3 weeks ago committed by GitHub
parent
commit
de7b03ad48
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 173
      server/handlers.go
  2. 251
      server/handlers_test.go
  3. 10
      server/server.go

173
server/handlers.go

@ -80,6 +80,7 @@ type discovery struct {
UserInfo string `json:"userinfo_endpoint"`
DeviceEndpoint string `json:"device_authorization_endpoint"`
Introspect string `json:"introspection_endpoint"`
Registration string `json:"registration_endpoint"`
GrantTypes []string `json:"grant_types_supported"`
ResponseTypes []string `json:"response_types_supported"`
Subjects []string `json:"subject_types_supported"`
@ -114,6 +115,7 @@ func (s *Server) constructDiscovery(ctx context.Context) discovery {
UserInfo: s.absURL("/userinfo"),
DeviceEndpoint: s.absURL("/device/code"),
Introspect: s.absURL("/token/introspect"),
Registration: s.absURL("/register"),
Subjects: []string{"public"},
IDTokenAlgs: []string{string(jose.RS256)},
CodeChallengeAlgs: []string{codeChallengeMethodS256, codeChallengeMethodPlain},
@ -1518,3 +1520,174 @@ func usernamePrompt(conn connector.PasswordConnector) string {
}
return "Username"
}
// clientRegistrationRequest represents an RFC 7591 client registration request
type clientRegistrationRequest struct {
RedirectURIs []string `json:"redirect_uris"`
ClientName string `json:"client_name,omitempty"`
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"`
GrantTypes []string `json:"grant_types,omitempty"`
ResponseTypes []string `json:"response_types,omitempty"`
Scope string `json:"scope,omitempty"`
LogoURI string `json:"logo_uri,omitempty"`
}
// clientRegistrationResponse represents an RFC 7591 client registration response
type clientRegistrationResponse struct {
ClientID string `json:"client_id"`
ClientSecret string `json:"client_secret,omitempty"`
ClientSecretExpiresAt int64 `json:"client_secret_expires_at"`
ClientName string `json:"client_name,omitempty"`
RedirectURIs []string `json:"redirect_uris"`
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"`
GrantTypes []string `json:"grant_types,omitempty"`
ResponseTypes []string `json:"response_types,omitempty"`
Scope string `json:"scope,omitempty"`
LogoURI string `json:"logo_uri,omitempty"`
}
// handleClientRegistration implements RFC 7591 OAuth 2.0 Dynamic Client Registration Protocol
func (s *Server) handleClientRegistration(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// Only POST method is allowed
if r.Method != http.MethodPost {
s.registrationErrHelper(w, errInvalidRequest, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// Check Initial Access Token if configured (RFC 7591 Section 3.1)
if s.registrationToken != "" {
authHeader := r.Header.Get("Authorization")
const bearerPrefix = "Bearer "
if authHeader == "" || !strings.HasPrefix(authHeader, bearerPrefix) {
w.Header().Set("WWW-Authenticate", "Bearer")
s.registrationErrHelper(w, errInvalidRequest, "Initial access token required", http.StatusUnauthorized)
return
}
providedToken := strings.TrimPrefix(authHeader, bearerPrefix)
if providedToken != s.registrationToken {
w.Header().Set("WWW-Authenticate", "Bearer error=\"invalid_token\"")
s.registrationErrHelper(w, errInvalidRequest, "Invalid initial access token", http.StatusUnauthorized)
return
}
s.logger.InfoContext(ctx, "client registration authenticated with initial access token")
} else {
s.logger.WarnContext(ctx, "client registration endpoint is open - no authentication required. Set registrationToken in config for production use.")
}
// Parse the request body
var req clientRegistrationRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.logger.ErrorContext(ctx, "failed to parse registration request", "err", err)
s.registrationErrHelper(w, errInvalidRequest, "Invalid JSON request body", http.StatusBadRequest)
return
}
// Validate required fields
if len(req.RedirectURIs) == 0 {
s.registrationErrHelper(w, errInvalidRequest, "redirect_uris is required", http.StatusBadRequest)
return
}
// Apply default values
if req.TokenEndpointAuthMethod == "" {
req.TokenEndpointAuthMethod = "client_secret_basic"
}
if len(req.GrantTypes) == 0 {
req.GrantTypes = []string{grantTypeAuthorizationCode, grantTypeRefreshToken}
}
if len(req.ResponseTypes) == 0 {
req.ResponseTypes = []string{responseTypeCode}
}
// Validate token_endpoint_auth_method
if req.TokenEndpointAuthMethod != "client_secret_basic" && req.TokenEndpointAuthMethod != "client_secret_post" && req.TokenEndpointAuthMethod != "none" {
s.registrationErrHelper(w, errInvalidRequest, "Unsupported token_endpoint_auth_method", http.StatusBadRequest)
return
}
// Validate grant_types
for _, gt := range req.GrantTypes {
if !contains(s.supportedGrantTypes, gt) {
s.registrationErrHelper(w, errInvalidRequest, fmt.Sprintf("Unsupported grant_type: %s", gt), http.StatusBadRequest)
return
}
}
// Validate response_types
for _, rt := range req.ResponseTypes {
if !s.supportedResponseTypes[rt] {
s.registrationErrHelper(w, errInvalidRequest, fmt.Sprintf("Unsupported response_type: %s", rt), http.StatusBadRequest)
return
}
}
// Generate client_id and client_secret
// Following the same pattern as the gRPC API (api.go:CreateClient)
clientID := storage.NewID()
// Determine if this is a public client
isPublic := req.TokenEndpointAuthMethod == "none"
// Only generate secret for confidential clients
var clientSecret string
if !isPublic {
clientSecret = storage.NewID() + storage.NewID() // Double NewID for longer secret
}
// Create the client in storage
client := storage.Client{
ID: clientID,
Secret: clientSecret,
RedirectURIs: req.RedirectURIs,
Name: req.ClientName,
LogoURL: req.LogoURI,
Public: isPublic,
}
if err := s.storage.CreateClient(ctx, client); err != nil {
s.logger.ErrorContext(ctx, "failed to create client", "err", err)
if err == storage.ErrAlreadyExists {
s.registrationErrHelper(w, errInvalidRequest, "Client ID already exists", http.StatusBadRequest)
} else {
s.registrationErrHelper(w, errServerError, "Failed to register client", http.StatusInternalServerError)
}
return
}
// Build the response
resp := clientRegistrationResponse{
ClientID: clientID,
ClientSecret: clientSecret,
ClientSecretExpiresAt: 0, // 0 indicates the secret never expires
ClientName: req.ClientName,
RedirectURIs: req.RedirectURIs,
TokenEndpointAuthMethod: req.TokenEndpointAuthMethod,
GrantTypes: req.GrantTypes,
ResponseTypes: req.ResponseTypes,
Scope: req.Scope,
LogoURI: req.LogoURI,
}
// For public clients, don't return the secret
if isPublic {
resp.ClientSecret = ""
}
// Return HTTP 201 Created
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
if err := json.NewEncoder(w).Encode(resp); err != nil {
s.logger.ErrorContext(ctx, "failed to encode registration response", "err", err)
}
}
func (s *Server) registrationErrHelper(w http.ResponseWriter, typ, description string, statusCode int) {
if err := tokenErr(w, typ, description, statusCode); err != nil {
s.logger.Error("registration error response", "err", err)
}
}

251
server/handlers_test.go

@ -60,6 +60,7 @@ func TestHandleDiscovery(t *testing.T) {
UserInfo: fmt.Sprintf("%s/userinfo", httpServer.URL),
DeviceEndpoint: fmt.Sprintf("%s/device/code", httpServer.URL),
Introspect: fmt.Sprintf("%s/token/introspect", httpServer.URL),
Registration: fmt.Sprintf("%s/register", httpServer.URL),
GrantTypes: []string{
"authorization_code",
"refresh_token",
@ -892,3 +893,253 @@ func setNonEmpty(vals url.Values, key, value string) {
vals.Set(key, value)
}
}
func TestHandleClientRegistration(t *testing.T) {
tests := []struct {
name string
requestBody clientRegistrationRequest
expectedStatusCode int
validateResponse func(t *testing.T, resp clientRegistrationResponse)
}{
{
name: "successful registration with minimal fields",
requestBody: clientRegistrationRequest{
RedirectURIs: []string{"https://example.com/callback"},
},
expectedStatusCode: http.StatusCreated,
validateResponse: func(t *testing.T, resp clientRegistrationResponse) {
require.NotEmpty(t, resp.ClientID)
require.NotEmpty(t, resp.ClientSecret)
require.Equal(t, int64(0), resp.ClientSecretExpiresAt)
require.Equal(t, []string{"https://example.com/callback"}, resp.RedirectURIs)
require.Equal(t, "client_secret_basic", resp.TokenEndpointAuthMethod)
require.Equal(t, []string{"authorization_code", "refresh_token"}, resp.GrantTypes)
require.Equal(t, []string{"code"}, resp.ResponseTypes)
},
},
{
name: "successful registration with all fields",
requestBody: clientRegistrationRequest{
RedirectURIs: []string{"https://example.com/callback", "https://example.com/callback2"},
ClientName: "Test Client",
TokenEndpointAuthMethod: "client_secret_post",
GrantTypes: []string{"authorization_code"},
ResponseTypes: []string{"code"},
Scope: "openid email profile",
LogoURI: "https://example.com/logo.png",
},
expectedStatusCode: http.StatusCreated,
validateResponse: func(t *testing.T, resp clientRegistrationResponse) {
require.NotEmpty(t, resp.ClientID)
require.NotEmpty(t, resp.ClientSecret)
require.Equal(t, "Test Client", resp.ClientName)
require.Equal(t, "client_secret_post", resp.TokenEndpointAuthMethod)
require.Equal(t, []string{"authorization_code"}, resp.GrantTypes)
require.Equal(t, []string{"code"}, resp.ResponseTypes)
require.Equal(t, "openid email profile", resp.Scope)
require.Equal(t, "https://example.com/logo.png", resp.LogoURI)
},
},
{
name: "public client (no secret)",
requestBody: clientRegistrationRequest{
RedirectURIs: []string{"https://example.com/callback"},
TokenEndpointAuthMethod: "none",
},
expectedStatusCode: http.StatusCreated,
validateResponse: func(t *testing.T, resp clientRegistrationResponse) {
require.NotEmpty(t, resp.ClientID)
require.Empty(t, resp.ClientSecret)
require.Equal(t, "none", resp.TokenEndpointAuthMethod)
},
},
{
name: "missing redirect_uris",
requestBody: clientRegistrationRequest{
ClientName: "Test Client",
},
expectedStatusCode: http.StatusBadRequest,
},
{
name: "unsupported token_endpoint_auth_method",
requestBody: clientRegistrationRequest{
RedirectURIs: []string{"https://example.com/callback"},
TokenEndpointAuthMethod: "invalid_method",
},
expectedStatusCode: http.StatusBadRequest,
},
{
name: "unsupported grant_type",
requestBody: clientRegistrationRequest{
RedirectURIs: []string{"https://example.com/callback"},
GrantTypes: []string{"invalid_grant"},
},
expectedStatusCode: http.StatusBadRequest,
},
{
name: "unsupported response_type",
requestBody: clientRegistrationRequest{
RedirectURIs: []string{"https://example.com/callback"},
ResponseTypes: []string{"invalid_response"},
},
expectedStatusCode: http.StatusBadRequest,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
httpServer, s := newTestServer(t, nil)
defer httpServer.Close()
body, err := json.Marshal(tc.requestBody)
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, httpServer.URL+"/register", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
s.handleClientRegistration(rr, req)
require.Equal(t, tc.expectedStatusCode, rr.Code, rr.Body.String())
if tc.expectedStatusCode == http.StatusCreated {
var resp clientRegistrationResponse
err := json.NewDecoder(rr.Result().Body).Decode(&resp)
require.NoError(t, err)
tc.validateResponse(t, resp)
// Verify the client was actually created in storage
ctx := context.Background()
client, err := s.storage.GetClient(ctx, resp.ClientID)
require.NoError(t, err)
require.Equal(t, resp.ClientID, client.ID)
require.Equal(t, resp.RedirectURIs, client.RedirectURIs)
}
})
}
}
func TestHandleClientRegistrationMethodNotAllowed(t *testing.T) {
httpServer, s := newTestServer(t, nil)
defer httpServer.Close()
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, httpServer.URL+"/register", nil)
s.handleClientRegistration(rr, req)
require.Equal(t, http.StatusMethodNotAllowed, rr.Code)
}
func TestHandleClientRegistrationInvalidJSON(t *testing.T) {
httpServer, s := newTestServer(t, nil)
defer httpServer.Close()
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, httpServer.URL+"/register", strings.NewReader("invalid json"))
req.Header.Set("Content-Type", "application/json")
s.handleClientRegistration(rr, req)
require.Equal(t, http.StatusBadRequest, rr.Code)
}
func TestHandleClientRegistrationWithAuth(t *testing.T) {
tests := []struct {
name string
registrationToken string
authHeader string
requestBody clientRegistrationRequest
expectedStatusCode int
}{
{
name: "successful registration with valid token",
registrationToken: "secret-token-123",
authHeader: "Bearer secret-token-123",
requestBody: clientRegistrationRequest{
RedirectURIs: []string{"https://example.com/callback"},
},
expectedStatusCode: http.StatusCreated,
},
{
name: "missing auth header when token required",
registrationToken: "secret-token-123",
authHeader: "",
requestBody: clientRegistrationRequest{
RedirectURIs: []string{"https://example.com/callback"},
},
expectedStatusCode: http.StatusUnauthorized,
},
{
name: "invalid token",
registrationToken: "secret-token-123",
authHeader: "Bearer wrong-token",
requestBody: clientRegistrationRequest{
RedirectURIs: []string{"https://example.com/callback"},
},
expectedStatusCode: http.StatusUnauthorized,
},
{
name: "malformed auth header",
registrationToken: "secret-token-123",
authHeader: "Basic secret-token-123",
requestBody: clientRegistrationRequest{
RedirectURIs: []string{"https://example.com/callback"},
},
expectedStatusCode: http.StatusUnauthorized,
},
{
name: "open registration (no token configured)",
registrationToken: "",
authHeader: "",
requestBody: clientRegistrationRequest{
RedirectURIs: []string{"https://example.com/callback"},
},
expectedStatusCode: http.StatusCreated,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ctx := context.Background()
httpServer, s := newTestServer(t, func(c *Config) {
c.RegistrationToken = tc.registrationToken
})
defer httpServer.Close()
body, err := json.Marshal(tc.requestBody)
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, httpServer.URL+"/register", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
if tc.authHeader != "" {
req.Header.Set("Authorization", tc.authHeader)
}
s.handleClientRegistration(rr, req)
require.Equal(t, tc.expectedStatusCode, rr.Code, rr.Body.String())
if tc.expectedStatusCode == http.StatusCreated {
var resp clientRegistrationResponse
err := json.NewDecoder(rr.Result().Body).Decode(&resp)
require.NoError(t, err)
require.NotEmpty(t, resp.ClientID)
require.NotEmpty(t, resp.ClientSecret)
// Verify the client was actually created in storage
client, err := s.storage.GetClient(ctx, resp.ClientID)
require.NoError(t, err)
require.Equal(t, resp.ClientID, client.ID)
}
// Check WWW-Authenticate header on 401
if tc.expectedStatusCode == http.StatusUnauthorized {
wwwAuth := rr.Header().Get("WWW-Authenticate")
require.NotEmpty(t, wwwAuth)
require.Contains(t, wwwAuth, "Bearer")
}
})
}
}

10
server/server.go

@ -108,6 +108,11 @@ type Config struct {
GCFrequency time.Duration // Defaults to 5 minutes
// RegistrationToken is an optional bearer token required for dynamic client registration.
// If empty, the /register endpoint allows open registration (not recommended for production).
// Set this to restrict registration to authorized parties only.
RegistrationToken string
// If specified, the server will use this function for determining time.
Now func() time.Time
@ -189,6 +194,9 @@ type Server struct {
// Used for password grant
passwordConnector string
// Optional bearer token for client registration endpoint
registrationToken string
supportedResponseTypes map[string]bool
supportedGrantTypes []string
@ -309,6 +317,7 @@ func newServer(ctx context.Context, c Config) (*Server, error) {
now: now,
templates: tmpls,
passwordConnector: c.PasswordConnector,
registrationToken: c.RegistrationToken,
logger: c.Logger,
signer: c.Signer,
}
@ -472,6 +481,7 @@ func newServer(ctx context.Context, c Config) (*Server, error) {
handleWithCORS("/keys", s.handlePublicKeys)
handleWithCORS("/userinfo", s.handleUserInfo)
handleWithCORS("/token/introspect", s.handleIntrospect)
handleWithCORS("/register", s.handleClientRegistration)
handleFunc("/auth", s.handleAuthorization)
handleFunc("/auth/{connector}", s.handleConnectorLogin)
handleFunc("/auth/{connector}/login", s.handlePasswordLogin)

Loading…
Cancel
Save