Browse Source

Merge 9a5868eb96 into 4c94d8a140

pull/4446/merge
Giovanni Vella 1 month ago committed by GitHub
parent
commit
f3ef6450d5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 11
      connector/oidc/oidc_test.go
  2. 2
      server/api.go
  3. 63
      server/handlers.go
  4. 51
      server/handlers_test.go
  5. 10
      server/server.go

11
connector/oidc/oidc_test.go

@ -862,6 +862,17 @@ func setupServer(tok map[string]interface{}, idTokenDesired bool) (*httptest.Ser
})
})
mux.HandleFunc("/.well-known/oauth-authorization-server", func(w http.ResponseWriter, r *http.Request) {
url := fmt.Sprintf("http://%s", r.Host)
json.NewEncoder(w).Encode(&map[string]string{
"issuer": url,
"token_endpoint": fmt.Sprintf("%s/token", url),
"authorization_endpoint": fmt.Sprintf("%s/authorize", url),
"jwks_uri": fmt.Sprintf("%s/keys", url),
})
})
return httptest.NewServer(mux), nil
}

2
server/api.go

@ -279,7 +279,7 @@ func (d dexAPI) GetVersion(ctx context.Context, req *api.VersionReq) (*api.Versi
}
func (d dexAPI) GetDiscovery(ctx context.Context, req *api.DiscoveryReq) (*api.DiscoveryResp, error) {
discoveryDoc := d.server.constructDiscovery()
discoveryDoc := d.server.constructDiscoveryOIDC()
data, err := json.Marshal(discoveryDoc)
if err != nil {
return nil, fmt.Errorf("failed to marshal discovery data: %v", err)

63
server/handlers.go

@ -72,7 +72,7 @@ func (s *Server) handlePublicKeys(w http.ResponseWriter, r *http.Request) {
w.Write(data)
}
type discovery struct {
type discoveryOIDC struct {
Issuer string `json:"issuer"`
Auth string `json:"authorization_endpoint"`
Token string `json:"token_endpoint"`
@ -90,8 +90,36 @@ type discovery struct {
Claims []string `json:"claims_supported"`
}
func (s *Server) discoveryHandler() (http.HandlerFunc, error) {
d := s.constructDiscovery()
type discoveryOAuth2 struct {
Issuer string `json:"issuer"`
Auth string `json:"authorization_endpoint"`
Token string `json:"token_endpoint"`
Keys string `json:"jwks_uri"`
DeviceEndpoint string `json:"device_authorization_endpoint,omitempty"`
Introspect string `json:"introspection_endpoint,omitempty"`
GrantTypes []string `json:"grant_types_supported"`
ResponseTypes []string `json:"response_types_supported"`
CodeChallengeAlgs []string `json:"code_challenge_methods_supported,omitempty"`
Scopes []string `json:"scopes_supported,omitempty"`
AuthMethods []string `json:"token_endpoint_auth_methods_supported,omitempty"`
}
type DiscoveryType int
const (
DiscoveryOIDC DiscoveryType = iota
DiscoveryOAuth2
)
func (s *Server) discoveryHandler(t DiscoveryType) (http.HandlerFunc, error) {
var d interface{}
switch t {
case DiscoveryOAuth2:
d = s.constructDiscoveryOAuth2()
default:
d = s.constructDiscoveryOIDC()
}
data, err := json.MarshalIndent(d, "", " ")
if err != nil {
@ -105,8 +133,8 @@ func (s *Server) discoveryHandler() (http.HandlerFunc, error) {
}), nil
}
func (s *Server) constructDiscovery() discovery {
d := discovery{
func (s *Server) constructDiscoveryOIDC() discoveryOIDC {
d := discoveryOIDC{
Issuer: s.issuerURL.String(),
Auth: s.absURL("/auth"),
Token: s.absURL("/token"),
@ -134,6 +162,31 @@ func (s *Server) constructDiscovery() discovery {
return d
}
func (s *Server) constructDiscoveryOAuth2() discoveryOAuth2 {
d := discoveryOAuth2{
Issuer: s.issuerURL.String(),
Auth: s.absURL("/auth"),
Token: s.absURL("/token"),
Keys: s.absURL("/keys"),
DeviceEndpoint: s.absURL("/device/code"),
Introspect: s.absURL("/token/introspect"),
CodeChallengeAlgs: []string{codeChallengeMethodS256, codeChallengeMethodPlain},
Scopes: []string{"offline_access"},
AuthMethods: []string{"client_secret_basic", "client_secret_post"},
}
// response_types_supported
for responseType := range s.supportedResponseTypes {
d.ResponseTypes = append(d.ResponseTypes, responseType)
}
sort.Strings(d.ResponseTypes)
// grant_types_supported
d.GrantTypes = s.supportedGrantTypes
return d
}
// handleAuthorization handles the OAuth2 auth endpoint.
func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

51
server/handlers_test.go

@ -35,7 +35,7 @@ func TestHandleHealth(t *testing.T) {
}
}
func TestHandleDiscovery(t *testing.T) {
func TestHandleDiscoveryOIDC(t *testing.T) {
httpServer, server := newTestServer(t, nil)
defer httpServer.Close()
@ -45,10 +45,10 @@ func TestHandleDiscovery(t *testing.T) {
t.Errorf("expected 200 got %d", rr.Code)
}
var res discovery
var res discoveryOIDC
err := json.NewDecoder(rr.Result().Body).Decode(&res)
require.NoError(t, err)
require.Equal(t, discovery{
require.Equal(t, discoveryOIDC{
Issuer: httpServer.URL,
Auth: fmt.Sprintf("%s/auth", httpServer.URL),
Token: fmt.Sprintf("%s/token", httpServer.URL),
@ -102,6 +102,51 @@ func TestHandleDiscovery(t *testing.T) {
}, res)
}
func TestHandleDiscoveryOAuth2(t *testing.T) {
httpServer, server := newTestServer(t, nil)
defer httpServer.Close()
rr := httptest.NewRecorder()
server.ServeHTTP(rr, httptest.NewRequest("GET", "/.well-known/oauth-authorization-server", nil))
if rr.Code != http.StatusOK {
t.Errorf("expected 200 got %d", rr.Code)
}
var res discoveryOAuth2
err := json.NewDecoder(rr.Result().Body).Decode(&res)
require.NoError(t, err)
require.Equal(t, discoveryOAuth2{
Issuer: httpServer.URL,
Auth: fmt.Sprintf("%s/auth", httpServer.URL),
Token: fmt.Sprintf("%s/token", httpServer.URL),
Keys: fmt.Sprintf("%s/keys", httpServer.URL),
DeviceEndpoint: fmt.Sprintf("%s/device/code", httpServer.URL),
Introspect: fmt.Sprintf("%s/token/introspect", httpServer.URL),
GrantTypes: []string{
"authorization_code",
"refresh_token",
"urn:ietf:params:oauth:grant-type:device_code",
"urn:ietf:params:oauth:grant-type:token-exchange",
},
ResponseTypes: []string{
"code",
},
CodeChallengeAlgs: []string{
"S256",
"plain",
},
Scopes: []string{
"offline_access",
},
AuthMethods: []string{
"client_secret_basic",
"client_secret_post",
},
}, res)
}
func TestHandleHealthFailure(t *testing.T) {
httpServer, server := newTestServer(t, func(c *Config) {
c.HealthChecker = gosundheit.New()

10
server/server.go

@ -452,11 +452,17 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
}
r.NotFoundHandler = http.NotFoundHandler()
discoveryHandler, err := s.discoveryHandler()
oidcHandler, err := s.discoveryHandler(DiscoveryOIDC)
if err != nil {
return nil, err
}
handleWithCORS("/.well-known/openid-configuration", discoveryHandler)
handleWithCORS("/.well-known/openid-configuration", oidcHandler)
oauthHandler, err := s.discoveryHandler(DiscoveryOAuth2)
if err != nil {
return nil, err
}
handleWithCORS("/.well-known/oauth-authorization-server", oauthHandler)
// Handle the root path for the better user experience.
handleWithCORS("/", func(w http.ResponseWriter, r *http.Request) {
_, err := fmt.Fprintf(w, `<!DOCTYPE html>

Loading…
Cancel
Save