From 9a5868eb96dc6722516067c6e4182be593206f39 Mon Sep 17 00:00:00 2001 From: Giovanni Vella Date: Fri, 12 Dec 2025 14:07:45 +0100 Subject: [PATCH] feat: add oauth-authorization-server discovery endpoint Signed-off-by: Giovanni Vella --- connector/oidc/oidc_test.go | 11 +++++++ server/api.go | 2 +- server/handlers.go | 63 ++++++++++++++++++++++++++++++++++--- server/handlers_test.go | 51 ++++++++++++++++++++++++++++-- server/server.go | 10 ++++-- 5 files changed, 126 insertions(+), 11 deletions(-) diff --git a/connector/oidc/oidc_test.go b/connector/oidc/oidc_test.go index 9a3d7126..7b289485 100644 --- a/connector/oidc/oidc_test.go +++ b/connector/oidc/oidc_test.go @@ -862,6 +862,17 @@ func setupServer(tok map[string]interface{}, idTokenDesired bool) (*httptest.Ser }) }) + mux.HandleFunc("/.well-known/oauth-authorization-server", func(w http.ResponseWriter, r *http.Request) { + url := fmt.Sprintf("http://%s", r.Host) + + json.NewEncoder(w).Encode(&map[string]string{ + "issuer": url, + "token_endpoint": fmt.Sprintf("%s/token", url), + "authorization_endpoint": fmt.Sprintf("%s/authorize", url), + "jwks_uri": fmt.Sprintf("%s/keys", url), + }) + }) + return httptest.NewServer(mux), nil } diff --git a/server/api.go b/server/api.go index 4fceae96..e9cb3d6f 100644 --- a/server/api.go +++ b/server/api.go @@ -279,7 +279,7 @@ func (d dexAPI) GetVersion(ctx context.Context, req *api.VersionReq) (*api.Versi } func (d dexAPI) GetDiscovery(ctx context.Context, req *api.DiscoveryReq) (*api.DiscoveryResp, error) { - discoveryDoc := d.server.constructDiscovery() + discoveryDoc := d.server.constructDiscoveryOIDC() data, err := json.Marshal(discoveryDoc) if err != nil { return nil, fmt.Errorf("failed to marshal discovery data: %v", err) diff --git a/server/handlers.go b/server/handlers.go index f8d0ed64..8ee9b132 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -72,7 +72,7 @@ func (s *Server) handlePublicKeys(w http.ResponseWriter, r *http.Request) { w.Write(data) } -type discovery struct { +type discoveryOIDC struct { Issuer string `json:"issuer"` Auth string `json:"authorization_endpoint"` Token string `json:"token_endpoint"` @@ -90,8 +90,36 @@ type discovery struct { Claims []string `json:"claims_supported"` } -func (s *Server) discoveryHandler() (http.HandlerFunc, error) { - d := s.constructDiscovery() +type discoveryOAuth2 struct { + Issuer string `json:"issuer"` + Auth string `json:"authorization_endpoint"` + Token string `json:"token_endpoint"` + Keys string `json:"jwks_uri"` + DeviceEndpoint string `json:"device_authorization_endpoint,omitempty"` + Introspect string `json:"introspection_endpoint,omitempty"` + GrantTypes []string `json:"grant_types_supported"` + ResponseTypes []string `json:"response_types_supported"` + CodeChallengeAlgs []string `json:"code_challenge_methods_supported,omitempty"` + Scopes []string `json:"scopes_supported,omitempty"` + AuthMethods []string `json:"token_endpoint_auth_methods_supported,omitempty"` +} + +type DiscoveryType int + +const ( + DiscoveryOIDC DiscoveryType = iota + DiscoveryOAuth2 +) + +func (s *Server) discoveryHandler(t DiscoveryType) (http.HandlerFunc, error) { + var d interface{} + + switch t { + case DiscoveryOAuth2: + d = s.constructDiscoveryOAuth2() + default: + d = s.constructDiscoveryOIDC() + } data, err := json.MarshalIndent(d, "", " ") if err != nil { @@ -105,8 +133,8 @@ func (s *Server) discoveryHandler() (http.HandlerFunc, error) { }), nil } -func (s *Server) constructDiscovery() discovery { - d := discovery{ +func (s *Server) constructDiscoveryOIDC() discoveryOIDC { + d := discoveryOIDC{ Issuer: s.issuerURL.String(), Auth: s.absURL("/auth"), Token: s.absURL("/token"), @@ -134,6 +162,31 @@ func (s *Server) constructDiscovery() discovery { return d } +func (s *Server) constructDiscoveryOAuth2() discoveryOAuth2 { + d := discoveryOAuth2{ + Issuer: s.issuerURL.String(), + Auth: s.absURL("/auth"), + Token: s.absURL("/token"), + Keys: s.absURL("/keys"), + DeviceEndpoint: s.absURL("/device/code"), + Introspect: s.absURL("/token/introspect"), + CodeChallengeAlgs: []string{codeChallengeMethodS256, codeChallengeMethodPlain}, + Scopes: []string{"offline_access"}, + AuthMethods: []string{"client_secret_basic", "client_secret_post"}, + } + + // response_types_supported + for responseType := range s.supportedResponseTypes { + d.ResponseTypes = append(d.ResponseTypes, responseType) + } + sort.Strings(d.ResponseTypes) + + // grant_types_supported + d.GrantTypes = s.supportedGrantTypes + + return d +} + // handleAuthorization handles the OAuth2 auth endpoint. func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) { ctx := r.Context() diff --git a/server/handlers_test.go b/server/handlers_test.go index 114712ba..ac40ea80 100644 --- a/server/handlers_test.go +++ b/server/handlers_test.go @@ -34,7 +34,7 @@ func TestHandleHealth(t *testing.T) { } } -func TestHandleDiscovery(t *testing.T) { +func TestHandleDiscoveryOIDC(t *testing.T) { httpServer, server := newTestServer(t, nil) defer httpServer.Close() @@ -44,10 +44,10 @@ func TestHandleDiscovery(t *testing.T) { t.Errorf("expected 200 got %d", rr.Code) } - var res discovery + var res discoveryOIDC err := json.NewDecoder(rr.Result().Body).Decode(&res) require.NoError(t, err) - require.Equal(t, discovery{ + require.Equal(t, discoveryOIDC{ Issuer: httpServer.URL, Auth: fmt.Sprintf("%s/auth", httpServer.URL), Token: fmt.Sprintf("%s/token", httpServer.URL), @@ -101,6 +101,51 @@ func TestHandleDiscovery(t *testing.T) { }, res) } +func TestHandleDiscoveryOAuth2(t *testing.T) { + httpServer, server := newTestServer(t, nil) + defer httpServer.Close() + + rr := httptest.NewRecorder() + server.ServeHTTP(rr, httptest.NewRequest("GET", "/.well-known/oauth-authorization-server", nil)) + + if rr.Code != http.StatusOK { + t.Errorf("expected 200 got %d", rr.Code) + } + + var res discoveryOAuth2 + err := json.NewDecoder(rr.Result().Body).Decode(&res) + require.NoError(t, err) + + require.Equal(t, discoveryOAuth2{ + Issuer: httpServer.URL, + Auth: fmt.Sprintf("%s/auth", httpServer.URL), + Token: fmt.Sprintf("%s/token", httpServer.URL), + Keys: fmt.Sprintf("%s/keys", httpServer.URL), + DeviceEndpoint: fmt.Sprintf("%s/device/code", httpServer.URL), + Introspect: fmt.Sprintf("%s/token/introspect", httpServer.URL), + GrantTypes: []string{ + "authorization_code", + "refresh_token", + "urn:ietf:params:oauth:grant-type:device_code", + "urn:ietf:params:oauth:grant-type:token-exchange", + }, + ResponseTypes: []string{ + "code", + }, + CodeChallengeAlgs: []string{ + "S256", + "plain", + }, + Scopes: []string{ + "offline_access", + }, + AuthMethods: []string{ + "client_secret_basic", + "client_secret_post", + }, + }, res) +} + func TestHandleHealthFailure(t *testing.T) { httpServer, server := newTestServer(t, func(c *Config) { c.HealthChecker = gosundheit.New() diff --git a/server/server.go b/server/server.go index 70e8ae75..564df7cc 100644 --- a/server/server.go +++ b/server/server.go @@ -452,11 +452,17 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) } r.NotFoundHandler = http.NotFoundHandler() - discoveryHandler, err := s.discoveryHandler() + oidcHandler, err := s.discoveryHandler(DiscoveryOIDC) if err != nil { return nil, err } - handleWithCORS("/.well-known/openid-configuration", discoveryHandler) + handleWithCORS("/.well-known/openid-configuration", oidcHandler) + + oauthHandler, err := s.discoveryHandler(DiscoveryOAuth2) + if err != nil { + return nil, err + } + handleWithCORS("/.well-known/oauth-authorization-server", oauthHandler) // Handle the root path for the better user experience. handleWithCORS("/", func(w http.ResponseWriter, r *http.Request) { _, err := fmt.Fprintf(w, `