package server import ( "bytes" "context" "encoding/json" "errors" "fmt" "net/http" "net/http/httptest" "net/url" "path" "strings" "testing" "time" gosundheit "github.com/AppsFlyer/go-sundheit" "github.com/AppsFlyer/go-sundheit/checks" "github.com/coreos/go-oidc/v3/oidc" "github.com/stretchr/testify/require" "golang.org/x/crypto/bcrypt" "golang.org/x/oauth2" "github.com/dexidp/dex/storage" ) func boolPtr(v bool) *bool { return &v } func TestHandleHealth(t *testing.T) { httpServer, server := newTestServer(t, nil) defer httpServer.Close() rr := httptest.NewRecorder() server.ServeHTTP(rr, httptest.NewRequest("GET", "/healthz", nil)) if rr.Code != http.StatusOK { t.Errorf("expected 200 got %d", rr.Code) } } func TestHandleDiscovery(t *testing.T) { httpServer, server := newTestServer(t, nil) defer httpServer.Close() rr := httptest.NewRecorder() server.ServeHTTP(rr, httptest.NewRequest("GET", "/.well-known/openid-configuration", nil)) if rr.Code != http.StatusOK { t.Errorf("expected 200 got %d", rr.Code) } var res discovery err := json.NewDecoder(rr.Result().Body).Decode(&res) require.NoError(t, err) require.Equal(t, discovery{ Issuer: httpServer.URL, Auth: fmt.Sprintf("%s/auth", httpServer.URL), Token: fmt.Sprintf("%s/token", httpServer.URL), Keys: fmt.Sprintf("%s/keys", httpServer.URL), UserInfo: fmt.Sprintf("%s/userinfo", httpServer.URL), DeviceEndpoint: fmt.Sprintf("%s/device/code", httpServer.URL), Introspect: fmt.Sprintf("%s/token/introspect", httpServer.URL), GrantTypes: []string{ "authorization_code", "refresh_token", "urn:ietf:params:oauth:grant-type:device_code", "urn:ietf:params:oauth:grant-type:token-exchange", }, ResponseTypes: []string{ "code", }, Subjects: []string{ "public", }, IDTokenAlgs: []string{ "RS256", }, CodeChallengeAlgs: []string{ "S256", "plain", }, Scopes: []string{ "openid", "email", "groups", "profile", "offline_access", }, AuthMethods: []string{ "client_secret_basic", "client_secret_post", }, Claims: []string{ "iss", "sub", "aud", "iat", "exp", "email", "email_verified", "locale", "name", "preferred_username", "at_hash", }, }, res) } func TestHandleHealthFailure(t *testing.T) { httpServer, server := newTestServer(t, func(c *Config) { c.HealthChecker = gosundheit.New() c.HealthChecker.RegisterCheck( &checks.CustomCheck{ CheckName: "fail", CheckFunc: func(_ context.Context) (details interface{}, err error) { return nil, errors.New("error") }, }, gosundheit.InitiallyPassing(false), gosundheit.ExecutionPeriod(1*time.Second), ) }) defer httpServer.Close() rr := httptest.NewRecorder() server.ServeHTTP(rr, httptest.NewRequest("GET", "/healthz", nil)) if rr.Code != http.StatusInternalServerError { t.Errorf("expected 500 got %d", rr.Code) } } type emptyStorage struct { storage.Storage } func (*emptyStorage) GetAuthRequest(context.Context, string) (storage.AuthRequest, error) { return storage.AuthRequest{}, storage.ErrNotFound } func TestHandleInvalidOAuth2Callbacks(t *testing.T) { httpServer, server := newTestServer(t, func(c *Config) { c.Storage = &emptyStorage{c.Storage} }) defer httpServer.Close() tests := []struct { TargetURI string ExpectedCode int }{ {"/callback", http.StatusBadRequest}, {"/callback?code=&state=", http.StatusBadRequest}, {"/callback?code=AAAAAAA&state=BBBBBBB", http.StatusBadRequest}, } rr := httptest.NewRecorder() for i, r := range tests { server.ServeHTTP(rr, httptest.NewRequest("GET", r.TargetURI, nil)) if rr.Code != r.ExpectedCode { t.Fatalf("test %d expected %d, got %d", i, r.ExpectedCode, rr.Code) } } } func TestHandleInvalidSAMLCallbacks(t *testing.T) { httpServer, server := newTestServer(t, func(c *Config) { c.Storage = &emptyStorage{c.Storage} }) defer httpServer.Close() type requestForm struct { RelayState string } tests := []struct { RequestForm requestForm ExpectedCode int }{ {requestForm{}, http.StatusBadRequest}, {requestForm{RelayState: "AAAAAAA"}, http.StatusBadRequest}, } rr := httptest.NewRecorder() for i, r := range tests { jsonValue, err := json.Marshal(r.RequestForm) if err != nil { t.Fatal(err.Error()) } server.ServeHTTP(rr, httptest.NewRequest("POST", "/callback", bytes.NewBuffer(jsonValue))) if rr.Code != r.ExpectedCode { t.Fatalf("test %d expected %d, got %d", i, r.ExpectedCode, rr.Code) } } } // TestHandleAuthCode checks that it is forbidden to use same code twice func TestHandleAuthCode(t *testing.T) { tests := []struct { name string handleCode func(*testing.T, context.Context, *oauth2.Config, string) }{ { name: "Code Reuse should return invalid_grant", handleCode: func(t *testing.T, ctx context.Context, oauth2Config *oauth2.Config, code string) { _, err := oauth2Config.Exchange(ctx, code) require.NoError(t, err) _, err = oauth2Config.Exchange(ctx, code) require.Error(t, err) oauth2Err, ok := err.(*oauth2.RetrieveError) require.True(t, ok) var errResponse struct{ Error string } err = json.Unmarshal(oauth2Err.Body, &errResponse) require.NoError(t, err) // invalid_grant must be returned for invalid values // https://tools.ietf.org/html/rfc6749#section-5.2 require.Equal(t, errInvalidGrant, errResponse.Error) }, }, { name: "No Code should return invalid_request", handleCode: func(t *testing.T, ctx context.Context, oauth2Config *oauth2.Config, _ string) { _, err := oauth2Config.Exchange(ctx, "") require.Error(t, err) oauth2Err, ok := err.(*oauth2.RetrieveError) require.True(t, ok) var errResponse struct{ Error string } err = json.Unmarshal(oauth2Err.Body, &errResponse) require.NoError(t, err) require.Equal(t, errInvalidRequest, errResponse.Error) }, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { ctx := t.Context() httpServer, s := newTestServer(t, func(c *Config) { c.Issuer += "/non-root-path" }) defer httpServer.Close() p, err := oidc.NewProvider(ctx, httpServer.URL) require.NoError(t, err) var oauth2Client oauth2Client oauth2Client.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/callback" { http.Redirect(w, r, oauth2Client.config.AuthCodeURL(""), http.StatusSeeOther) return } q := r.URL.Query() require.Equal(t, q.Get("error"), "", q.Get("error_description")) code := q.Get("code") tc.handleCode(t, ctx, oauth2Client.config, code) w.WriteHeader(http.StatusOK) })) defer oauth2Client.server.Close() redirectURL := oauth2Client.server.URL + "/callback" client := storage.Client{ ID: "testclient", Secret: "testclientsecret", RedirectURIs: []string{redirectURL}, } err = s.storage.CreateClient(ctx, client) require.NoError(t, err) oauth2Client.config = &oauth2.Config{ ClientID: client.ID, ClientSecret: client.Secret, Endpoint: p.Endpoint(), Scopes: []string{oidc.ScopeOpenID, "email", "offline_access"}, RedirectURL: redirectURL, } resp, err := http.Get(oauth2Client.server.URL + "/login") require.NoError(t, err) resp.Body.Close() }) } } func mockConnectorDataTestStorage(t *testing.T, s storage.Storage) { ctx := t.Context() c := storage.Client{ ID: "test", Secret: "barfoo", RedirectURIs: []string{"foo://bar.com/", "https://auth.example.com"}, Name: "dex client", LogoURL: "https://goo.gl/JIyzIC", } err := s.CreateClient(ctx, c) require.NoError(t, err) c1 := storage.Connector{ ID: "test", Type: "mockPassword", Name: "mockPassword", Config: []byte(`{ "username": "test", "password": "test" }`), } err = s.CreateConnector(ctx, c1) require.NoError(t, err) c2 := storage.Connector{ ID: "http://any.valid.url/", Type: "mock", Name: "mockURLID", } err = s.CreateConnector(ctx, c2) require.NoError(t, err) } func TestHandlePassword(t *testing.T) { ctx := t.Context() tests := []struct { name string scopes string offlineSessionCreated bool }{ { name: "Password login, request refresh token", scopes: "openid offline_access email", offlineSessionCreated: true, }, { name: "Password login", scopes: "openid email", offlineSessionCreated: false, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { // Setup a dex server. httpServer, s := newTestServer(t, func(c *Config) { c.PasswordConnector = "test" c.Now = time.Now }) defer httpServer.Close() mockConnectorDataTestStorage(t, s.storage) makeReq := func(username, password string) *httptest.ResponseRecorder { u, err := url.Parse(s.issuerURL.String()) require.NoError(t, err) u.Path = path.Join(u.Path, "/token") v := url.Values{} v.Add("scope", tc.scopes) v.Add("grant_type", "password") v.Add("username", username) v.Add("password", password) req, _ := http.NewRequest("POST", u.String(), bytes.NewBufferString(v.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") req.SetBasicAuth("test", "barfoo") rr := httptest.NewRecorder() s.ServeHTTP(rr, req) return rr } // Check unauthorized error { rr := makeReq("test", "invalid") require.Equal(t, 401, rr.Code) } // Check that we received expected refresh token { rr := makeReq("test", "test") require.Equal(t, 200, rr.Code) var ref struct { Token string `json:"refresh_token"` } err := json.Unmarshal(rr.Body.Bytes(), &ref) require.NoError(t, err) newSess, err := s.storage.GetOfflineSessions(ctx, "0-385-28089-0", "test") if tc.offlineSessionCreated { require.NoError(t, err) require.Equal(t, `{"test": "true"}`, string(newSess.ConnectorData)) } else { require.Error(t, storage.ErrNotFound, err) } } }) } } func TestHandlePassword_LocalPasswordDBClaims(t *testing.T) { ctx := t.Context() // Setup a dex server. httpServer, s := newTestServer(t, func(c *Config) { c.PasswordConnector = "local" }) defer httpServer.Close() // Client credentials for password grant. client := storage.Client{ ID: "test", Secret: "barfoo", RedirectURIs: []string{"foo://bar.com/", "https://auth.example.com"}, } require.NoError(t, s.storage.CreateClient(ctx, client)) // Enable local connector. localConn := storage.Connector{ ID: "local", Type: LocalConnector, Name: "Email", ResourceVersion: "1", } require.NoError(t, s.storage.CreateConnector(ctx, localConn)) _, err := s.OpenConnector(localConn) require.NoError(t, err) // Create a user in the password DB with groups and preferred_username. pw := "secret" hash, err := bcrypt.GenerateFromPassword([]byte(pw), bcrypt.DefaultCost) require.NoError(t, err) require.NoError(t, s.storage.CreatePassword(ctx, storage.Password{ Email: "user@example.com", Username: "user-login", Name: "User Full Name", EmailVerified: boolPtr(false), PreferredUsername: "user-public", UserID: "user-id", Groups: []string{"team-a", "team-a/admins"}, Hash: hash, })) u, err := url.Parse(s.issuerURL.String()) require.NoError(t, err) u.Path = path.Join(u.Path, "/token") v := url.Values{} v.Add("scope", "openid profile email groups") v.Add("grant_type", "password") v.Add("username", "user@example.com") v.Add("password", pw) req, _ := http.NewRequest("POST", u.String(), bytes.NewBufferString(v.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.SetBasicAuth("test", "barfoo") rr := httptest.NewRecorder() s.ServeHTTP(rr, req) require.Equal(t, http.StatusOK, rr.Code) var tokenResponse struct { IDToken string `json:"id_token"` } require.NoError(t, json.Unmarshal(rr.Body.Bytes(), &tokenResponse)) require.NotEmpty(t, tokenResponse.IDToken) p, err := oidc.NewProvider(ctx, httpServer.URL) require.NoError(t, err) idToken, err := p.Verifier(&oidc.Config{SkipClientIDCheck: true}).Verify(ctx, tokenResponse.IDToken) require.NoError(t, err) var claims struct { Name string `json:"name"` EmailVerified bool `json:"email_verified"` PreferredUsername string `json:"preferred_username"` Groups []string `json:"groups"` } require.NoError(t, idToken.Claims(&claims)) require.Equal(t, "User Full Name", claims.Name) require.False(t, claims.EmailVerified) require.Equal(t, "user-public", claims.PreferredUsername) require.Equal(t, []string{"team-a", "team-a/admins"}, claims.Groups) } func TestHandlePasswordLoginWithSkipApproval(t *testing.T) { ctx := t.Context() connID := "mockPw" authReqID := "test" expiry := time.Now().Add(100 * time.Second) resTypes := []string{responseTypeCode} tests := []struct { name string skipApproval bool authReq storage.AuthRequest expectedRes string offlineSessionCreated bool }{ { name: "Force approval", skipApproval: false, authReq: storage.AuthRequest{ ID: authReqID, ConnectorID: connID, RedirectURI: "cb", Expiry: expiry, ResponseTypes: resTypes, ForceApprovalPrompt: true, }, expectedRes: "/approval", offlineSessionCreated: false, }, { name: "Skip approval by server config", skipApproval: true, authReq: storage.AuthRequest{ ID: authReqID, ConnectorID: connID, RedirectURI: "cb", Expiry: expiry, ResponseTypes: resTypes, ForceApprovalPrompt: true, }, expectedRes: "/approval", offlineSessionCreated: false, }, { name: "No skip", skipApproval: false, authReq: storage.AuthRequest{ ID: authReqID, ConnectorID: connID, RedirectURI: "cb", Expiry: expiry, ResponseTypes: resTypes, ForceApprovalPrompt: false, }, expectedRes: "/approval", offlineSessionCreated: false, }, { name: "Skip approval", skipApproval: true, authReq: storage.AuthRequest{ ID: authReqID, ConnectorID: connID, RedirectURI: "cb", Expiry: expiry, ResponseTypes: resTypes, ForceApprovalPrompt: false, }, expectedRes: "/auth/mockPw/cb", offlineSessionCreated: false, }, { name: "Force approval, request refresh token", skipApproval: false, authReq: storage.AuthRequest{ ID: authReqID, ConnectorID: connID, RedirectURI: "cb", Expiry: expiry, ResponseTypes: resTypes, ForceApprovalPrompt: true, Scopes: []string{"offline_access"}, }, expectedRes: "/approval", offlineSessionCreated: true, }, { name: "Skip approval, request refresh token", skipApproval: true, authReq: storage.AuthRequest{ ID: authReqID, ConnectorID: connID, RedirectURI: "cb", Expiry: expiry, ResponseTypes: resTypes, ForceApprovalPrompt: false, Scopes: []string{"offline_access"}, }, expectedRes: "/auth/mockPw/cb", offlineSessionCreated: true, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { httpServer, s := newTestServer(t, func(c *Config) { c.SkipApprovalScreen = tc.skipApproval c.Now = time.Now }) defer httpServer.Close() sc := storage.Connector{ ID: connID, Type: "mockPassword", Name: "MockPassword", ResourceVersion: "1", Config: []byte("{\"username\": \"foo\", \"password\": \"password\"}"), } if err := s.storage.CreateConnector(ctx, sc); err != nil { t.Fatalf("create connector: %v", err) } if _, err := s.OpenConnector(sc); err != nil { t.Fatalf("open connector: %v", err) } if err := s.storage.CreateAuthRequest(ctx, tc.authReq); err != nil { t.Fatalf("failed to create AuthRequest: %v", err) } rr := httptest.NewRecorder() path := fmt.Sprintf("/auth/%s/login?state=%s&back=&login=foo&password=password", connID, authReqID) s.handlePasswordLogin(rr, httptest.NewRequest("POST", path, nil)) require.Equal(t, 303, rr.Code) resp := rr.Result() defer resp.Body.Close() cb, _ := url.Parse(resp.Header.Get("Location")) require.Equal(t, tc.expectedRes, cb.Path) offlineSession, err := s.storage.GetOfflineSessions(ctx, "0-385-28089-0", connID) if tc.offlineSessionCreated { require.NoError(t, err) require.NotEmpty(t, offlineSession) } else { require.Error(t, storage.ErrNotFound, err) } }) } } func TestHandleConnectorCallbackWithSkipApproval(t *testing.T) { ctx := t.Context() connID := "mock" authReqID := "test" expiry := time.Now().Add(100 * time.Second) resTypes := []string{responseTypeCode} tests := []struct { name string skipApproval bool authReq storage.AuthRequest expectedRes string offlineSessionCreated bool }{ { name: "Force approval", skipApproval: false, authReq: storage.AuthRequest{ ID: authReqID, ConnectorID: connID, RedirectURI: "cb", Expiry: expiry, ResponseTypes: resTypes, ForceApprovalPrompt: true, }, expectedRes: "/approval", offlineSessionCreated: false, }, { name: "Skip approval by server config", skipApproval: true, authReq: storage.AuthRequest{ ID: authReqID, ConnectorID: connID, RedirectURI: "cb", Expiry: expiry, ResponseTypes: resTypes, ForceApprovalPrompt: true, }, expectedRes: "/approval", offlineSessionCreated: false, }, { name: "Skip approval by auth request", skipApproval: false, authReq: storage.AuthRequest{ ID: authReqID, ConnectorID: connID, RedirectURI: "cb", Expiry: expiry, ResponseTypes: resTypes, ForceApprovalPrompt: false, }, expectedRes: "/approval", offlineSessionCreated: false, }, { name: "Skip approval", skipApproval: true, authReq: storage.AuthRequest{ ID: authReqID, ConnectorID: connID, RedirectURI: "cb", Expiry: expiry, ResponseTypes: resTypes, ForceApprovalPrompt: false, }, expectedRes: "/callback/cb", offlineSessionCreated: false, }, { name: "Force approval, request refresh token", skipApproval: false, authReq: storage.AuthRequest{ ID: authReqID, ConnectorID: connID, RedirectURI: "cb", Expiry: expiry, ResponseTypes: resTypes, ForceApprovalPrompt: true, Scopes: []string{"offline_access"}, }, expectedRes: "/approval", offlineSessionCreated: true, }, { name: "Skip approval, request refresh token", skipApproval: true, authReq: storage.AuthRequest{ ID: authReqID, ConnectorID: connID, RedirectURI: "cb", Expiry: expiry, ResponseTypes: resTypes, ForceApprovalPrompt: false, Scopes: []string{"offline_access"}, }, expectedRes: "/callback/cb", offlineSessionCreated: false, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { httpServer, s := newTestServer(t, func(c *Config) { c.SkipApprovalScreen = tc.skipApproval c.Now = time.Now }) defer httpServer.Close() if err := s.storage.CreateAuthRequest(ctx, tc.authReq); err != nil { t.Fatalf("failed to create AuthRequest: %v", err) } rr := httptest.NewRecorder() path := fmt.Sprintf("/callback/%s?state=%s", connID, authReqID) s.handleConnectorCallback(rr, httptest.NewRequest("GET", path, nil)) require.Equal(t, 303, rr.Code) resp := rr.Result() defer resp.Body.Close() cb, _ := url.Parse(resp.Header.Get("Location")) require.Equal(t, tc.expectedRes, cb.Path) offlineSession, err := s.storage.GetOfflineSessions(ctx, "0-385-28089-0", connID) if tc.offlineSessionCreated { require.NoError(t, err) require.NotEmpty(t, offlineSession) } else { require.Error(t, storage.ErrNotFound, err) } }) } } func TestHandleTokenExchange(t *testing.T) { tests := []struct { name string scope string requestedTokenType string subjectTokenType string subjectToken string expectedCode int expectedTokenType string }{ { "id-for-acccess", "openid", tokenTypeAccess, tokenTypeID, "foobar", http.StatusOK, tokenTypeAccess, }, { "id-for-id", "openid", tokenTypeID, tokenTypeID, "foobar", http.StatusOK, tokenTypeID, }, { "id-for-default", "openid", "", tokenTypeID, "foobar", http.StatusOK, tokenTypeAccess, }, { "access-for-access", "openid", tokenTypeAccess, tokenTypeAccess, "foobar", http.StatusOK, tokenTypeAccess, }, { "missing-subject_token_type", "openid", tokenTypeAccess, "", "foobar", http.StatusBadRequest, "", }, { "missing-subject_token", "openid", tokenTypeAccess, tokenTypeAccess, "", http.StatusBadRequest, "", }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { ctx := t.Context() httpServer, s := newTestServer(t, func(c *Config) { c.Storage.CreateClient(ctx, storage.Client{ ID: "client_1", Secret: "secret_1", }) }) defer httpServer.Close() vals := make(url.Values) vals.Set("grant_type", grantTypeTokenExchange) setNonEmpty(vals, "connector_id", "mock") setNonEmpty(vals, "scope", tc.scope) setNonEmpty(vals, "requested_token_type", tc.requestedTokenType) setNonEmpty(vals, "subject_token_type", tc.subjectTokenType) setNonEmpty(vals, "subject_token", tc.subjectToken) setNonEmpty(vals, "client_id", "client_1") setNonEmpty(vals, "client_secret", "secret_1") rr := httptest.NewRecorder() req := httptest.NewRequest(http.MethodPost, httpServer.URL+"/token", strings.NewReader(vals.Encode())) req.Header.Set("content-type", "application/x-www-form-urlencoded") s.handleToken(rr, req) require.Equal(t, tc.expectedCode, rr.Code, rr.Body.String()) require.Equal(t, "application/json", rr.Result().Header.Get("content-type")) if tc.expectedCode == http.StatusOK { var res accessTokenResponse err := json.NewDecoder(rr.Result().Body).Decode(&res) require.NoError(t, err) require.Equal(t, tc.expectedTokenType, res.IssuedTokenType) } }) } } func setNonEmpty(vals url.Values, key, value string) { if value != "" { vals.Set(key, value) } }