|
|
|
|
@ -1497,143 +1497,164 @@ func TestOAuth2DeviceFlow(t *testing.T) {
|
|
|
|
|
var conn *mock.Callback |
|
|
|
|
idTokensValidFor := time.Second * 30 |
|
|
|
|
|
|
|
|
|
for _, tc := range makeOAuth2Tests(clientID, clientSecret, now).tests { |
|
|
|
|
func() { |
|
|
|
|
ctx, cancel := context.WithCancel(context.Background()) |
|
|
|
|
defer cancel() |
|
|
|
|
tests := makeOAuth2Tests(clientID, clientSecret, now) |
|
|
|
|
testCases := []struct { |
|
|
|
|
name string |
|
|
|
|
tokenEndpoint string |
|
|
|
|
oauth2Tests oauth2Tests |
|
|
|
|
}{ |
|
|
|
|
{ |
|
|
|
|
name: "Actual token endpoint for devices", |
|
|
|
|
tokenEndpoint: "/token", |
|
|
|
|
oauth2Tests: tests, |
|
|
|
|
}, |
|
|
|
|
// TODO(nabokihms): delete temporary tests after removing the deprecated token endpoint support
|
|
|
|
|
{ |
|
|
|
|
name: "Deprecated token endpoint for devices", |
|
|
|
|
tokenEndpoint: "/device/token", |
|
|
|
|
oauth2Tests: tests, |
|
|
|
|
}, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Setup a dex server.
|
|
|
|
|
httpServer, s := newTestServer(ctx, t, func(c *Config) { |
|
|
|
|
c.Issuer += "/non-root-path" |
|
|
|
|
c.Now = now |
|
|
|
|
c.IDTokensValidFor = idTokensValidFor |
|
|
|
|
}) |
|
|
|
|
defer httpServer.Close() |
|
|
|
|
for _, testCase := range testCases { |
|
|
|
|
for _, tc := range testCase.oauth2Tests.tests { |
|
|
|
|
t.Run(tc.name, func(t *testing.T) { |
|
|
|
|
ctx, cancel := context.WithCancel(context.Background()) |
|
|
|
|
defer cancel() |
|
|
|
|
|
|
|
|
|
mockConn := s.connectors["mock"] |
|
|
|
|
conn = mockConn.Connector.(*mock.Callback) |
|
|
|
|
// Setup a dex server.
|
|
|
|
|
httpServer, s := newTestServer(ctx, t, func(c *Config) { |
|
|
|
|
c.Issuer += "/non-root-path" |
|
|
|
|
c.Now = now |
|
|
|
|
c.IDTokensValidFor = idTokensValidFor |
|
|
|
|
}) |
|
|
|
|
defer httpServer.Close() |
|
|
|
|
|
|
|
|
|
p, err := oidc.NewProvider(ctx, httpServer.URL) |
|
|
|
|
if err != nil { |
|
|
|
|
t.Fatalf("failed to get provider: %v", err) |
|
|
|
|
} |
|
|
|
|
mockConn := s.connectors["mock"] |
|
|
|
|
conn = mockConn.Connector.(*mock.Callback) |
|
|
|
|
|
|
|
|
|
// Add the Clients to the test server
|
|
|
|
|
client := storage.Client{ |
|
|
|
|
ID: clientID, |
|
|
|
|
RedirectURIs: []string{deviceCallbackURI}, |
|
|
|
|
Public: true, |
|
|
|
|
} |
|
|
|
|
if err := s.storage.CreateClient(client); err != nil { |
|
|
|
|
t.Fatalf("failed to create client: %v", err) |
|
|
|
|
} |
|
|
|
|
p, err := oidc.NewProvider(ctx, httpServer.URL) |
|
|
|
|
if err != nil { |
|
|
|
|
t.Fatalf("failed to get provider: %v", err) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Grab the issuer that we'll reuse for the different endpoints to hit
|
|
|
|
|
issuer, err := url.Parse(s.issuerURL.String()) |
|
|
|
|
if err != nil { |
|
|
|
|
t.Errorf("Could not parse issuer URL %v", err) |
|
|
|
|
} |
|
|
|
|
// Add the Clients to the test server
|
|
|
|
|
client := storage.Client{ |
|
|
|
|
ID: clientID, |
|
|
|
|
RedirectURIs: []string{deviceCallbackURI}, |
|
|
|
|
Public: true, |
|
|
|
|
} |
|
|
|
|
if err := s.storage.CreateClient(client); err != nil { |
|
|
|
|
t.Fatalf("failed to create client: %v", err) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Send a new Device Request
|
|
|
|
|
codeURL, _ := url.Parse(issuer.String()) |
|
|
|
|
codeURL.Path = path.Join(codeURL.Path, "device/code") |
|
|
|
|
// Grab the issuer that we'll reuse for the different endpoints to hit
|
|
|
|
|
issuer, err := url.Parse(s.issuerURL.String()) |
|
|
|
|
if err != nil { |
|
|
|
|
t.Errorf("Could not parse issuer URL %v", err) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
data := url.Values{} |
|
|
|
|
data.Set("client_id", clientID) |
|
|
|
|
data.Add("scope", strings.Join(requestedScopes, " ")) |
|
|
|
|
resp, err := http.PostForm(codeURL.String(), data) |
|
|
|
|
if err != nil { |
|
|
|
|
t.Errorf("Could not request device code: %v", err) |
|
|
|
|
} |
|
|
|
|
defer resp.Body.Close() |
|
|
|
|
responseBody, err := ioutil.ReadAll(resp.Body) |
|
|
|
|
if err != nil { |
|
|
|
|
t.Errorf("Could read device code response %v", err) |
|
|
|
|
} |
|
|
|
|
if resp.StatusCode != http.StatusOK { |
|
|
|
|
t.Errorf("%v - Unexpected Response Type. Expected 200 got %v. Response: %v", tc.name, resp.StatusCode, string(responseBody)) |
|
|
|
|
} |
|
|
|
|
if resp.Header.Get("Cache-Control") != "no-store" { |
|
|
|
|
t.Errorf("Cache-Control header doesn't exist in Device Code Response") |
|
|
|
|
} |
|
|
|
|
// Send a new Device Request
|
|
|
|
|
codeURL, _ := url.Parse(issuer.String()) |
|
|
|
|
codeURL.Path = path.Join(codeURL.Path, "device/code") |
|
|
|
|
|
|
|
|
|
// Parse the code response
|
|
|
|
|
var deviceCode deviceCodeResponse |
|
|
|
|
if err := json.Unmarshal(responseBody, &deviceCode); err != nil { |
|
|
|
|
t.Errorf("Unexpected Device Code Response Format %v", string(responseBody)) |
|
|
|
|
} |
|
|
|
|
data := url.Values{} |
|
|
|
|
data.Set("client_id", clientID) |
|
|
|
|
data.Add("scope", strings.Join(requestedScopes, " ")) |
|
|
|
|
resp, err := http.PostForm(codeURL.String(), data) |
|
|
|
|
if err != nil { |
|
|
|
|
t.Errorf("Could not request device code: %v", err) |
|
|
|
|
} |
|
|
|
|
defer resp.Body.Close() |
|
|
|
|
responseBody, err := ioutil.ReadAll(resp.Body) |
|
|
|
|
if err != nil { |
|
|
|
|
t.Errorf("Could read device code response %v", err) |
|
|
|
|
} |
|
|
|
|
if resp.StatusCode != http.StatusOK { |
|
|
|
|
t.Errorf("%v - Unexpected Response Type. Expected 200 got %v. Response: %v", tc.name, resp.StatusCode, string(responseBody)) |
|
|
|
|
} |
|
|
|
|
if resp.Header.Get("Cache-Control") != "no-store" { |
|
|
|
|
t.Errorf("Cache-Control header doesn't exist in Device Code Response") |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Mock the user hitting the verification URI and posting the form
|
|
|
|
|
verifyURL, _ := url.Parse(issuer.String()) |
|
|
|
|
verifyURL.Path = path.Join(verifyURL.Path, "/device/auth/verify_code") |
|
|
|
|
urlData := url.Values{} |
|
|
|
|
urlData.Set("user_code", deviceCode.UserCode) |
|
|
|
|
resp, err = http.PostForm(verifyURL.String(), urlData) |
|
|
|
|
if err != nil { |
|
|
|
|
t.Errorf("Error Posting Form: %v", err) |
|
|
|
|
} |
|
|
|
|
defer resp.Body.Close() |
|
|
|
|
responseBody, err = ioutil.ReadAll(resp.Body) |
|
|
|
|
if err != nil { |
|
|
|
|
t.Errorf("Could read verification response %v", err) |
|
|
|
|
} |
|
|
|
|
if resp.StatusCode != http.StatusOK { |
|
|
|
|
t.Errorf("%v - Unexpected Response Type. Expected 200 got %v. Response: %v", tc.name, resp.StatusCode, string(responseBody)) |
|
|
|
|
} |
|
|
|
|
// Parse the code response
|
|
|
|
|
var deviceCode deviceCodeResponse |
|
|
|
|
if err := json.Unmarshal(responseBody, &deviceCode); err != nil { |
|
|
|
|
t.Errorf("Unexpected Device Code Response Format %v", string(responseBody)) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Hit the Token Endpoint, and try and get an access token
|
|
|
|
|
tokenURL, _ := url.Parse(issuer.String()) |
|
|
|
|
tokenURL.Path = path.Join(tokenURL.Path, "/token") |
|
|
|
|
v := url.Values{} |
|
|
|
|
v.Add("grant_type", grantTypeDeviceCode) |
|
|
|
|
v.Add("device_code", deviceCode.DeviceCode) |
|
|
|
|
resp, err = http.PostForm(tokenURL.String(), v) |
|
|
|
|
if err != nil { |
|
|
|
|
t.Errorf("Could not request device token: %v", err) |
|
|
|
|
} |
|
|
|
|
defer resp.Body.Close() |
|
|
|
|
responseBody, err = ioutil.ReadAll(resp.Body) |
|
|
|
|
if err != nil { |
|
|
|
|
t.Errorf("Could read device token response %v", err) |
|
|
|
|
} |
|
|
|
|
if resp.StatusCode != http.StatusOK { |
|
|
|
|
t.Errorf("%v - Unexpected Token Response Type. Expected 200 got %v. Response: %v", tc.name, resp.StatusCode, string(responseBody)) |
|
|
|
|
} |
|
|
|
|
// Mock the user hitting the verification URI and posting the form
|
|
|
|
|
verifyURL, _ := url.Parse(issuer.String()) |
|
|
|
|
verifyURL.Path = path.Join(verifyURL.Path, "/device/auth/verify_code") |
|
|
|
|
urlData := url.Values{} |
|
|
|
|
urlData.Set("user_code", deviceCode.UserCode) |
|
|
|
|
resp, err = http.PostForm(verifyURL.String(), urlData) |
|
|
|
|
if err != nil { |
|
|
|
|
t.Errorf("Error Posting Form: %v", err) |
|
|
|
|
} |
|
|
|
|
defer resp.Body.Close() |
|
|
|
|
responseBody, err = ioutil.ReadAll(resp.Body) |
|
|
|
|
if err != nil { |
|
|
|
|
t.Errorf("Could read verification response %v", err) |
|
|
|
|
} |
|
|
|
|
if resp.StatusCode != http.StatusOK { |
|
|
|
|
t.Errorf("%v - Unexpected Response Type. Expected 200 got %v. Response: %v", tc.name, resp.StatusCode, string(responseBody)) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Parse the response
|
|
|
|
|
var tokenRes accessTokenResponse |
|
|
|
|
if err := json.Unmarshal(responseBody, &tokenRes); err != nil { |
|
|
|
|
t.Errorf("Unexpected Device Access Token Response Format %v", string(responseBody)) |
|
|
|
|
} |
|
|
|
|
// Hit the Token Endpoint, and try and get an access token
|
|
|
|
|
tokenURL, _ := url.Parse(issuer.String()) |
|
|
|
|
tokenURL.Path = path.Join(tokenURL.Path, testCase.tokenEndpoint) |
|
|
|
|
v := url.Values{} |
|
|
|
|
v.Add("grant_type", grantTypeDeviceCode) |
|
|
|
|
v.Add("device_code", deviceCode.DeviceCode) |
|
|
|
|
resp, err = http.PostForm(tokenURL.String(), v) |
|
|
|
|
if err != nil { |
|
|
|
|
t.Errorf("Could not request device token: %v", err) |
|
|
|
|
} |
|
|
|
|
defer resp.Body.Close() |
|
|
|
|
responseBody, err = ioutil.ReadAll(resp.Body) |
|
|
|
|
if err != nil { |
|
|
|
|
t.Errorf("Could read device token response %v", err) |
|
|
|
|
} |
|
|
|
|
if resp.StatusCode != http.StatusOK { |
|
|
|
|
t.Errorf("%v - Unexpected Token Response Type. Expected 200 got %v. Response: %v", tc.name, resp.StatusCode, string(responseBody)) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
token := &oauth2.Token{ |
|
|
|
|
AccessToken: tokenRes.AccessToken, |
|
|
|
|
TokenType: tokenRes.TokenType, |
|
|
|
|
RefreshToken: tokenRes.RefreshToken, |
|
|
|
|
} |
|
|
|
|
raw := make(map[string]interface{}) |
|
|
|
|
json.Unmarshal(responseBody, &raw) // no error checks for optional fields
|
|
|
|
|
token = token.WithExtra(raw) |
|
|
|
|
if secs := tokenRes.ExpiresIn; secs > 0 { |
|
|
|
|
token.Expiry = time.Now().Add(time.Duration(secs) * time.Second) |
|
|
|
|
} |
|
|
|
|
// Parse the response
|
|
|
|
|
var tokenRes accessTokenResponse |
|
|
|
|
if err := json.Unmarshal(responseBody, &tokenRes); err != nil { |
|
|
|
|
t.Errorf("Unexpected Device Access Token Response Format %v", string(responseBody)) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Run token tests to validate info is correct
|
|
|
|
|
// Create the OAuth2 config.
|
|
|
|
|
oauth2Config := &oauth2.Config{ |
|
|
|
|
ClientID: client.ID, |
|
|
|
|
ClientSecret: client.Secret, |
|
|
|
|
Endpoint: p.Endpoint(), |
|
|
|
|
Scopes: requestedScopes, |
|
|
|
|
RedirectURL: deviceCallbackURI, |
|
|
|
|
} |
|
|
|
|
if len(tc.scopes) != 0 { |
|
|
|
|
oauth2Config.Scopes = tc.scopes |
|
|
|
|
} |
|
|
|
|
err = tc.handleToken(ctx, p, oauth2Config, token, conn) |
|
|
|
|
if err != nil { |
|
|
|
|
t.Errorf("%s: %v", tc.name, err) |
|
|
|
|
} |
|
|
|
|
}() |
|
|
|
|
token := &oauth2.Token{ |
|
|
|
|
AccessToken: tokenRes.AccessToken, |
|
|
|
|
TokenType: tokenRes.TokenType, |
|
|
|
|
RefreshToken: tokenRes.RefreshToken, |
|
|
|
|
} |
|
|
|
|
raw := make(map[string]interface{}) |
|
|
|
|
json.Unmarshal(responseBody, &raw) // no error checks for optional fields
|
|
|
|
|
token = token.WithExtra(raw) |
|
|
|
|
if secs := tokenRes.ExpiresIn; secs > 0 { |
|
|
|
|
token.Expiry = time.Now().Add(time.Duration(secs) * time.Second) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Run token tests to validate info is correct
|
|
|
|
|
// Create the OAuth2 config.
|
|
|
|
|
oauth2Config := &oauth2.Config{ |
|
|
|
|
ClientID: client.ID, |
|
|
|
|
ClientSecret: client.Secret, |
|
|
|
|
Endpoint: p.Endpoint(), |
|
|
|
|
Scopes: requestedScopes, |
|
|
|
|
RedirectURL: deviceCallbackURI, |
|
|
|
|
} |
|
|
|
|
if len(tc.scopes) != 0 { |
|
|
|
|
oauth2Config.Scopes = tc.scopes |
|
|
|
|
} |
|
|
|
|
err = tc.handleToken(ctx, p, oauth2Config, token, conn) |
|
|
|
|
if err != nil { |
|
|
|
|
t.Errorf("%s: %v", tc.name, err) |
|
|
|
|
} |
|
|
|
|
}) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|