diff --git a/server/deviceflowhandlers.go b/server/deviceflowhandlers.go index ec5fb52b..b9fb652a 100644 --- a/server/deviceflowhandlers.go +++ b/server/deviceflowhandlers.go @@ -431,7 +431,7 @@ func (s *Server) verifyUserCode(w http.ResponseWriter, r *http.Request) { } // Redirect to Dex Auth Endpoint - authURL := path.Join(s.issuerURL.Path, "/auth") + authURL := s.absURL("/auth") u, err := url.Parse(authURL) if err != nil { s.renderError(r, w, http.StatusInternalServerError, "Invalid auth URI.") @@ -442,7 +442,7 @@ func (s *Server) verifyUserCode(w http.ResponseWriter, r *http.Request) { q.Set("client_secret", deviceRequest.ClientSecret) q.Set("state", deviceRequest.UserCode) q.Set("response_type", "code") - q.Set("redirect_uri", "/device/callback") + q.Set("redirect_uri", s.absPath(deviceCallbackURI)) q.Set("scope", strings.Join(deviceRequest.Scopes, " ")) u.RawQuery = q.Encode() diff --git a/server/deviceflowhandlers_test.go b/server/deviceflowhandlers_test.go index ec7bf29d..1cbd60f7 100644 --- a/server/deviceflowhandlers_test.go +++ b/server/deviceflowhandlers_test.go @@ -364,7 +364,7 @@ func TestDeviceCallback(t *testing.T) { // Setup a dex server. httpServer, s := newTestServer(t, func(c *Config) { - // c.Issuer = c.Issuer + "/non-root-path" + c.Issuer = c.Issuer + "/non-root-path" c.Now = now }) defer httpServer.Close() @@ -752,7 +752,8 @@ func TestVerifyCodeResponse(t *testing.T) { testDeviceRequest storage.DeviceRequest userCode string expectedResponseCode int - expectedRedirectPath string + expectedAuthPath string + shouldRedirectToAuth bool }{ { testName: "Unknown user code", @@ -765,7 +766,6 @@ func TestVerifyCodeResponse(t *testing.T) { }, userCode: "CODE-TEST", expectedResponseCode: http.StatusBadRequest, - expectedRedirectPath: "", }, { testName: "Expired user code", @@ -778,7 +778,6 @@ func TestVerifyCodeResponse(t *testing.T) { }, userCode: "ABCD-WXYZ", expectedResponseCode: http.StatusBadRequest, - expectedRedirectPath: "", }, { testName: "No user code", @@ -791,10 +790,9 @@ func TestVerifyCodeResponse(t *testing.T) { }, userCode: "", expectedResponseCode: http.StatusBadRequest, - expectedRedirectPath: "", }, { - testName: "Valid user code, expect redirect to auth endpoint", + testName: "Valid user code, expect redirect to auth endpoint with device callback", testDeviceRequest: storage.DeviceRequest{ UserCode: "ABCD-WXYZ", DeviceCode: "f00bar", @@ -804,7 +802,8 @@ func TestVerifyCodeResponse(t *testing.T) { }, userCode: "ABCD-WXYZ", expectedResponseCode: http.StatusFound, - expectedRedirectPath: "/auth", + expectedAuthPath: "/auth", + shouldRedirectToAuth: true, }, } for _, tc := range tests { @@ -839,15 +838,24 @@ func TestVerifyCodeResponse(t *testing.T) { t.Errorf("Unexpected Response Type. Expected %v got %v", tc.expectedResponseCode, rr.Code) } - u, err = url.Parse(s.issuerURL.String()) - if err != nil { - t.Errorf("Could not parse issuer URL %v", err) - } - u.Path = path.Join(u.Path, tc.expectedRedirectPath) - location := rr.Header().Get("Location") - if rr.Code == http.StatusFound && !strings.HasPrefix(location, u.Path) { - t.Errorf("Invalid Redirect. Expected %v got %v", u.Path, location) + if rr.Code == http.StatusFound && tc.shouldRedirectToAuth { + // Parse the redirect location + redirectURL, err := url.Parse(location) + if err != nil { + t.Errorf("Could not parse redirect URL: %v", err) + return + } + + // Check that the redirect path contains /auth + if !strings.Contains(redirectURL.Path, tc.expectedAuthPath) { + t.Errorf("Invalid Redirect Path. Expected to contain %q got %q", tc.expectedAuthPath, redirectURL.Path) + } + + // Check that redirect_uri parameter contains /device/callback + if !strings.Contains(location, "redirect_uri=%2Fnon-root-path%2Fdevice%2Fcallback") { + t.Errorf("Invalid redirect_uri parameter. Expected to contain /device/callback (URL encoded), got %v", location) + } } }) } diff --git a/server/server_test.go b/server/server_test.go index 5a735f1d..e61e21ab 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1640,7 +1640,7 @@ func TestOAuth2DeviceFlow(t *testing.T) { // Add the Clients to the test server client := storage.Client{ ID: clientID, - RedirectURIs: []string{deviceCallbackURI}, + RedirectURIs: []string{s.absPath(deviceCallbackURI)}, Public: true, } if err := s.storage.CreateClient(ctx, client); err != nil { @@ -1751,7 +1751,7 @@ func TestOAuth2DeviceFlow(t *testing.T) { ClientSecret: client.Secret, Endpoint: p.Endpoint(), Scopes: requestedScopes, - RedirectURL: deviceCallbackURI, + RedirectURL: s.absURL(deviceCallbackURI), } if len(tc.scopes) != 0 { oauth2Config.Scopes = tc.scopes