From 29c7b6f4e3c44fb4afd135ead636b4bb1edc89bd Mon Sep 17 00:00:00 2001 From: Maksim Nabokikh Date: Wed, 18 Feb 2026 10:04:51 +0100 Subject: [PATCH] feat: validate redirect URIs and safely append parameters (#4559) Signed-off-by: maksim.nabokikh --- examples/example-app/main.go | 14 ++++- server/oauth2.go | 24 ++++++-- server/oauth2_test.go | 107 +++++++++++++++++++++++++++++++++++ 3 files changed, 137 insertions(+), 8 deletions(-) diff --git a/examples/example-app/main.go b/examples/example-app/main.go index af566704..22a7b2bd 100644 --- a/examples/example-app/main.go +++ b/examples/example-app/main.go @@ -274,11 +274,21 @@ func (a *app) handleLogin(w http.ResponseWriter, r *http.Request) { scopes = append(scopes, "offline_access") } authCodeURL = a.oauth2Config(scopes).AuthCodeURL(exampleAppState, authCodeOptions...) + + // Parse the auth code URL and safely add connector_id parameter if provided + u, err := url.Parse(authCodeURL) + if err != nil { + http.Error(w, "Failed to parse auth URL", http.StatusInternalServerError) + return + } + if connectorID != "" { - authCodeURL = authCodeURL + "&connector_id=" + connectorID + query := u.Query() + query.Set("connector_id", connectorID) + u.RawQuery = query.Encode() } - http.Redirect(w, r, authCodeURL, http.StatusSeeOther) + http.Redirect(w, r, u.String(), http.StatusSeeOther) } func (a *app) handleCallback(w http.ResponseWriter, r *http.Request) { diff --git a/server/oauth2.go b/server/oauth2.go index d5415986..5821f0ff 100644 --- a/server/oauth2.go +++ b/server/oauth2.go @@ -64,13 +64,25 @@ func (err *redirectedAuthErr) Handler() http.Handler { if err.Description != "" { v.Add("error_description", err.Description) } - var redirectURI string - if strings.Contains(err.RedirectURI, "?") { - redirectURI = err.RedirectURI + "&" + v.Encode() - } else { - redirectURI = err.RedirectURI + "?" + v.Encode() + + // Parse the redirect URI to ensure it's valid before redirecting + u, parseErr := url.Parse(err.RedirectURI) + if parseErr != nil { + // If URI parsing fails, respond with an error instead of redirecting + http.Error(w, "Invalid redirect URI", http.StatusBadRequest) + return } - http.Redirect(w, r, redirectURI, http.StatusSeeOther) + + // Add error parameters to the URL + query := u.Query() + for key, values := range v { + for _, value := range values { + query.Add(key, value) + } + } + u.RawQuery = query.Encode() + + http.Redirect(w, r, u.String(), http.StatusSeeOther) } return http.HandlerFunc(hf) } diff --git a/server/oauth2_test.go b/server/oauth2_test.go index d5fbd42e..ea930cb3 100644 --- a/server/oauth2_test.go +++ b/server/oauth2_test.go @@ -686,3 +686,110 @@ func TestSignerKeySet(t *testing.T) { }) } } + +func TestRedirectedAuthErrHandler(t *testing.T) { + tests := []struct { + name string + redirectURI string + state string + errType string + description string + wantStatus int + wantErr bool + }{ + { + name: "valid redirect uri with error parameters", + redirectURI: "https://example.com/callback", + state: "state123", + errType: errInvalidRequest, + description: "Invalid request parameter", + wantStatus: http.StatusSeeOther, + wantErr: false, + }, + { + name: "valid redirect uri with query params", + redirectURI: "https://example.com/callback?existing=param&another=value", + state: "state456", + errType: errAccessDenied, + description: "User denied access", + wantStatus: http.StatusSeeOther, + wantErr: false, + }, + { + name: "valid redirect uri without description", + redirectURI: "https://example.com/callback", + state: "state789", + errType: errServerError, + description: "", + wantStatus: http.StatusSeeOther, + wantErr: false, + }, + { + name: "invalid redirect uri", + redirectURI: "not a valid url ://", + state: "state", + errType: errInvalidRequest, + description: "Test error", + wantStatus: http.StatusBadRequest, + wantErr: true, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + err := &redirectedAuthErr{ + State: tc.state, + RedirectURI: tc.redirectURI, + Type: tc.errType, + Description: tc.description, + } + + handler := err.Handler() + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "/", nil) + + handler.ServeHTTP(w, r) + + if w.Code != tc.wantStatus { + t.Errorf("expected status %d, got %d", tc.wantStatus, w.Code) + } + + if tc.wantStatus == http.StatusSeeOther { + // Verify the redirect location is a valid URL + location := w.Header().Get("Location") + if location == "" { + t.Fatalf("expected Location header, got empty string") + } + + // Parse the redirect URL to verify it's valid + redirectURL, parseErr := url.Parse(location) + if parseErr != nil { + t.Fatalf("invalid redirect URL: %v", parseErr) + } + + // Verify error parameters are present in the query string + query := redirectURL.Query() + if query.Get("state") != tc.state { + t.Errorf("expected state %q, got %q", tc.state, query.Get("state")) + } + if query.Get("error") != tc.errType { + t.Errorf("expected error type %q, got %q", tc.errType, query.Get("error")) + } + if tc.description != "" && query.Get("error_description") != tc.description { + t.Errorf("expected error_description %q, got %q", tc.description, query.Get("error_description")) + } + + // Verify that existing query parameters are preserved + if tc.name == "valid redirect uri with query params" { + if query.Get("existing") != "param" { + t.Errorf("expected existing parameter 'param', got %q", query.Get("existing")) + } + if query.Get("another") != "value" { + t.Errorf("expected another parameter 'value', got %q", query.Get("another")) + } + } + } + }) + } +}