package server import ( "bytes" "encoding/json" "io" "net/http" "net/http/httptest" "net/url" "path" "strings" "testing" "time" "github.com/dexidp/dex/storage" ) func TestDeviceVerificationURI(t *testing.T) { t0 := time.Now() now := func() time.Time { return t0 } // Setup a dex server. httpServer, s := newTestServer(t, func(c *Config) { c.Issuer += "/non-root-path" c.Now = now }) defer httpServer.Close() u, err := url.Parse(s.issuerURL.String()) if err != nil { t.Fatalf("Could not parse issuer URL %v", err) } u.Path = path.Join(u.Path, "/device/auth/verify_code") uri := s.getDeviceVerificationURI() if uri != u.Path { t.Errorf("Invalid verification URI. Expected %v got %v", u.Path, uri) } } func TestHandleDeviceCode(t *testing.T) { t0 := time.Now() now := func() time.Time { return t0 } tests := []struct { testName string clientID string codeChallengeMethod string requestType string scopes []string expectedResponseCode int expectedContentType string expectedServerResponse string }{ { testName: "New Code", clientID: "test", requestType: "POST", scopes: []string{"openid", "profile", "email"}, expectedResponseCode: http.StatusOK, expectedContentType: "application/json", }, { testName: "Invalid request Type (GET)", clientID: "test", requestType: "GET", scopes: []string{"openid", "profile", "email"}, expectedResponseCode: http.StatusBadRequest, expectedContentType: "application/json", }, { testName: "New Code with valid PKCE", clientID: "test", requestType: "POST", scopes: []string{"openid", "profile", "email"}, codeChallengeMethod: "S256", expectedResponseCode: http.StatusOK, expectedContentType: "application/json", }, { testName: "Invalid code challenge method", clientID: "test", requestType: "POST", codeChallengeMethod: "invalid", scopes: []string{"openid", "profile", "email"}, expectedResponseCode: http.StatusBadRequest, expectedContentType: "application/json", }, { testName: "New Code without scope", clientID: "test", requestType: "POST", scopes: []string{}, expectedResponseCode: http.StatusOK, expectedContentType: "application/json", }, } for _, tc := range tests { t.Run(tc.testName, func(t *testing.T) { // Setup a dex server. httpServer, s := newTestServer(t, func(c *Config) { c.Issuer += "/non-root-path" c.Now = now }) defer httpServer.Close() u, err := url.Parse(s.issuerURL.String()) if err != nil { t.Fatalf("Could not parse issuer URL %v", err) } u.Path = path.Join(u.Path, "device/code") data := url.Values{} data.Set("client_id", tc.clientID) data.Set("code_challenge_method", tc.codeChallengeMethod) for _, scope := range tc.scopes { data.Add("scope", scope) } req, _ := http.NewRequest(tc.requestType, u.String(), bytes.NewBufferString(data.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") rr := httptest.NewRecorder() s.ServeHTTP(rr, req) if rr.Code != tc.expectedResponseCode { t.Errorf("Unexpected Response Type. Expected %v got %v", tc.expectedResponseCode, rr.Code) } if rr.Header().Get("content-type") != tc.expectedContentType { t.Errorf("Unexpected Response Content Type. Expected %v got %v", tc.expectedContentType, rr.Header().Get("content-type")) } body, err := io.ReadAll(rr.Body) if err != nil { t.Errorf("Could read token response %v", err) } if tc.expectedResponseCode == http.StatusOK { var resp deviceCodeResponse if err := json.Unmarshal(body, &resp); err != nil { t.Errorf("Unexpected Device Code Response Format %v", string(body)) } } }) } } func TestDeviceCallback(t *testing.T) { t0 := time.Now() now := func() time.Time { return t0 } type formValues struct { state string code string error string } // Base "Control" test values baseFormValues := formValues{ state: "XXXX-XXXX", code: "somecode", } baseAuthCode := storage.AuthCode{ ID: "somecode", ClientID: "testclient", RedirectURI: deviceCallbackURI, Nonce: "", Scopes: []string{"openid", "profile", "email"}, ConnectorID: "mock", ConnectorData: nil, Claims: storage.Claims{}, Expiry: now().Add(5 * time.Minute), } baseDeviceRequest := storage.DeviceRequest{ UserCode: "XXXX-XXXX", DeviceCode: "devicecode", ClientID: "testclient", ClientSecret: "", Scopes: []string{"openid", "profile", "email"}, Expiry: now().Add(5 * time.Minute), } baseDeviceToken := storage.DeviceToken{ DeviceCode: "devicecode", Status: deviceTokenPending, Token: "", Expiry: now().Add(5 * time.Minute), LastRequestTime: time.Time{}, PollIntervalSeconds: 0, } tests := []struct { testName string expectedResponseCode int expectedServerResponse string values formValues testAuthCode storage.AuthCode testDeviceRequest storage.DeviceRequest testDeviceToken storage.DeviceToken }{ { testName: "Missing State", values: formValues{ state: "", code: "somecode", error: "", }, expectedResponseCode: http.StatusBadRequest, }, { testName: "Missing Code", values: formValues{ state: "XXXX-XXXX", code: "", error: "", }, expectedResponseCode: http.StatusBadRequest, }, { testName: "Error During Authorization", values: formValues{ state: "XXXX-XXXX", code: "somecode", error: "Error Condition", }, expectedResponseCode: http.StatusBadRequest, // Note: Error details should NOT be displayed to user anymore. // Instead, a safe generic message is shown. }, { testName: "Expired Auth Code", values: baseFormValues, testAuthCode: storage.AuthCode{ ID: "somecode", ClientID: "testclient", RedirectURI: deviceCallbackURI, Nonce: "", Scopes: []string{"openid", "profile", "email"}, ConnectorID: "pic", ConnectorData: nil, Claims: storage.Claims{}, Expiry: now().Add(-5 * time.Minute), }, expectedResponseCode: http.StatusBadRequest, }, { testName: "Invalid Auth Code", values: baseFormValues, testAuthCode: storage.AuthCode{ ID: "somecode", ClientID: "testclient", RedirectURI: deviceCallbackURI, Nonce: "", Scopes: []string{"openid", "profile", "email"}, ConnectorID: "pic", ConnectorData: nil, Claims: storage.Claims{}, Expiry: now().Add(5 * time.Minute), }, expectedResponseCode: http.StatusBadRequest, }, { testName: "Expired Device Request", values: baseFormValues, testAuthCode: baseAuthCode, testDeviceRequest: storage.DeviceRequest{ UserCode: "XXXX-XXXX", DeviceCode: "devicecode", ClientID: "testclient", Scopes: []string{"openid", "profile", "email"}, Expiry: now().Add(-5 * time.Minute), }, expectedResponseCode: http.StatusBadRequest, }, { testName: "Non-Existent User Code", values: baseFormValues, testAuthCode: baseAuthCode, testDeviceRequest: storage.DeviceRequest{ UserCode: "ZZZZ-ZZZZ", DeviceCode: "devicecode", Scopes: []string{"openid", "profile", "email"}, Expiry: now().Add(5 * time.Minute), }, expectedResponseCode: http.StatusBadRequest, }, { testName: "Bad Device Request Client", values: baseFormValues, testAuthCode: baseAuthCode, testDeviceRequest: storage.DeviceRequest{ UserCode: "XXXX-XXXX", DeviceCode: "devicecode", Scopes: []string{"openid", "profile", "email"}, Expiry: now().Add(5 * time.Minute), }, expectedResponseCode: http.StatusUnauthorized, }, { testName: "Bad Device Request Secret", values: baseFormValues, testAuthCode: baseAuthCode, testDeviceRequest: storage.DeviceRequest{ UserCode: "XXXX-XXXX", DeviceCode: "devicecode", ClientSecret: "foobar", Scopes: []string{"openid", "profile", "email"}, Expiry: now().Add(5 * time.Minute), }, expectedResponseCode: http.StatusUnauthorized, }, { testName: "Expired Device Token", values: baseFormValues, testAuthCode: baseAuthCode, testDeviceRequest: baseDeviceRequest, testDeviceToken: storage.DeviceToken{ DeviceCode: "devicecode", Status: deviceTokenPending, Token: "", Expiry: now().Add(-5 * time.Minute), LastRequestTime: time.Time{}, PollIntervalSeconds: 0, }, expectedResponseCode: http.StatusBadRequest, }, { testName: "Device Code Already Redeemed", values: baseFormValues, testAuthCode: baseAuthCode, testDeviceRequest: baseDeviceRequest, testDeviceToken: storage.DeviceToken{ DeviceCode: "devicecode", Status: deviceTokenComplete, Token: "", Expiry: now().Add(5 * time.Minute), LastRequestTime: time.Time{}, PollIntervalSeconds: 0, }, expectedResponseCode: http.StatusBadRequest, }, { testName: "Successful Exchange", values: baseFormValues, testAuthCode: baseAuthCode, testDeviceRequest: baseDeviceRequest, testDeviceToken: baseDeviceToken, expectedResponseCode: http.StatusOK, }, { testName: "Prevent cross-site scripting", values: formValues{ state: "XXXX-XXXX", code: "somecode", error: "", }, expectedResponseCode: http.StatusBadRequest, // Note: XSS data should NOT be displayed to user anymore. // Instead, a safe generic message is shown. }, } for _, tc := range tests { t.Run(tc.testName, func(t *testing.T) { ctx := t.Context() // Setup a dex server. httpServer, s := newTestServer(t, func(c *Config) { c.Issuer = c.Issuer + "/non-root-path" c.Now = now }) defer httpServer.Close() if err := s.storage.CreateAuthCode(ctx, tc.testAuthCode); err != nil { t.Fatalf("failed to create auth code: %v", err) } if err := s.storage.CreateDeviceRequest(ctx, tc.testDeviceRequest); err != nil { t.Fatalf("failed to create device request: %v", err) } if err := s.storage.CreateDeviceToken(ctx, tc.testDeviceToken); err != nil { t.Fatalf("failed to create device token: %v", err) } client := storage.Client{ ID: "testclient", Secret: "", RedirectURIs: []string{deviceCallbackURI}, } if err := s.storage.CreateClient(ctx, client); err != nil { t.Fatalf("failed to create client: %v", err) } u, err := url.Parse(s.issuerURL.String()) if err != nil { t.Fatalf("Could not parse issuer URL %v", err) } u.Path = path.Join(u.Path, "device/callback") q := u.Query() q.Set("state", tc.values.state) q.Set("code", tc.values.code) q.Set("error", tc.values.error) u.RawQuery = q.Encode() req, _ := http.NewRequest("GET", u.String(), nil) req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") rr := httptest.NewRecorder() s.ServeHTTP(rr, req) if rr.Code != tc.expectedResponseCode { t.Errorf("%s: Unexpected Response Type. Expected %v got %v", tc.testName, tc.expectedResponseCode, rr.Code) } if len(tc.expectedServerResponse) > 0 { result, _ := io.ReadAll(rr.Body) if string(result) != tc.expectedServerResponse { t.Errorf("%s: Unexpected Response. Expected %q got %q", tc.testName, tc.expectedServerResponse, result) } } // Special check for error message safety tests if tc.testName == "Prevent cross-site scripting" || tc.testName == "Error During Authorization" { result, _ := io.ReadAll(rr.Body) responseBody := string(result) // Error details should NOT be present in the response (for security) if tc.testName == "Prevent cross-site scripting" { if strings.Contains(responseBody, "