mirror of https://github.com/dexidp/dex.git
Browse Source
When GetAuthRequest returns ErrNotFound in handleApproval, render a 400 "User session error." instead of logging + rendering a 500 "Database error.". Covers the double-submit race where sendCodeResponse deletes the auth request on first approval and the second request finds nothing. --- Signed-off-by: Mark Liu <mark@prove.com.au> Signed-off-by: mark-liu <mark-liu@users.noreply.github.com>pull/4622/head
2 changed files with 121 additions and 0 deletions
@ -0,0 +1,117 @@
|
||||
package server |
||||
|
||||
import ( |
||||
"context" |
||||
"crypto/hmac" |
||||
"crypto/sha256" |
||||
"encoding/base64" |
||||
"errors" |
||||
"net/http" |
||||
"net/http/httptest" |
||||
"net/url" |
||||
"strings" |
||||
"testing" |
||||
"time" |
||||
|
||||
"github.com/stretchr/testify/require" |
||||
|
||||
"github.com/dexidp/dex/storage" |
||||
) |
||||
|
||||
type getAuthRequestErrorStorage struct { |
||||
storage.Storage |
||||
err error |
||||
} |
||||
|
||||
func (s *getAuthRequestErrorStorage) GetAuthRequest(context.Context, string) (storage.AuthRequest, error) { |
||||
return storage.AuthRequest{}, s.err |
||||
} |
||||
|
||||
func TestHandleApprovalGetAuthRequestErrorGET(t *testing.T) { |
||||
httpServer, server := newTestServer(t, func(c *Config) { |
||||
c.Storage = &getAuthRequestErrorStorage{Storage: c.Storage, err: errors.New("storage unavailable")} |
||||
}) |
||||
defer httpServer.Close() |
||||
|
||||
rr := httptest.NewRecorder() |
||||
req := httptest.NewRequest(http.MethodGet, "/approval?req=any&hmac=AQ", nil) |
||||
|
||||
server.ServeHTTP(rr, req) |
||||
|
||||
require.Equal(t, http.StatusInternalServerError, rr.Code) |
||||
require.Contains(t, rr.Body.String(), "Database error.") |
||||
} |
||||
|
||||
func TestHandleApprovalGetAuthRequestNotFoundGET(t *testing.T) { |
||||
httpServer, server := newTestServer(t, nil) |
||||
defer httpServer.Close() |
||||
|
||||
rr := httptest.NewRecorder() |
||||
req := httptest.NewRequest(http.MethodGet, "/approval?req=does-not-exist&hmac=AQ", nil) |
||||
|
||||
server.ServeHTTP(rr, req) |
||||
|
||||
require.Equal(t, http.StatusBadRequest, rr.Code) |
||||
require.Contains(t, rr.Body.String(), "User session error.") |
||||
require.NotContains(t, rr.Body.String(), "Database error.") |
||||
} |
||||
|
||||
func TestHandleApprovalGetAuthRequestNotFoundPOST(t *testing.T) { |
||||
httpServer, server := newTestServer(t, nil) |
||||
defer httpServer.Close() |
||||
|
||||
body := strings.NewReader("approval=approve&req=does-not-exist&hmac=AQ") |
||||
rr := httptest.NewRecorder() |
||||
req := httptest.NewRequest(http.MethodPost, "/approval", body) |
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") |
||||
|
||||
server.ServeHTTP(rr, req) |
||||
|
||||
require.Equal(t, http.StatusBadRequest, rr.Code) |
||||
require.Contains(t, rr.Body.String(), "User session error.") |
||||
require.NotContains(t, rr.Body.String(), "Database error.") |
||||
} |
||||
|
||||
func TestHandleApprovalDoubleSubmitPOST(t *testing.T) { |
||||
ctx := t.Context() |
||||
httpServer, server := newTestServer(t, nil) |
||||
defer httpServer.Close() |
||||
|
||||
authReq := storage.AuthRequest{ |
||||
ID: "approval-double-submit", |
||||
ClientID: "test", |
||||
ResponseTypes: []string{responseTypeCode}, |
||||
RedirectURI: "https://client.example/callback", |
||||
Expiry: time.Now().Add(time.Minute), |
||||
LoggedIn: true, |
||||
HMACKey: []byte("approval-double-submit-key"), |
||||
} |
||||
require.NoError(t, server.storage.CreateAuthRequest(ctx, authReq)) |
||||
|
||||
h := hmac.New(sha256.New, authReq.HMACKey) |
||||
h.Write([]byte(authReq.ID)) |
||||
mac := base64.RawURLEncoding.EncodeToString(h.Sum(nil)) |
||||
|
||||
form := url.Values{ |
||||
"approval": {"approve"}, |
||||
"req": {authReq.ID}, |
||||
"hmac": {mac}, |
||||
} |
||||
|
||||
firstRR := httptest.NewRecorder() |
||||
firstReq := httptest.NewRequest(http.MethodPost, "/approval", strings.NewReader(form.Encode())) |
||||
firstReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") |
||||
server.ServeHTTP(firstRR, firstReq) |
||||
|
||||
require.Equal(t, http.StatusSeeOther, firstRR.Code) |
||||
require.Contains(t, firstRR.Header().Get("Location"), "https://client.example/callback") |
||||
|
||||
secondRR := httptest.NewRecorder() |
||||
secondReq := httptest.NewRequest(http.MethodPost, "/approval", strings.NewReader(form.Encode())) |
||||
secondReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") |
||||
server.ServeHTTP(secondRR, secondReq) |
||||
|
||||
require.Equal(t, http.StatusBadRequest, secondRR.Code) |
||||
require.Contains(t, secondRR.Body.String(), "User session error.") |
||||
require.NotContains(t, secondRR.Body.String(), "Database error.") |
||||
} |
||||
Loading…
Reference in new issue