Browse Source

Add context to storage's Create endpoints (#2935)

* Initial commit

Signed-off-by: PumpkinSeed <qwer.kocka@gmail.com>

* Finish the syntex fixes

Signed-off-by: PumpkinSeed <qwer.kocka@gmail.com>

* Add fixes after running the tests

Signed-off-by: PumpkinSeed <qwer.kocka@gmail.com>

* Change background context to request context

Signed-off-by: PumpkinSeed <qwer.kocka@gmail.com>

---------

Signed-off-by: PumpkinSeed <qwer.kocka@gmail.com>
v2.38.x
Ferenc Fabian 2 years ago committed by GitHub
parent
commit
2377b0a0cd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 4
      server/api.go
  2. 4
      server/api_test.go
  3. 8
      server/deviceflowhandlers.go
  4. 14
      server/deviceflowhandlers_test.go
  5. 38
      server/handlers.go
  6. 17
      server/handlers_test.go
  7. 9
      server/refreshhandlers_test.go
  8. 29
      server/server_test.go
  9. 72
      storage/conformance/conformance.go
  10. 10
      storage/conformance/transactions.go
  11. 4
      storage/ent/client/authcode.go
  12. 4
      storage/ent/client/authrequest.go
  13. 4
      storage/ent/client/client.go
  14. 4
      storage/ent/client/connector.go
  15. 4
      storage/ent/client/devicerequest.go
  16. 4
      storage/ent/client/devicetoken.go
  17. 4
      storage/ent/client/offlinesession.go
  18. 4
      storage/ent/client/password.go
  19. 4
      storage/ent/client/refreshtoken.go
  20. 38
      storage/etcd/etcd.go
  21. 4
      storage/health.go
  22. 20
      storage/kubernetes/storage.go
  23. 3
      storage/kubernetes/storage_test.go
  24. 21
      storage/memory/memory.go
  25. 18
      storage/memory/static_test.go
  26. 23
      storage/sql/crud.go
  27. 13
      storage/static.go
  28. 19
      storage/storage.go

4
server/api.go

@ -85,7 +85,7 @@ func (d dexAPI) CreateClient(ctx context.Context, req *api.CreateClientReq) (*ap
Name: req.Client.Name,
LogoURL: req.Client.LogoUrl,
}
if err := d.s.CreateClient(c); err != nil {
if err := d.s.CreateClient(ctx, c); err != nil {
if err == storage.ErrAlreadyExists {
return &api.CreateClientResp{AlreadyExists: true}, nil
}
@ -177,7 +177,7 @@ func (d dexAPI) CreatePassword(ctx context.Context, req *api.CreatePasswordReq)
Username: req.Password.Username,
UserID: req.Password.UserId,
}
if err := d.s.CreatePassword(p); err != nil {
if err := d.s.CreatePassword(ctx, p); err != nil {
if err == storage.ErrAlreadyExists {
return &api.CreatePasswordResp{AlreadyExists: true}, nil
}

4
server/api_test.go

@ -262,7 +262,7 @@ func TestRefreshToken(t *testing.T) {
ConnectorData: []byte(`{"some":"data"}`),
}
if err := s.CreateRefresh(r); err != nil {
if err := s.CreateRefresh(ctx, r); err != nil {
t.Fatalf("create refresh token: %v", err)
}
@ -280,7 +280,7 @@ func TestRefreshToken(t *testing.T) {
}
session.Refresh[tokenRef.ClientID] = &tokenRef
if err := s.CreateOfflineSessions(session); err != nil {
if err := s.CreateOfflineSessions(ctx, session); err != nil {
t.Fatalf("create offline session: %v", err)
}

8
server/deviceflowhandlers.go

@ -58,6 +58,7 @@ func (s *Server) handleDeviceExchange(w http.ResponseWriter, r *http.Request) {
}
func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
pollIntervalSeconds := 5
switch r.Method {
@ -106,7 +107,7 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
Expiry: expireTime,
}
if err := s.storage.CreateDeviceRequest(deviceReq); err != nil {
if err := s.storage.CreateDeviceRequest(ctx, deviceReq); err != nil {
s.logger.Errorf("Failed to store device request; %v", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
return
@ -125,7 +126,7 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
},
}
if err := s.storage.CreateDeviceToken(deviceToken); err != nil {
if err := s.storage.CreateDeviceToken(ctx, deviceToken); err != nil {
s.logger.Errorf("Failed to store device token %v", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
return
@ -280,6 +281,7 @@ func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
}
func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
switch r.Method {
case http.MethodGet:
userCode := r.FormValue("state")
@ -336,7 +338,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
return
}
resp, err := s.exchangeAuthCode(w, authCode, client)
resp, err := s.exchangeAuthCode(ctx, w, authCode, client)
if err != nil {
s.logger.Errorf("Could not exchange auth code for client %q: %v", deviceReq.ClientID, err)
s.renderError(r, w, http.StatusInternalServerError, "Failed to exchange auth code.")

14
server/deviceflowhandlers_test.go

@ -366,15 +366,15 @@ func TestDeviceCallback(t *testing.T) {
})
defer httpServer.Close()
if err := s.storage.CreateAuthCode(tc.testAuthCode); err != nil {
if err := s.storage.CreateAuthCode(ctx, tc.testAuthCode); err != nil {
t.Fatalf("failed to create auth code: %v", err)
}
if err := s.storage.CreateDeviceRequest(tc.testDeviceRequest); err != nil {
if err := s.storage.CreateDeviceRequest(ctx, tc.testDeviceRequest); err != nil {
t.Fatalf("failed to create device request: %v", err)
}
if err := s.storage.CreateDeviceToken(tc.testDeviceToken); err != nil {
if err := s.storage.CreateDeviceToken(ctx, tc.testDeviceToken); err != nil {
t.Fatalf("failed to create device token: %v", err)
}
@ -383,7 +383,7 @@ func TestDeviceCallback(t *testing.T) {
Secret: "",
RedirectURIs: []string{deviceCallbackURI},
}
if err := s.storage.CreateClient(client); err != nil {
if err := s.storage.CreateClient(ctx, client); err != nil {
t.Fatalf("failed to create client: %v", err)
}
@ -660,11 +660,11 @@ func TestDeviceTokenResponse(t *testing.T) {
})
defer httpServer.Close()
if err := s.storage.CreateDeviceRequest(tc.testDeviceRequest); err != nil {
if err := s.storage.CreateDeviceRequest(ctx, tc.testDeviceRequest); err != nil {
t.Fatalf("Failed to store device token %v", err)
}
if err := s.storage.CreateDeviceToken(tc.testDeviceToken); err != nil {
if err := s.storage.CreateDeviceToken(ctx, tc.testDeviceToken); err != nil {
t.Fatalf("Failed to store device token %v", err)
}
@ -794,7 +794,7 @@ func TestVerifyCodeResponse(t *testing.T) {
})
defer httpServer.Close()
if err := s.storage.CreateDeviceRequest(tc.testDeviceRequest); err != nil {
if err := s.storage.CreateDeviceRequest(ctx, tc.testDeviceRequest); err != nil {
t.Fatalf("Failed to store device token %v", err)
}

38
server/handlers.go

@ -1,6 +1,7 @@
package server
import (
"context"
"crypto/hmac"
"crypto/sha256"
"crypto/subtle"
@ -187,6 +188,7 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
}
func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
authReq, err := s.parseAuthorizationRequest(r)
if err != nil {
s.logger.Errorf("Failed to parse authorization request: %v", err)
@ -229,7 +231,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
// Actually create the auth request
authReq.Expiry = s.now().Add(s.authRequestsValidFor)
if err := s.storage.CreateAuthRequest(*authReq); err != nil {
if err := s.storage.CreateAuthRequest(ctx, *authReq); err != nil {
s.logger.Errorf("Failed to create authorization request: %v", err)
s.renderError(r, w, http.StatusInternalServerError, "Failed to connect to the database.")
return
@ -305,6 +307,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
}
func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
authID := r.URL.Query().Get("state")
if authID == "" {
s.renderError(r, w, http.StatusBadRequest, "User session error.")
@ -360,7 +363,7 @@ func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) {
password := r.FormValue("password")
scopes := parseScopes(authReq.Scopes)
identity, ok, err := pwConn.Login(r.Context(), scopes, username, password)
identity, ok, err := pwConn.Login(ctx, scopes, username, password)
if err != nil {
s.logger.Errorf("Failed to login user: %v", err)
s.renderError(r, w, http.StatusInternalServerError, fmt.Sprintf("Login error: %v", err))
@ -372,7 +375,7 @@ func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) {
}
return
}
redirectURL, canSkipApproval, err := s.finalizeLogin(identity, authReq, conn.Connector)
redirectURL, canSkipApproval, err := s.finalizeLogin(ctx, identity, authReq, conn.Connector)
if err != nil {
s.logger.Errorf("Failed to finalize login: %v", err)
s.renderError(r, w, http.StatusInternalServerError, "Login error.")
@ -397,6 +400,7 @@ func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) {
}
func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
var authID string
switch r.Method {
case http.MethodGet: // OAuth2 callback
@ -471,7 +475,7 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request)
return
}
redirectURL, canSkipApproval, err := s.finalizeLogin(identity, authReq, conn.Connector)
redirectURL, canSkipApproval, err := s.finalizeLogin(ctx, identity, authReq, conn.Connector)
if err != nil {
s.logger.Errorf("Failed to finalize login: %v", err)
s.renderError(r, w, http.StatusInternalServerError, "Login error.")
@ -494,7 +498,7 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request)
// finalizeLogin associates the user's identity with the current AuthRequest, then returns
// the approval page's path.
func (s *Server) finalizeLogin(identity connector.Identity, authReq storage.AuthRequest, conn connector.Connector) (string, bool, error) {
func (s *Server) finalizeLogin(ctx context.Context, identity connector.Identity, authReq storage.AuthRequest, conn connector.Connector) (string, bool, error) {
claims := storage.Claims{
UserID: identity.UserID,
Username: identity.Username,
@ -566,7 +570,7 @@ func (s *Server) finalizeLogin(identity connector.Identity, authReq storage.Auth
// Create a new OfflineSession object for the user and add a reference object for
// the newly received refreshtoken.
if err := s.storage.CreateOfflineSessions(offlineSessions); err != nil {
if err := s.storage.CreateOfflineSessions(ctx, offlineSessions); err != nil {
s.logger.Errorf("failed to create offline session: %v", err)
return "", false, err
}
@ -649,6 +653,7 @@ func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) {
}
func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authReq storage.AuthRequest) {
ctx := r.Context()
if s.now().After(authReq.Expiry) {
s.renderError(r, w, http.StatusBadRequest, "User session has expired.")
return
@ -701,7 +706,7 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe
ConnectorData: authReq.ConnectorData,
PKCE: authReq.PKCE,
}
if err := s.storage.CreateAuthCode(code); err != nil {
if err := s.storage.CreateAuthCode(ctx, code); err != nil {
s.logger.Errorf("Failed to create auth code: %v", err)
s.renderError(r, w, http.StatusInternalServerError, "Internal server error.")
return
@ -876,6 +881,7 @@ func (s *Server) calculateCodeChallenge(codeVerifier, codeChallengeMethod string
// handle an access token request https://tools.ietf.org/html/rfc6749#section-4.1.3
func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client storage.Client) {
ctx := r.Context()
code := r.PostFormValue("code")
redirectURI := r.PostFormValue("redirect_uri")
@ -926,7 +932,7 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
return
}
tokenResponse, err := s.exchangeAuthCode(w, authCode, client)
tokenResponse, err := s.exchangeAuthCode(ctx, w, authCode, client)
if err != nil {
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
@ -934,7 +940,7 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
s.writeAccessToken(w, tokenResponse)
}
func (s *Server) exchangeAuthCode(w http.ResponseWriter, authCode storage.AuthCode, client storage.Client) (*accessTokenResponse, error) {
func (s *Server) exchangeAuthCode(ctx context.Context, w http.ResponseWriter, authCode storage.AuthCode, client storage.Client) (*accessTokenResponse, error) {
accessToken, _, err := s.newAccessToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, authCode.ConnectorID)
if err != nil {
s.logger.Errorf("failed to create new access token: %v", err)
@ -1002,7 +1008,7 @@ func (s *Server) exchangeAuthCode(w http.ResponseWriter, authCode storage.AuthCo
return nil, err
}
if err := s.storage.CreateRefresh(refresh); err != nil {
if err := s.storage.CreateRefresh(ctx, refresh); err != nil {
s.logger.Errorf("failed to create refresh token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return nil, err
@ -1047,7 +1053,7 @@ func (s *Server) exchangeAuthCode(w http.ResponseWriter, authCode storage.AuthCo
// Create a new OfflineSession object for the user and add a reference object for
// the newly received refreshtoken.
if err := s.storage.CreateOfflineSessions(offlineSessions); err != nil {
if err := s.storage.CreateOfflineSessions(ctx, offlineSessions); err != nil {
s.logger.Errorf("failed to create offline session: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
deleteToken = true
@ -1080,6 +1086,7 @@ func (s *Server) exchangeAuthCode(w http.ResponseWriter, authCode storage.AuthCo
}
func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
const prefix = "Bearer "
auth := r.Header.Get("authorization")
@ -1091,7 +1098,7 @@ func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) {
rawIDToken := auth[len(prefix):]
verifier := oidc.NewVerifier(s.issuerURL.String(), &storageKeySet{s.storage}, &oidc.Config{SkipClientIDCheck: true})
idToken, err := verifier.Verify(r.Context(), rawIDToken)
idToken, err := verifier.Verify(ctx, rawIDToken)
if err != nil {
s.tokenErrHelper(w, errAccessDenied, err.Error(), http.StatusForbidden)
return
@ -1108,6 +1115,7 @@ func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) {
}
func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, client storage.Client) {
ctx := r.Context()
// Parse the fields
if err := r.ParseForm(); err != nil {
s.tokenErrHelper(w, errInvalidRequest, "Couldn't parse data", http.StatusBadRequest)
@ -1177,7 +1185,7 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli
// Login
username := q.Get("username")
password := q.Get("password")
identity, ok, err := passwordConnector.Login(r.Context(), parseScopes(scopes), username, password)
identity, ok, err := passwordConnector.Login(ctx, parseScopes(scopes), username, password)
if err != nil {
s.logger.Errorf("Failed to login user: %v", err)
s.tokenErrHelper(w, errInvalidRequest, "Could not login user", http.StatusBadRequest)
@ -1252,7 +1260,7 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli
return
}
if err := s.storage.CreateRefresh(refresh); err != nil {
if err := s.storage.CreateRefresh(ctx, refresh); err != nil {
s.logger.Errorf("failed to create refresh token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
@ -1298,7 +1306,7 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli
// Create a new OfflineSession object for the user and add a reference object for
// the newly received refreshtoken.
if err := s.storage.CreateOfflineSessions(offlineSessions); err != nil {
if err := s.storage.CreateOfflineSessions(ctx, offlineSessions); err != nil {
s.logger.Errorf("failed to create offline session: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
deleteToken = true

17
server/handlers_test.go

@ -213,7 +213,7 @@ func TestHandleAuthCode(t *testing.T) {
Secret: "testclientsecret",
RedirectURIs: []string{redirectURL},
}
err = s.storage.CreateClient(client)
err = s.storage.CreateClient(ctx, client)
require.NoError(t, err)
oauth2Client.config = &oauth2.Config{
@ -233,6 +233,7 @@ func TestHandleAuthCode(t *testing.T) {
}
func mockConnectorDataTestStorage(t *testing.T, s storage.Storage) {
ctx := context.Background()
c := storage.Client{
ID: "test",
Secret: "barfoo",
@ -241,7 +242,7 @@ func mockConnectorDataTestStorage(t *testing.T, s storage.Storage) {
LogoURL: "https://goo.gl/JIyzIC",
}
err := s.CreateClient(c)
err := s.CreateClient(ctx, c)
require.NoError(t, err)
c1 := storage.Connector{
@ -254,7 +255,7 @@ func mockConnectorDataTestStorage(t *testing.T, s storage.Storage) {
}`),
}
err = s.CreateConnector(c1)
err = s.CreateConnector(ctx, c1)
require.NoError(t, err)
c2 := storage.Connector{
@ -263,7 +264,7 @@ func mockConnectorDataTestStorage(t *testing.T, s storage.Storage) {
Name: "mockURLID",
}
err = s.CreateConnector(c2)
err = s.CreateConnector(ctx, c2)
require.NoError(t, err)
}
@ -467,13 +468,13 @@ func TestHandlePasswordLoginWithSkipApproval(t *testing.T) {
ResourceVersion: "1",
Config: []byte("{\"username\": \"foo\", \"password\": \"password\"}"),
}
if err := s.storage.CreateConnector(sc); err != nil {
if err := s.storage.CreateConnector(ctx, sc); err != nil {
t.Fatalf("create connector: %v", err)
}
if _, err := s.OpenConnector(sc); err != nil {
t.Fatalf("open connector: %v", err)
}
if err := s.storage.CreateAuthRequest(tc.authReq); err != nil {
if err := s.storage.CreateAuthRequest(ctx, tc.authReq); err != nil {
t.Fatalf("failed to create AuthRequest: %v", err)
}
@ -614,7 +615,7 @@ func TestHandleConnectorCallbackWithSkipApproval(t *testing.T) {
})
defer httpServer.Close()
if err := s.storage.CreateAuthRequest(tc.authReq); err != nil {
if err := s.storage.CreateAuthRequest(ctx, tc.authReq); err != nil {
t.Fatalf("failed to create AuthRequest: %v", err)
}
rr := httptest.NewRecorder()
@ -712,7 +713,7 @@ func TestHandleTokenExchange(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
httpServer, s := newTestServer(ctx, t, func(c *Config) {
c.Storage.CreateClient(storage.Client{
c.Storage.CreateClient(ctx, storage.Client{
ID: "client_1",
Secret: "secret_1",
})

9
server/refreshhandlers_test.go

@ -18,6 +18,7 @@ import (
)
func mockRefreshTokenTestStorage(t *testing.T, s storage.Storage, useObsolete bool) {
ctx := context.Background()
c := storage.Client{
ID: "test",
Secret: "barfoo",
@ -26,7 +27,7 @@ func mockRefreshTokenTestStorage(t *testing.T, s storage.Storage, useObsolete bo
LogoURL: "https://goo.gl/JIyzIC",
}
err := s.CreateClient(c)
err := s.CreateClient(ctx, c)
require.NoError(t, err)
c1 := storage.Connector{
@ -36,7 +37,7 @@ func mockRefreshTokenTestStorage(t *testing.T, s storage.Storage, useObsolete bo
Config: nil,
}
err = s.CreateConnector(c1)
err = s.CreateConnector(ctx, c1)
require.NoError(t, err)
refresh := storage.RefreshToken{
@ -64,7 +65,7 @@ func mockRefreshTokenTestStorage(t *testing.T, s storage.Storage, useObsolete bo
refresh.ObsoleteToken = "bar"
}
err = s.CreateRefresh(refresh)
err = s.CreateRefresh(ctx, refresh)
require.NoError(t, err)
offlineSessions := storage.OfflineSessions{
@ -74,7 +75,7 @@ func mockRefreshTokenTestStorage(t *testing.T, s storage.Storage, useObsolete bo
ConnectorData: nil,
}
err = s.CreateOfflineSessions(offlineSessions)
err = s.CreateOfflineSessions(ctx, offlineSessions)
require.NoError(t, err)
}

29
server/server_test.go

@ -119,7 +119,7 @@ func newTestServer(ctx context.Context, t *testing.T, updateConfig func(c *Confi
Name: "Mock",
ResourceVersion: "1",
}
if err := config.Storage.CreateConnector(connector); err != nil {
if err := config.Storage.CreateConnector(ctx, connector); err != nil {
t.Fatalf("create connector: %v", err)
}
@ -172,10 +172,10 @@ func newTestServerMultipleConnectors(ctx context.Context, t *testing.T, updateCo
Name: "Mock",
ResourceVersion: "1",
}
if err := config.Storage.CreateConnector(connector); err != nil {
if err := config.Storage.CreateConnector(ctx, connector); err != nil {
t.Fatalf("create connector: %v", err)
}
if err := config.Storage.CreateConnector(connector2); err != nil {
if err := config.Storage.CreateConnector(ctx, connector2); err != nil {
t.Fatalf("create connector: %v", err)
}
@ -837,11 +837,11 @@ func TestOAuth2CodeFlow(t *testing.T) {
Secret: clientSecret,
RedirectURIs: []string{redirectURL},
}
if err := s.storage.CreateClient(client); err != nil {
if err := s.storage.CreateClient(ctx, client); err != nil {
t.Fatalf("failed to create client: %v", err)
}
if err := s.storage.CreateRefresh(storage.RefreshToken{
if err := s.storage.CreateRefresh(ctx, storage.RefreshToken{
ID: "existedrefrestoken",
ClientID: "unexcistedclientid",
}); err != nil {
@ -955,7 +955,7 @@ func TestOAuth2ImplicitFlow(t *testing.T) {
Secret: "testclientsecret",
RedirectURIs: []string{redirectURL},
}
if err := s.storage.CreateClient(client); err != nil {
if err := s.storage.CreateClient(ctx, client); err != nil {
t.Fatalf("failed to create client: %v", err)
}
@ -1113,7 +1113,7 @@ func TestCrossClientScopes(t *testing.T) {
Secret: "testclientsecret",
RedirectURIs: []string{redirectURL},
}
if err := s.storage.CreateClient(client); err != nil {
if err := s.storage.CreateClient(ctx, client); err != nil {
t.Fatalf("failed to create client: %v", err)
}
@ -1123,7 +1123,7 @@ func TestCrossClientScopes(t *testing.T) {
TrustedPeers: []string{"testclient"},
}
if err := s.storage.CreateClient(peer); err != nil {
if err := s.storage.CreateClient(ctx, peer); err != nil {
t.Fatalf("failed to create client: %v", err)
}
@ -1236,7 +1236,7 @@ func TestCrossClientScopesWithAzpInAudienceByDefault(t *testing.T) {
Secret: "testclientsecret",
RedirectURIs: []string{redirectURL},
}
if err := s.storage.CreateClient(client); err != nil {
if err := s.storage.CreateClient(ctx, client); err != nil {
t.Fatalf("failed to create client: %v", err)
}
@ -1246,7 +1246,7 @@ func TestCrossClientScopesWithAzpInAudienceByDefault(t *testing.T) {
TrustedPeers: []string{"testclient"},
}
if err := s.storage.CreateClient(peer); err != nil {
if err := s.storage.CreateClient(ctx, peer); err != nil {
t.Fatalf("failed to create client: %v", err)
}
@ -1276,6 +1276,7 @@ func TestCrossClientScopesWithAzpInAudienceByDefault(t *testing.T) {
}
func TestPasswordDB(t *testing.T) {
ctx := context.Background()
s := memory.New(logger)
conn := newPasswordDB(s)
@ -1286,7 +1287,7 @@ func TestPasswordDB(t *testing.T) {
t.Fatal(err)
}
s.CreatePassword(storage.Password{
s.CreatePassword(ctx, storage.Password{
Email: "jane@example.com",
Username: "jane",
UserID: "foobar",
@ -1534,7 +1535,7 @@ func TestRefreshTokenFlow(t *testing.T) {
Secret: "testclientsecret",
RedirectURIs: []string{redirectURL},
}
if err := s.storage.CreateClient(client); err != nil {
if err := s.storage.CreateClient(ctx, client); err != nil {
t.Fatalf("failed to create client: %v", err)
}
@ -1633,11 +1634,11 @@ func TestOAuth2DeviceFlow(t *testing.T) {
RedirectURIs: []string{deviceCallbackURI},
Public: true,
}
if err := s.storage.CreateClient(client); err != nil {
if err := s.storage.CreateClient(ctx, client); err != nil {
t.Fatalf("failed to create client: %v", err)
}
if err := s.storage.CreateRefresh(storage.RefreshToken{
if err := s.storage.CreateRefresh(ctx, storage.RefreshToken{
ID: "existedrefrestoken",
ClientID: "unexcistedclientid",
}); err != nil {

72
storage/conformance/conformance.go

@ -2,6 +2,7 @@
package conformance
import (
"context"
"reflect"
"sort"
"testing"
@ -80,6 +81,7 @@ func mustBeErrAlreadyExists(t *testing.T, kind string, err error) {
}
func testAuthRequestCRUD(t *testing.T, s storage.Storage) {
ctx := context.Background()
codeChallenge := storage.PKCE{
CodeChallenge: "code_challenge_test",
CodeChallengeMethod: "plain",
@ -111,12 +113,12 @@ func testAuthRequestCRUD(t *testing.T, s storage.Storage) {
identity := storage.Claims{Email: "foobar"}
if err := s.CreateAuthRequest(a1); err != nil {
if err := s.CreateAuthRequest(ctx, a1); err != nil {
t.Fatalf("failed creating auth request: %v", err)
}
// Attempt to create same AuthRequest twice.
err := s.CreateAuthRequest(a1)
err := s.CreateAuthRequest(ctx, a1)
mustBeErrAlreadyExists(t, "auth request", err)
a2 := storage.AuthRequest{
@ -142,7 +144,7 @@ func testAuthRequestCRUD(t *testing.T, s storage.Storage) {
HMACKey: []byte("hmac_key"),
}
if err := s.CreateAuthRequest(a2); err != nil {
if err := s.CreateAuthRequest(ctx, a2); err != nil {
t.Fatalf("failed creating auth request: %v", err)
}
@ -179,6 +181,7 @@ func testAuthRequestCRUD(t *testing.T, s storage.Storage) {
}
func testAuthCodeCRUD(t *testing.T, s storage.Storage) {
ctx := context.Background()
a1 := storage.AuthCode{
ID: storage.NewID(),
ClientID: "client1",
@ -201,7 +204,7 @@ func testAuthCodeCRUD(t *testing.T, s storage.Storage) {
},
}
if err := s.CreateAuthCode(a1); err != nil {
if err := s.CreateAuthCode(ctx, a1); err != nil {
t.Fatalf("failed creating auth code: %v", err)
}
@ -224,10 +227,10 @@ func testAuthCodeCRUD(t *testing.T, s storage.Storage) {
}
// Attempt to create same AuthCode twice.
err := s.CreateAuthCode(a1)
err := s.CreateAuthCode(ctx, a1)
mustBeErrAlreadyExists(t, "auth code", err)
if err := s.CreateAuthCode(a2); err != nil {
if err := s.CreateAuthCode(ctx, a2); err != nil {
t.Fatalf("failed creating auth code: %v", err)
}
@ -256,6 +259,7 @@ func testAuthCodeCRUD(t *testing.T, s storage.Storage) {
}
func testClientCRUD(t *testing.T, s storage.Storage) {
ctx := context.Background()
id1 := storage.NewID()
c1 := storage.Client{
ID: id1,
@ -267,12 +271,12 @@ func testClientCRUD(t *testing.T, s storage.Storage) {
err := s.DeleteClient(id1)
mustBeErrNotFound(t, "client", err)
if err := s.CreateClient(c1); err != nil {
if err := s.CreateClient(ctx, c1); err != nil {
t.Fatalf("create client: %v", err)
}
// Attempt to create same Client twice.
err = s.CreateClient(c1)
err = s.CreateClient(ctx, c1)
mustBeErrAlreadyExists(t, "client", err)
id2 := storage.NewID()
@ -284,7 +288,7 @@ func testClientCRUD(t *testing.T, s storage.Storage) {
LogoURL: "https://goo.gl/JIyzIC",
}
if err := s.CreateClient(c2); err != nil {
if err := s.CreateClient(ctx, c2); err != nil {
t.Fatalf("create client: %v", err)
}
@ -325,6 +329,7 @@ func testClientCRUD(t *testing.T, s storage.Storage) {
}
func testRefreshTokenCRUD(t *testing.T, s storage.Storage) {
ctx := context.Background()
id := storage.NewID()
refresh := storage.RefreshToken{
ID: id,
@ -345,12 +350,12 @@ func testRefreshTokenCRUD(t *testing.T, s storage.Storage) {
},
ConnectorData: []byte(`{"some":"data"}`),
}
if err := s.CreateRefresh(refresh); err != nil {
if err := s.CreateRefresh(ctx, refresh); err != nil {
t.Fatalf("create refresh token: %v", err)
}
// Attempt to create same Refresh Token twice.
err := s.CreateRefresh(refresh)
err := s.CreateRefresh(ctx, refresh)
mustBeErrAlreadyExists(t, "refresh token", err)
getAndCompare := func(id string, want storage.RefreshToken) {
@ -401,7 +406,7 @@ func testRefreshTokenCRUD(t *testing.T, s storage.Storage) {
ConnectorData: []byte(`{"some":"data"}`),
}
if err := s.CreateRefresh(refresh2); err != nil {
if err := s.CreateRefresh(ctx, refresh2); err != nil {
t.Fatalf("create second refresh token: %v", err)
}
@ -443,6 +448,7 @@ func (n byEmail) Less(i, j int) bool { return n[i].Email < n[j].Email }
func (n byEmail) Swap(i, j int) { n[i], n[j] = n[j], n[i] }
func testPasswordCRUD(t *testing.T, s storage.Storage) {
ctx := context.Background()
// Use bcrypt.MinCost to keep the tests short.
passwordHash1, err := bcrypt.GenerateFromPassword([]byte("secret"), bcrypt.MinCost)
if err != nil {
@ -455,12 +461,12 @@ func testPasswordCRUD(t *testing.T, s storage.Storage) {
Username: "jane",
UserID: "foobar",
}
if err := s.CreatePassword(password1); err != nil {
if err := s.CreatePassword(ctx, password1); err != nil {
t.Fatalf("create password token: %v", err)
}
// Attempt to create same Password twice.
err = s.CreatePassword(password1)
err = s.CreatePassword(ctx, password1)
mustBeErrAlreadyExists(t, "password", err)
passwordHash2, err := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.MinCost)
@ -474,7 +480,7 @@ func testPasswordCRUD(t *testing.T, s storage.Storage) {
Username: "john",
UserID: "barfoo",
}
if err := s.CreatePassword(password2); err != nil {
if err := s.CreatePassword(ctx, password2); err != nil {
t.Fatalf("create password token: %v", err)
}
@ -533,6 +539,7 @@ func testPasswordCRUD(t *testing.T, s storage.Storage) {
}
func testOfflineSessionCRUD(t *testing.T, s storage.Storage) {
ctx := context.Background()
userID1 := storage.NewID()
session1 := storage.OfflineSessions{
UserID: userID1,
@ -543,12 +550,12 @@ func testOfflineSessionCRUD(t *testing.T, s storage.Storage) {
// Creating an OfflineSession with an empty Refresh list to ensure that
// an empty map is translated as expected by the storage.
if err := s.CreateOfflineSessions(session1); err != nil {
if err := s.CreateOfflineSessions(ctx, session1); err != nil {
t.Fatalf("create offline session with UserID = %s: %v", session1.UserID, err)
}
// Attempt to create same OfflineSession twice.
err := s.CreateOfflineSessions(session1)
err := s.CreateOfflineSessions(ctx, session1)
mustBeErrAlreadyExists(t, "offline session", err)
userID2 := storage.NewID()
@ -559,7 +566,7 @@ func testOfflineSessionCRUD(t *testing.T, s storage.Storage) {
ConnectorData: []byte(`{"some":"data"}`),
}
if err := s.CreateOfflineSessions(session2); err != nil {
if err := s.CreateOfflineSessions(ctx, session2); err != nil {
t.Fatalf("create offline session with UserID = %s: %v", session2.UserID, err)
}
@ -607,6 +614,7 @@ func testOfflineSessionCRUD(t *testing.T, s storage.Storage) {
}
func testConnectorCRUD(t *testing.T, s storage.Storage) {
ctx := context.Background()
id1 := storage.NewID()
config1 := []byte(`{"issuer": "https://accounts.google.com"}`)
c1 := storage.Connector{
@ -616,12 +624,12 @@ func testConnectorCRUD(t *testing.T, s storage.Storage) {
Config: config1,
}
if err := s.CreateConnector(c1); err != nil {
if err := s.CreateConnector(ctx, c1); err != nil {
t.Fatalf("create connector with ID = %s: %v", c1.ID, err)
}
// Attempt to create same Connector twice.
err := s.CreateConnector(c1)
err := s.CreateConnector(ctx, c1)
mustBeErrAlreadyExists(t, "connector", err)
id2 := storage.NewID()
@ -633,7 +641,7 @@ func testConnectorCRUD(t *testing.T, s storage.Storage) {
Config: config2,
}
if err := s.CreateConnector(c2); err != nil {
if err := s.CreateConnector(ctx, c2); err != nil {
t.Fatalf("create connector with ID = %s: %v", c2.ID, err)
}
@ -744,6 +752,7 @@ func testKeysCRUD(t *testing.T, s storage.Storage) {
}
func testGC(t *testing.T, s storage.Storage) {
ctx := context.Background()
est, err := time.LoadLocation("America/New_York")
if err != nil {
t.Fatal(err)
@ -772,7 +781,7 @@ func testGC(t *testing.T, s storage.Storage) {
},
}
if err := s.CreateAuthCode(c); err != nil {
if err := s.CreateAuthCode(ctx, c); err != nil {
t.Fatalf("failed creating auth code: %v", err)
}
@ -823,7 +832,7 @@ func testGC(t *testing.T, s storage.Storage) {
HMACKey: []byte("hmac_key"),
}
if err := s.CreateAuthRequest(a); err != nil {
if err := s.CreateAuthRequest(ctx, a); err != nil {
t.Fatalf("failed creating auth request: %v", err)
}
@ -860,7 +869,7 @@ func testGC(t *testing.T, s storage.Storage) {
Expiry: expiry,
}
if err := s.CreateDeviceRequest(d); err != nil {
if err := s.CreateDeviceRequest(ctx, d); err != nil {
t.Fatalf("failed creating device request: %v", err)
}
@ -900,7 +909,7 @@ func testGC(t *testing.T, s storage.Storage) {
},
}
if err := s.CreateDeviceToken(dt); err != nil {
if err := s.CreateDeviceToken(ctx, dt); err != nil {
t.Fatalf("failed creating device token: %v", err)
}
@ -931,6 +940,7 @@ func testGC(t *testing.T, s storage.Storage) {
// testTimezones tests that backends either fully support timezones or
// do the correct standardization.
func testTimezones(t *testing.T, s storage.Storage) {
ctx := context.Background()
est, err := time.LoadLocation("America/New_York")
if err != nil {
t.Fatal(err)
@ -956,7 +966,7 @@ func testTimezones(t *testing.T, s storage.Storage) {
Groups: []string{"a", "b"},
},
}
if err := s.CreateAuthCode(c); err != nil {
if err := s.CreateAuthCode(ctx, c); err != nil {
t.Fatalf("failed creating auth code: %v", err)
}
got, err := s.GetAuthCode(c.ID)
@ -975,6 +985,7 @@ func testTimezones(t *testing.T, s storage.Storage) {
}
func testDeviceRequestCRUD(t *testing.T, s storage.Storage) {
ctx := context.Background()
d1 := storage.DeviceRequest{
UserCode: storage.NewUserCode(),
DeviceCode: storage.NewID(),
@ -984,12 +995,12 @@ func testDeviceRequestCRUD(t *testing.T, s storage.Storage) {
Expiry: neverExpire.Round(time.Second),
}
if err := s.CreateDeviceRequest(d1); err != nil {
if err := s.CreateDeviceRequest(ctx, d1); err != nil {
t.Fatalf("failed creating device request: %v", err)
}
// Attempt to create same DeviceRequest twice.
err := s.CreateDeviceRequest(d1)
err := s.CreateDeviceRequest(ctx, d1)
mustBeErrAlreadyExists(t, "device request", err)
got, err := s.GetDeviceRequest(d1.UserCode)
@ -1004,6 +1015,7 @@ func testDeviceRequestCRUD(t *testing.T, s storage.Storage) {
}
func testDeviceTokenCRUD(t *testing.T, s storage.Storage) {
ctx := context.Background()
codeChallenge := storage.PKCE{
CodeChallenge: "code_challenge_test",
CodeChallengeMethod: "plain",
@ -1020,12 +1032,12 @@ func testDeviceTokenCRUD(t *testing.T, s storage.Storage) {
PKCE: codeChallenge,
}
if err := s.CreateDeviceToken(d1); err != nil {
if err := s.CreateDeviceToken(ctx, d1); err != nil {
t.Fatalf("failed creating device token: %v", err)
}
// Attempt to create same Device Token twice.
err := s.CreateDeviceToken(d1)
err := s.CreateDeviceToken(ctx, d1)
mustBeErrAlreadyExists(t, "device token", err)
// Update the device token, simulate a redemption

10
storage/conformance/transactions.go

@ -1,6 +1,7 @@
package conformance
import (
"context"
"testing"
"time"
@ -26,6 +27,7 @@ func RunTransactionTests(t *testing.T, newStorage func() storage.Storage) {
}
func testClientConcurrentUpdate(t *testing.T, s storage.Storage) {
ctx := context.Background()
c := storage.Client{
ID: storage.NewID(),
Secret: "foobar",
@ -34,7 +36,7 @@ func testClientConcurrentUpdate(t *testing.T, s storage.Storage) {
LogoURL: "https://goo.gl/JIyzIC",
}
if err := s.CreateClient(c); err != nil {
if err := s.CreateClient(ctx, c); err != nil {
t.Fatalf("create client: %v", err)
}
@ -55,6 +57,7 @@ func testClientConcurrentUpdate(t *testing.T, s storage.Storage) {
}
func testAuthRequestConcurrentUpdate(t *testing.T, s storage.Storage) {
ctx := context.Background()
a := storage.AuthRequest{
ID: storage.NewID(),
ClientID: "foobar",
@ -78,7 +81,7 @@ func testAuthRequestConcurrentUpdate(t *testing.T, s storage.Storage) {
HMACKey: []byte("hmac_key"),
}
if err := s.CreateAuthRequest(a); err != nil {
if err := s.CreateAuthRequest(ctx, a); err != nil {
t.Fatalf("failed creating auth request: %v", err)
}
@ -99,6 +102,7 @@ func testAuthRequestConcurrentUpdate(t *testing.T, s storage.Storage) {
}
func testPasswordConcurrentUpdate(t *testing.T, s storage.Storage) {
ctx := context.Background()
// Use bcrypt.MinCost to keep the tests short.
passwordHash, err := bcrypt.GenerateFromPassword([]byte("secret"), bcrypt.MinCost)
if err != nil {
@ -111,7 +115,7 @@ func testPasswordConcurrentUpdate(t *testing.T, s storage.Storage) {
Username: "jane",
UserID: "foobar",
}
if err := s.CreatePassword(password); err != nil {
if err := s.CreatePassword(ctx, password); err != nil {
t.Fatalf("create password token: %v", err)
}

4
storage/ent/client/authcode.go

@ -7,7 +7,7 @@ import (
)
// CreateAuthCode saves provided auth code into the database.
func (d *Database) CreateAuthCode(code storage.AuthCode) error {
func (d *Database) CreateAuthCode(ctx context.Context, code storage.AuthCode) error {
_, err := d.client.AuthCode.Create().
SetID(code.ID).
SetClientID(code.ClientID).
@ -26,7 +26,7 @@ func (d *Database) CreateAuthCode(code storage.AuthCode) error {
SetExpiry(code.Expiry.UTC()).
SetConnectorID(code.ConnectorID).
SetConnectorData(code.ConnectorData).
Save(context.TODO())
Save(ctx)
if err != nil {
return convertDBError("create auth code: %w", err)
}

4
storage/ent/client/authrequest.go

@ -8,7 +8,7 @@ import (
)
// CreateAuthRequest saves provided auth request into the database.
func (d *Database) CreateAuthRequest(authRequest storage.AuthRequest) error {
func (d *Database) CreateAuthRequest(ctx context.Context, authRequest storage.AuthRequest) error {
_, err := d.client.AuthRequest.Create().
SetID(authRequest.ID).
SetClientID(authRequest.ClientID).
@ -32,7 +32,7 @@ func (d *Database) CreateAuthRequest(authRequest storage.AuthRequest) error {
SetConnectorID(authRequest.ConnectorID).
SetConnectorData(authRequest.ConnectorData).
SetHmacKey(authRequest.HMACKey).
Save(context.TODO())
Save(ctx)
if err != nil {
return convertDBError("create auth request: %w", err)
}

4
storage/ent/client/client.go

@ -7,7 +7,7 @@ import (
)
// CreateClient saves provided oauth2 client settings into the database.
func (d *Database) CreateClient(client storage.Client) error {
func (d *Database) CreateClient(ctx context.Context, client storage.Client) error {
_, err := d.client.OAuth2Client.Create().
SetID(client.ID).
SetName(client.Name).
@ -16,7 +16,7 @@ func (d *Database) CreateClient(client storage.Client) error {
SetLogoURL(client.LogoURL).
SetRedirectUris(client.RedirectURIs).
SetTrustedPeers(client.TrustedPeers).
Save(context.TODO())
Save(ctx)
if err != nil {
return convertDBError("create oauth2 client: %w", err)
}

4
storage/ent/client/connector.go

@ -7,14 +7,14 @@ import (
)
// CreateConnector saves a connector into the database.
func (d *Database) CreateConnector(connector storage.Connector) error {
func (d *Database) CreateConnector(ctx context.Context, connector storage.Connector) error {
_, err := d.client.Connector.Create().
SetID(connector.ID).
SetName(connector.Name).
SetType(connector.Type).
SetResourceVersion(connector.ResourceVersion).
SetConfig(connector.Config).
Save(context.TODO())
Save(ctx)
if err != nil {
return convertDBError("create connector: %w", err)
}

4
storage/ent/client/devicerequest.go

@ -8,7 +8,7 @@ import (
)
// CreateDeviceRequest saves provided device request into the database.
func (d *Database) CreateDeviceRequest(request storage.DeviceRequest) error {
func (d *Database) CreateDeviceRequest(ctx context.Context, request storage.DeviceRequest) error {
_, err := d.client.DeviceRequest.Create().
SetClientID(request.ClientID).
SetClientSecret(request.ClientSecret).
@ -17,7 +17,7 @@ func (d *Database) CreateDeviceRequest(request storage.DeviceRequest) error {
SetDeviceCode(request.DeviceCode).
// Save utc time into database because ent doesn't support comparing dates with different timezones
SetExpiry(request.Expiry.UTC()).
Save(context.TODO())
Save(ctx)
if err != nil {
return convertDBError("create device request: %w", err)
}

4
storage/ent/client/devicetoken.go

@ -8,7 +8,7 @@ import (
)
// CreateDeviceToken saves provided token into the database.
func (d *Database) CreateDeviceToken(token storage.DeviceToken) error {
func (d *Database) CreateDeviceToken(ctx context.Context, token storage.DeviceToken) error {
_, err := d.client.DeviceToken.Create().
SetDeviceCode(token.DeviceCode).
SetToken([]byte(token.Token)).
@ -19,7 +19,7 @@ func (d *Database) CreateDeviceToken(token storage.DeviceToken) error {
SetStatus(token.Status).
SetCodeChallenge(token.PKCE.CodeChallenge).
SetCodeChallengeMethod(token.PKCE.CodeChallengeMethod).
Save(context.TODO())
Save(ctx)
if err != nil {
return convertDBError("create device token: %w", err)
}

4
storage/ent/client/offlinesession.go

@ -9,7 +9,7 @@ import (
)
// CreateOfflineSessions saves provided offline session into the database.
func (d *Database) CreateOfflineSessions(session storage.OfflineSessions) error {
func (d *Database) CreateOfflineSessions(ctx context.Context, session storage.OfflineSessions) error {
encodedRefresh, err := json.Marshal(session.Refresh)
if err != nil {
return fmt.Errorf("encode refresh offline session: %w", err)
@ -22,7 +22,7 @@ func (d *Database) CreateOfflineSessions(session storage.OfflineSessions) error
SetConnID(session.ConnID).
SetConnectorData(session.ConnectorData).
SetRefresh(encodedRefresh).
Save(context.TODO())
Save(ctx)
if err != nil {
return convertDBError("create offline session: %w", err)
}

4
storage/ent/client/password.go

@ -9,13 +9,13 @@ import (
)
// CreatePassword saves provided password into the database.
func (d *Database) CreatePassword(password storage.Password) error {
func (d *Database) CreatePassword(ctx context.Context, password storage.Password) error {
_, err := d.client.Password.Create().
SetEmail(password.Email).
SetHash(password.Hash).
SetUsername(password.Username).
SetUserID(password.UserID).
Save(context.TODO())
Save(ctx)
if err != nil {
return convertDBError("create password: %w", err)
}

4
storage/ent/client/refreshtoken.go

@ -7,7 +7,7 @@ import (
)
// CreateRefresh saves provided refresh token into the database.
func (d *Database) CreateRefresh(refresh storage.RefreshToken) error {
func (d *Database) CreateRefresh(ctx context.Context, refresh storage.RefreshToken) error {
_, err := d.client.RefreshToken.Create().
SetID(refresh.ID).
SetClientID(refresh.ClientID).
@ -26,7 +26,7 @@ func (d *Database) CreateRefresh(refresh storage.RefreshToken) error {
// Save utc time into database because ent doesn't support comparing dates with different timezones
SetLastUsed(refresh.LastUsed.UTC()).
SetCreatedAt(refresh.CreatedAt.UTC()).
Save(context.TODO())
Save(ctx)
if err != nil {
return convertDBError("create refresh token: %w", err)
}

38
storage/etcd/etcd.go

@ -29,6 +29,8 @@ const (
defaultStorageTimeout = 5 * time.Second
)
var _ storage.Storage = (*conn)(nil)
type conn struct {
db *clientv3.Client
logger log.Logger
@ -107,9 +109,7 @@ func (c *conn) GarbageCollect(now time.Time) (result storage.GCResult, err error
return result, delErr
}
func (c *conn) CreateAuthRequest(a storage.AuthRequest) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
func (c *conn) CreateAuthRequest(ctx context.Context, a storage.AuthRequest) error {
return c.txnCreate(ctx, keyID(authRequestPrefix, a.ID), fromStorageAuthRequest(a))
}
@ -147,9 +147,7 @@ func (c *conn) DeleteAuthRequest(id string) error {
return c.deleteKey(ctx, keyID(authRequestPrefix, id))
}
func (c *conn) CreateAuthCode(a storage.AuthCode) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
func (c *conn) CreateAuthCode(ctx context.Context, a storage.AuthCode) error {
return c.txnCreate(ctx, keyID(authCodePrefix, a.ID), fromStorageAuthCode(a))
}
@ -170,9 +168,7 @@ func (c *conn) DeleteAuthCode(id string) error {
return c.deleteKey(ctx, keyID(authCodePrefix, id))
}
func (c *conn) CreateRefresh(r storage.RefreshToken) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
func (c *conn) CreateRefresh(ctx context.Context, r storage.RefreshToken) error {
return c.txnCreate(ctx, keyID(refreshTokenPrefix, r.ID), fromStorageRefreshToken(r))
}
@ -227,9 +223,7 @@ func (c *conn) ListRefreshTokens() (tokens []storage.RefreshToken, err error) {
return tokens, nil
}
func (c *conn) CreateClient(cli storage.Client) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
func (c *conn) CreateClient(ctx context.Context, cli storage.Client) error {
return c.txnCreate(ctx, keyID(clientPrefix, cli.ID), cli)
}
@ -281,9 +275,7 @@ func (c *conn) ListClients() (clients []storage.Client, err error) {
return clients, nil
}
func (c *conn) CreatePassword(p storage.Password) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
func (c *conn) CreatePassword(ctx context.Context, p storage.Password) error {
return c.txnCreate(ctx, passwordPrefix+strings.ToLower(p.Email), p)
}
@ -335,9 +327,7 @@ func (c *conn) ListPasswords() (passwords []storage.Password, err error) {
return passwords, nil
}
func (c *conn) CreateOfflineSessions(s storage.OfflineSessions) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
func (c *conn) CreateOfflineSessions(ctx context.Context, s storage.OfflineSessions) error {
return c.txnCreate(ctx, keySession(s.UserID, s.ConnID), fromStorageOfflineSessions(s))
}
@ -375,9 +365,7 @@ func (c *conn) DeleteOfflineSessions(userID string, connID string) error {
return c.deleteKey(ctx, keySession(userID, connID))
}
func (c *conn) CreateConnector(connector storage.Connector) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
func (c *conn) CreateConnector(ctx context.Context, connector storage.Connector) error {
return c.txnCreate(ctx, keyID(connectorPrefix, connector.ID), connector)
}
@ -568,9 +556,7 @@ func keySession(userID, connID string) string {
return offlineSessionPrefix + strings.ToLower(userID+"|"+connID)
}
func (c *conn) CreateDeviceRequest(d storage.DeviceRequest) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
func (c *conn) CreateDeviceRequest(ctx context.Context, d storage.DeviceRequest) error {
return c.txnCreate(ctx, keyID(deviceRequestPrefix, d.UserCode), fromStorageDeviceRequest(d))
}
@ -599,9 +585,7 @@ func (c *conn) listDeviceRequests(ctx context.Context) (requests []DeviceRequest
return requests, nil
}
func (c *conn) CreateDeviceToken(t storage.DeviceToken) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
func (c *conn) CreateDeviceToken(ctx context.Context, t storage.DeviceToken) error {
return c.txnCreate(ctx, keyID(deviceTokenPrefix, t.DeviceCode), fromStorageDeviceToken(t))
}

4
storage/health.go

@ -9,7 +9,7 @@ import (
// NewCustomHealthCheckFunc returns a new health check function.
func NewCustomHealthCheckFunc(s Storage, now func() time.Time) func(context.Context) (details interface{}, err error) {
return func(_ context.Context) (details interface{}, err error) {
return func(ctx context.Context) (details interface{}, err error) {
a := AuthRequest{
ID: NewID(),
ClientID: NewID(),
@ -19,7 +19,7 @@ func NewCustomHealthCheckFunc(s Storage, now func() time.Time) func(context.Cont
HMACKey: NewHMACKey(crypto.SHA256),
}
if err := s.CreateAuthRequest(a); err != nil {
if err := s.CreateAuthRequest(ctx, a); err != nil {
return nil, fmt.Errorf("create auth request: %v", err)
}

20
storage/kubernetes/storage.go

@ -40,6 +40,8 @@ const (
resourceDeviceToken = "devicetokens"
)
var _ storage.Storage = (*client)(nil)
const (
gcResultLimit = 500
)
@ -232,31 +234,31 @@ func (cli *client) Close() error {
return nil
}
func (cli *client) CreateAuthRequest(a storage.AuthRequest) error {
func (cli *client) CreateAuthRequest(ctx context.Context, a storage.AuthRequest) error {
return cli.post(resourceAuthRequest, cli.fromStorageAuthRequest(a))
}
func (cli *client) CreateClient(c storage.Client) error {
func (cli *client) CreateClient(ctx context.Context, c storage.Client) error {
return cli.post(resourceClient, cli.fromStorageClient(c))
}
func (cli *client) CreateAuthCode(c storage.AuthCode) error {
func (cli *client) CreateAuthCode(ctx context.Context, c storage.AuthCode) error {
return cli.post(resourceAuthCode, cli.fromStorageAuthCode(c))
}
func (cli *client) CreatePassword(p storage.Password) error {
func (cli *client) CreatePassword(ctx context.Context, p storage.Password) error {
return cli.post(resourcePassword, cli.fromStoragePassword(p))
}
func (cli *client) CreateRefresh(r storage.RefreshToken) error {
func (cli *client) CreateRefresh(ctx context.Context, r storage.RefreshToken) error {
return cli.post(resourceRefreshToken, cli.fromStorageRefreshToken(r))
}
func (cli *client) CreateOfflineSessions(o storage.OfflineSessions) error {
func (cli *client) CreateOfflineSessions(ctx context.Context, o storage.OfflineSessions) error {
return cli.post(resourceOfflineSessions, cli.fromStorageOfflineSessions(o))
}
func (cli *client) CreateConnector(c storage.Connector) error {
func (cli *client) CreateConnector(ctx context.Context, c storage.Connector) error {
return cli.post(resourceConnector, cli.fromStorageConnector(c))
}
@ -681,7 +683,7 @@ func (cli *client) GarbageCollect(now time.Time) (result storage.GCResult, err e
return result, delErr
}
func (cli *client) CreateDeviceRequest(d storage.DeviceRequest) error {
func (cli *client) CreateDeviceRequest(ctx context.Context, d storage.DeviceRequest) error {
return cli.post(resourceDeviceRequest, cli.fromStorageDeviceRequest(d))
}
@ -693,7 +695,7 @@ func (cli *client) GetDeviceRequest(userCode string) (storage.DeviceRequest, err
return toStorageDeviceRequest(req), nil
}
func (cli *client) CreateDeviceToken(t storage.DeviceToken) error {
func (cli *client) CreateDeviceToken(ctx context.Context, t storage.DeviceToken) error {
return cli.post(resourceDeviceToken, cli.fromStorageDeviceToken(t))
}

3
storage/kubernetes/storage_test.go

@ -302,6 +302,7 @@ func TestRetryOnConflict(t *testing.T) {
}
func TestRefreshTokenLock(t *testing.T) {
ctx := context.Background()
if os.Getenv(kubeconfigPathVariableName) == "" {
t.Skipf("variable %q not set, skipping kubernetes storage tests\n", kubeconfigPathVariableName)
}
@ -345,7 +346,7 @@ func TestRefreshTokenLock(t *testing.T) {
ConnectorData: []byte(`{"some":"data"}`),
}
err = kubeClient.CreateRefresh(r)
err = kubeClient.CreateRefresh(ctx, r)
require.NoError(t, err)
t.Run("Timeout lock error", func(t *testing.T) {

21
storage/memory/memory.go

@ -2,6 +2,7 @@
package memory
import (
"context"
"strings"
"sync"
"time"
@ -10,6 +11,8 @@ import (
"github.com/dexidp/dex/storage"
)
var _ storage.Storage = (*memStorage)(nil)
// New returns an in memory storage.
func New(logger log.Logger) storage.Storage {
return &memStorage{
@ -98,7 +101,7 @@ func (s *memStorage) GarbageCollect(now time.Time) (result storage.GCResult, err
return result, nil
}
func (s *memStorage) CreateClient(c storage.Client) (err error) {
func (s *memStorage) CreateClient(ctx context.Context, c storage.Client) (err error) {
s.tx(func() {
if _, ok := s.clients[c.ID]; ok {
err = storage.ErrAlreadyExists
@ -109,7 +112,7 @@ func (s *memStorage) CreateClient(c storage.Client) (err error) {
return
}
func (s *memStorage) CreateAuthCode(c storage.AuthCode) (err error) {
func (s *memStorage) CreateAuthCode(ctx context.Context, c storage.AuthCode) (err error) {
s.tx(func() {
if _, ok := s.authCodes[c.ID]; ok {
err = storage.ErrAlreadyExists
@ -120,7 +123,7 @@ func (s *memStorage) CreateAuthCode(c storage.AuthCode) (err error) {
return
}
func (s *memStorage) CreateRefresh(r storage.RefreshToken) (err error) {
func (s *memStorage) CreateRefresh(ctx context.Context, r storage.RefreshToken) (err error) {
s.tx(func() {
if _, ok := s.refreshTokens[r.ID]; ok {
err = storage.ErrAlreadyExists
@ -131,7 +134,7 @@ func (s *memStorage) CreateRefresh(r storage.RefreshToken) (err error) {
return
}
func (s *memStorage) CreateAuthRequest(a storage.AuthRequest) (err error) {
func (s *memStorage) CreateAuthRequest(ctx context.Context, a storage.AuthRequest) (err error) {
s.tx(func() {
if _, ok := s.authReqs[a.ID]; ok {
err = storage.ErrAlreadyExists
@ -142,7 +145,7 @@ func (s *memStorage) CreateAuthRequest(a storage.AuthRequest) (err error) {
return
}
func (s *memStorage) CreatePassword(p storage.Password) (err error) {
func (s *memStorage) CreatePassword(ctx context.Context, p storage.Password) (err error) {
lowerEmail := strings.ToLower(p.Email)
s.tx(func() {
if _, ok := s.passwords[lowerEmail]; ok {
@ -154,7 +157,7 @@ func (s *memStorage) CreatePassword(p storage.Password) (err error) {
return
}
func (s *memStorage) CreateOfflineSessions(o storage.OfflineSessions) (err error) {
func (s *memStorage) CreateOfflineSessions(ctx context.Context, o storage.OfflineSessions) (err error) {
id := offlineSessionID{
userID: o.UserID,
connID: o.ConnID,
@ -169,7 +172,7 @@ func (s *memStorage) CreateOfflineSessions(o storage.OfflineSessions) (err error
return
}
func (s *memStorage) CreateConnector(connector storage.Connector) (err error) {
func (s *memStorage) CreateConnector(ctx context.Context, connector storage.Connector) (err error) {
s.tx(func() {
if _, ok := s.connectors[connector.ID]; ok {
err = storage.ErrAlreadyExists
@ -481,7 +484,7 @@ func (s *memStorage) UpdateConnector(id string, updater func(c storage.Connector
return
}
func (s *memStorage) CreateDeviceRequest(d storage.DeviceRequest) (err error) {
func (s *memStorage) CreateDeviceRequest(ctx context.Context, d storage.DeviceRequest) (err error) {
s.tx(func() {
if _, ok := s.deviceRequests[d.UserCode]; ok {
err = storage.ErrAlreadyExists
@ -503,7 +506,7 @@ func (s *memStorage) GetDeviceRequest(userCode string) (req storage.DeviceReques
return
}
func (s *memStorage) CreateDeviceToken(t storage.DeviceToken) (err error) {
func (s *memStorage) CreateDeviceToken(ctx context.Context, t storage.DeviceToken) (err error) {
s.tx(func() {
if _, ok := s.deviceTokens[t.DeviceCode]; ok {
err = storage.ErrAlreadyExists

18
storage/memory/static_test.go

@ -1,6 +1,7 @@
package memory
import (
"context"
"fmt"
"os"
"strings"
@ -12,6 +13,7 @@ import (
)
func TestStaticClients(t *testing.T) {
ctx := context.Background()
logger := &logrus.Logger{
Out: os.Stderr,
Formatter: &logrus.TextFormatter{DisableColors: true},
@ -23,7 +25,7 @@ func TestStaticClients(t *testing.T) {
c2 := storage.Client{ID: "bar", Secret: "bar_secret"}
c3 := storage.Client{ID: "spam", Secret: "spam_secret"}
backing.CreateClient(c1)
backing.CreateClient(ctx, c1)
s := storage.WithStaticClients(backing, []storage.Client{c2})
tests := []struct {
@ -82,7 +84,7 @@ func TestStaticClients(t *testing.T) {
{
name: "create client",
action: func() error {
return s.CreateClient(c3)
return s.CreateClient(ctx, c3)
},
},
}
@ -99,6 +101,7 @@ func TestStaticClients(t *testing.T) {
}
func TestStaticPasswords(t *testing.T) {
ctx := context.Background()
logger := &logrus.Logger{
Out: os.Stderr,
Formatter: &logrus.TextFormatter{DisableColors: true},
@ -111,7 +114,7 @@ func TestStaticPasswords(t *testing.T) {
p3 := storage.Password{Email: "spam@example.com", Username: "spam_secret"}
p4 := storage.Password{Email: "Spam@example.com", Username: "Spam_secret"}
backing.CreatePassword(p1)
backing.CreatePassword(ctx, p1)
s := storage.WithStaticPasswords(backing, []storage.Password{p2}, logger)
tests := []struct {
@ -164,10 +167,10 @@ func TestStaticPasswords(t *testing.T) {
{
name: "create passwords",
action: func() error {
if err := s.CreatePassword(p4); err != nil {
if err := s.CreatePassword(ctx, p4); err != nil {
return err
}
return s.CreatePassword(p3)
return s.CreatePassword(ctx, p3)
},
wantErr: true,
},
@ -211,6 +214,7 @@ func TestStaticPasswords(t *testing.T) {
}
func TestStaticConnectors(t *testing.T) {
ctx := context.Background()
logger := &logrus.Logger{
Out: os.Stderr,
Formatter: &logrus.TextFormatter{DisableColors: true},
@ -226,7 +230,7 @@ func TestStaticConnectors(t *testing.T) {
c2 := storage.Connector{ID: storage.NewID(), Type: "ldap", Name: "ldap", ResourceVersion: "1", Config: config2}
c3 := storage.Connector{ID: storage.NewID(), Type: "saml", Name: "saml", ResourceVersion: "1", Config: config3}
backing.CreateConnector(c1)
backing.CreateConnector(ctx, c1)
s := storage.WithStaticConnectors(backing, []storage.Connector{c2})
tests := []struct {
@ -285,7 +289,7 @@ func TestStaticConnectors(t *testing.T) {
{
name: "create connector",
action: func() error {
return s.CreateConnector(c3)
return s.CreateConnector(ctx, c3)
},
},
}

23
storage/sql/crud.go

@ -1,6 +1,7 @@
package sql
import (
"context"
"database/sql"
"database/sql/driver"
"encoding/json"
@ -83,6 +84,8 @@ type scanner interface {
Scan(dest ...interface{}) error
}
var _ storage.Storage = (*conn)(nil)
func (c *conn) GarbageCollect(now time.Time) (storage.GCResult, error) {
result := storage.GCResult{}
@ -121,7 +124,7 @@ func (c *conn) GarbageCollect(now time.Time) (storage.GCResult, error) {
return result, err
}
func (c *conn) CreateAuthRequest(a storage.AuthRequest) error {
func (c *conn) CreateAuthRequest(ctx context.Context, a storage.AuthRequest) error {
_, err := c.Exec(`
insert into auth_request (
id, client_id, response_types, scopes, redirect_uri, nonce, state,
@ -229,7 +232,7 @@ func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) {
return a, nil
}
func (c *conn) CreateAuthCode(a storage.AuthCode) error {
func (c *conn) CreateAuthCode(ctx context.Context, a storage.AuthCode) error {
_, err := c.Exec(`
insert into auth_code (
id, client_id, scopes, nonce, redirect_uri,
@ -280,7 +283,7 @@ func (c *conn) GetAuthCode(id string) (a storage.AuthCode, err error) {
return a, nil
}
func (c *conn) CreateRefresh(r storage.RefreshToken) error {
func (c *conn) CreateRefresh(ctx context.Context, r storage.RefreshToken) error {
_, err := c.Exec(`
insert into refresh_token (
id, client_id, scopes, nonce,
@ -521,7 +524,7 @@ func (c *conn) UpdateClient(id string, updater func(old storage.Client) (storage
})
}
func (c *conn) CreateClient(cli storage.Client) error {
func (c *conn) CreateClient(ctx context.Context, cli storage.Client) error {
_, err := c.Exec(`
insert into client (
id, secret, redirect_uris, trusted_peers, public, name, logo_url
@ -591,7 +594,7 @@ func scanClient(s scanner) (cli storage.Client, err error) {
return cli, nil
}
func (c *conn) CreatePassword(p storage.Password) error {
func (c *conn) CreatePassword(ctx context.Context, p storage.Password) error {
p.Email = strings.ToLower(p.Email)
_, err := c.Exec(`
insert into password (
@ -688,7 +691,7 @@ func scanPassword(s scanner) (p storage.Password, err error) {
return p, nil
}
func (c *conn) CreateOfflineSessions(s storage.OfflineSessions) error {
func (c *conn) CreateOfflineSessions(ctx context.Context, s storage.OfflineSessions) error {
_, err := c.Exec(`
insert into offline_session (
user_id, conn_id, refresh, connector_data
@ -761,7 +764,7 @@ func scanOfflineSessions(s scanner) (o storage.OfflineSessions, err error) {
return o, nil
}
func (c *conn) CreateConnector(connector storage.Connector) error {
func (c *conn) CreateConnector(ctx context.Context, connector storage.Connector) error {
_, err := c.Exec(`
insert into connector (
id, type, name, resource_version, config
@ -907,7 +910,7 @@ func (c *conn) delete(table, field, id string) error {
return nil
}
func (c *conn) CreateDeviceRequest(d storage.DeviceRequest) error {
func (c *conn) CreateDeviceRequest(ctx context.Context, d storage.DeviceRequest) error {
_, err := c.Exec(`
insert into device_request (
user_code, device_code, client_id, client_secret, scopes, expiry
@ -926,7 +929,7 @@ func (c *conn) CreateDeviceRequest(d storage.DeviceRequest) error {
return nil
}
func (c *conn) CreateDeviceToken(t storage.DeviceToken) error {
func (c *conn) CreateDeviceToken(ctx context.Context, t storage.DeviceToken) error {
_, err := c.Exec(`
insert into device_token (
device_code, status, token, expiry, last_request, poll_interval, code_challenge, code_challenge_method
@ -1001,7 +1004,7 @@ func (c *conn) UpdateDeviceToken(deviceCode string, updater func(old storage.Dev
_, err = tx.Exec(`
update device_token
set
status = $1,
status = $1,
token = $2,
last_request = $3,
poll_interval = $4,

13
storage/static.go

@ -1,6 +1,7 @@
package storage
import (
"context"
"errors"
"strings"
@ -60,11 +61,11 @@ func (s staticClientsStorage) ListClients() ([]Client, error) {
return append(clients[:n], s.clients...), nil
}
func (s staticClientsStorage) CreateClient(c Client) error {
func (s staticClientsStorage) CreateClient(ctx context.Context, c Client) error {
if s.isStatic(c.ID) {
return errors.New("static clients: read-only cannot create client")
}
return s.Storage.CreateClient(c)
return s.Storage.CreateClient(ctx, c)
}
func (s staticClientsStorage) DeleteClient(id string) error {
@ -140,11 +141,11 @@ func (s staticPasswordsStorage) ListPasswords() ([]Password, error) {
return append(passwords[:n], s.passwords...), nil
}
func (s staticPasswordsStorage) CreatePassword(p Password) error {
func (s staticPasswordsStorage) CreatePassword(ctx context.Context, p Password) error {
if s.isStatic(p.Email) {
return errors.New("static passwords: read-only cannot create password")
}
return s.Storage.CreatePassword(p)
return s.Storage.CreatePassword(ctx, p)
}
func (s staticPasswordsStorage) DeletePassword(email string) error {
@ -210,11 +211,11 @@ func (s staticConnectorsStorage) ListConnectors() ([]Connector, error) {
return append(connectors[:n], s.connectors...), nil
}
func (s staticConnectorsStorage) CreateConnector(c Connector) error {
func (s staticConnectorsStorage) CreateConnector(ctx context.Context, c Connector) error {
if s.isStatic(c.ID) {
return errors.New("static connectors: read-only cannot create connector")
}
return s.Storage.CreateConnector(c)
return s.Storage.CreateConnector(ctx, c)
}
func (s staticConnectorsStorage) DeleteConnector(id string) error {

19
storage/storage.go

@ -1,6 +1,7 @@
package storage
import (
"context"
"crypto"
"crypto/rand"
"encoding/base32"
@ -76,15 +77,15 @@ type Storage interface {
Close() error
// TODO(ericchiang): Let the storages set the IDs of these objects.
CreateAuthRequest(a AuthRequest) error
CreateClient(c Client) error
CreateAuthCode(c AuthCode) error
CreateRefresh(r RefreshToken) error
CreatePassword(p Password) error
CreateOfflineSessions(s OfflineSessions) error
CreateConnector(c Connector) error
CreateDeviceRequest(d DeviceRequest) error
CreateDeviceToken(d DeviceToken) error
CreateAuthRequest(ctx context.Context, a AuthRequest) error
CreateClient(ctx context.Context, c Client) error
CreateAuthCode(ctx context.Context, c AuthCode) error
CreateRefresh(ctx context.Context, r RefreshToken) error
CreatePassword(ctx context.Context, p Password) error
CreateOfflineSessions(ctx context.Context, s OfflineSessions) error
CreateConnector(ctx context.Context, c Connector) error
CreateDeviceRequest(ctx context.Context, d DeviceRequest) error
CreateDeviceToken(ctx context.Context, d DeviceToken) error
// TODO(ericchiang): return (T, bool, error) so we can indicate not found
// requests that way instead of using ErrNotFound.

Loading…
Cancel
Save