|
|
|
|
@ -25,6 +25,7 @@ import (
|
|
|
|
|
"github.com/kylelemons/godebug/pretty" |
|
|
|
|
"github.com/prometheus/client_golang/prometheus" |
|
|
|
|
"github.com/sirupsen/logrus" |
|
|
|
|
"github.com/stretchr/testify/require" |
|
|
|
|
"golang.org/x/crypto/bcrypt" |
|
|
|
|
"golang.org/x/oauth2" |
|
|
|
|
jose "gopkg.in/square/go-jose.v2" |
|
|
|
|
@ -223,6 +224,9 @@ type test struct {
|
|
|
|
|
// extra parameters to pass when retrieving id token
|
|
|
|
|
retrieveTokenOptions []oauth2.AuthCodeOption |
|
|
|
|
|
|
|
|
|
// define an error response, when the test expects an error on the auth endpoint
|
|
|
|
|
authError *OAuth2ErrorResponse |
|
|
|
|
|
|
|
|
|
// define an error response, when the test expects an error on the token endpoint
|
|
|
|
|
tokenError ErrorResponse |
|
|
|
|
} |
|
|
|
|
@ -619,6 +623,19 @@ func makeOAuth2Tests(clientID string, clientSecret string, now func() time.Time)
|
|
|
|
|
StatusCode: http.StatusBadRequest, |
|
|
|
|
}, |
|
|
|
|
}, |
|
|
|
|
{ |
|
|
|
|
name: "Request parameter in authorization query", |
|
|
|
|
authCodeOptions: []oauth2.AuthCodeOption{ |
|
|
|
|
oauth2.SetAuthURLParam("request", "anything"), |
|
|
|
|
}, |
|
|
|
|
authError: &OAuth2ErrorResponse{ |
|
|
|
|
Error: errRequestNotSupported, |
|
|
|
|
ErrorDescription: "Server does not support request parameter.", |
|
|
|
|
}, |
|
|
|
|
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error { |
|
|
|
|
return nil |
|
|
|
|
}, |
|
|
|
|
}, |
|
|
|
|
}, |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
@ -677,7 +694,7 @@ func TestOAuth2CodeFlow(t *testing.T) {
|
|
|
|
|
state = "a_state" |
|
|
|
|
) |
|
|
|
|
defer func() { |
|
|
|
|
if !gotCode { |
|
|
|
|
if !gotCode && tc.authError == nil { |
|
|
|
|
t.Errorf("never got a code in callback\n%s\n%s", reqDump, respDump) |
|
|
|
|
} |
|
|
|
|
}() |
|
|
|
|
@ -696,12 +713,18 @@ func TestOAuth2CodeFlow(t *testing.T) {
|
|
|
|
|
|
|
|
|
|
// Did dex return an error?
|
|
|
|
|
if errType := q.Get("error"); errType != "" { |
|
|
|
|
if desc := q.Get("error_description"); desc != "" { |
|
|
|
|
t.Errorf("got error from server %s: %s", errType, desc) |
|
|
|
|
} else { |
|
|
|
|
t.Errorf("got error from server %s", errType) |
|
|
|
|
description := q.Get("error_description") |
|
|
|
|
|
|
|
|
|
if tc.authError == nil { |
|
|
|
|
if description != "" { |
|
|
|
|
t.Errorf("got error from server %s: %s", errType, description) |
|
|
|
|
} else { |
|
|
|
|
t.Errorf("got error from server %s", errType) |
|
|
|
|
} |
|
|
|
|
w.WriteHeader(http.StatusInternalServerError) |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
w.WriteHeader(http.StatusInternalServerError) |
|
|
|
|
require.Equal(t, *tc.authError, OAuth2ErrorResponse{Error: errType, ErrorDescription: description}) |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|