Browse Source

Passing context storage (#3941)

Signed-off-by: Bob Maertz <1771054+bobmaertz@users.noreply.github.com>
pull/3967/head
Bob Maertz 1 year ago committed by GitHub
parent
commit
ad31b5d6f7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 28
      server/api.go
  2. 4
      server/api_test.go
  3. 18
      server/deviceflowhandlers.go
  4. 70
      server/handlers.go
  5. 8
      server/handlers_test.go
  6. 2
      server/introspectionhandler.go
  7. 13
      server/oauth2.go
  8. 2
      server/oauth2_test.go
  9. 10
      server/refreshhandlers.go
  10. 4
      server/rotation.go
  11. 5
      server/rotation_test.go
  12. 16
      server/server.go
  13. 14
      server/server_test.go
  14. 120
      storage/conformance/conformance.go
  15. 17
      storage/conformance/transactions.go
  16. 8
      storage/ent/client/authcode.go
  17. 12
      storage/ent/client/authrequest.go
  18. 20
      storage/ent/client/client.go
  19. 20
      storage/ent/client/connector.go
  20. 4
      storage/ent/client/devicerequest.go
  21. 12
      storage/ent/client/devicetoken.go
  22. 18
      storage/ent/client/keys.go
  23. 10
      storage/ent/client/main.go
  24. 16
      storage/ent/client/offlinesession.go
  25. 20
      storage/ent/client/password.go
  26. 20
      storage/ent/client/refreshtoken.go
  27. 121
      storage/etcd/etcd.go
  28. 2
      storage/health.go
  29. 68
      storage/kubernetes/storage.go
  30. 14
      storage/kubernetes/storage_test.go
  31. 60
      storage/memory/memory.go
  32. 34
      storage/memory/static_test.go
  33. 127
      storage/sql/crud.go
  34. 48
      storage/static.go
  35. 62
      storage/storage.go

28
server/api.go

@ -51,7 +51,7 @@ type dexAPI struct {
}
func (d dexAPI) GetClient(ctx context.Context, req *api.GetClientReq) (*api.GetClientResp, error) {
c, err := d.s.GetClient(req.Id)
c, err := d.s.GetClient(ctx, req.Id)
if err != nil {
return nil, err
}
@ -108,7 +108,7 @@ func (d dexAPI) UpdateClient(ctx context.Context, req *api.UpdateClientReq) (*ap
return nil, errors.New("update client: no client ID supplied")
}
err := d.s.UpdateClient(req.Id, func(old storage.Client) (storage.Client, error) {
err := d.s.UpdateClient(ctx, req.Id, func(old storage.Client) (storage.Client, error) {
if req.RedirectUris != nil {
old.RedirectURIs = req.RedirectUris
}
@ -134,7 +134,7 @@ func (d dexAPI) UpdateClient(ctx context.Context, req *api.UpdateClientReq) (*ap
}
func (d dexAPI) DeleteClient(ctx context.Context, req *api.DeleteClientReq) (*api.DeleteClientResp, error) {
err := d.s.DeleteClient(req.Id)
err := d.s.DeleteClient(ctx, req.Id)
if err != nil {
if err == storage.ErrNotFound {
return &api.DeleteClientResp{NotFound: true}, nil
@ -219,7 +219,7 @@ func (d dexAPI) UpdatePassword(ctx context.Context, req *api.UpdatePasswordReq)
return old, nil
}
if err := d.s.UpdatePassword(req.Email, updater); err != nil {
if err := d.s.UpdatePassword(ctx, req.Email, updater); err != nil {
if err == storage.ErrNotFound {
return &api.UpdatePasswordResp{NotFound: true}, nil
}
@ -235,7 +235,7 @@ func (d dexAPI) DeletePassword(ctx context.Context, req *api.DeletePasswordReq)
return nil, errors.New("no email supplied")
}
err := d.s.DeletePassword(req.Email)
err := d.s.DeletePassword(ctx, req.Email)
if err != nil {
if err == storage.ErrNotFound {
return &api.DeletePasswordResp{NotFound: true}, nil
@ -268,7 +268,7 @@ func (d dexAPI) GetDiscovery(ctx context.Context, req *api.DiscoveryReq) (*api.D
}
func (d dexAPI) ListPasswords(ctx context.Context, req *api.ListPasswordReq) (*api.ListPasswordResp, error) {
passwordList, err := d.s.ListPasswords()
passwordList, err := d.s.ListPasswords(ctx)
if err != nil {
d.logger.Error("failed to list passwords", "err", err)
return nil, fmt.Errorf("list passwords: %v", err)
@ -298,7 +298,7 @@ func (d dexAPI) VerifyPassword(ctx context.Context, req *api.VerifyPasswordReq)
return nil, errors.New("no password to verify supplied")
}
password, err := d.s.GetPassword(req.Email)
password, err := d.s.GetPassword(ctx, req.Email)
if err != nil {
if err == storage.ErrNotFound {
return &api.VerifyPasswordResp{
@ -327,7 +327,7 @@ func (d dexAPI) ListRefresh(ctx context.Context, req *api.ListRefreshReq) (*api.
return nil, err
}
offlineSessions, err := d.s.GetOfflineSessions(id.UserId, id.ConnId)
offlineSessions, err := d.s.GetOfflineSessions(ctx, id.UserId, id.ConnId)
if err != nil {
if err == storage.ErrNotFound {
// This means that this user-client pair does not have a refresh token yet.
@ -381,7 +381,7 @@ func (d dexAPI) RevokeRefresh(ctx context.Context, req *api.RevokeRefreshReq) (*
return old, nil
}
if err := d.s.UpdateOfflineSessions(id.UserId, id.ConnId, updater); err != nil {
if err := d.s.UpdateOfflineSessions(ctx, id.UserId, id.ConnId, updater); err != nil {
if err == storage.ErrNotFound {
return &api.RevokeRefreshResp{NotFound: true}, nil
}
@ -397,7 +397,7 @@ func (d dexAPI) RevokeRefresh(ctx context.Context, req *api.RevokeRefreshReq) (*
//
// TODO(ericchiang): we don't have any good recourse if this call fails.
// Consider garbage collection of refresh tokens with no associated ref.
if err := d.s.DeleteRefresh(refreshID); err != nil {
if err := d.s.DeleteRefresh(ctx, refreshID); err != nil {
d.logger.Error("failed to delete refresh token", "err", err)
return nil, err
}
@ -448,7 +448,7 @@ func (d dexAPI) CreateConnector(ctx context.Context, req *api.CreateConnectorReq
return &api.CreateConnectorResp{}, nil
}
func (d dexAPI) UpdateConnector(_ context.Context, req *api.UpdateConnectorReq) (*api.UpdateConnectorResp, error) {
func (d dexAPI) UpdateConnector(ctx context.Context, req *api.UpdateConnectorReq) (*api.UpdateConnectorResp, error) {
if !featureflags.APIConnectorsCRUD.Enabled() {
return nil, fmt.Errorf("%s feature flag is not enabled", featureflags.APIConnectorsCRUD.Name)
}
@ -485,7 +485,7 @@ func (d dexAPI) UpdateConnector(_ context.Context, req *api.UpdateConnectorReq)
return old, nil
}
if err := d.s.UpdateConnector(req.Id, updater); err != nil {
if err := d.s.UpdateConnector(ctx, req.Id, updater); err != nil {
if err == storage.ErrNotFound {
return &api.UpdateConnectorResp{NotFound: true}, nil
}
@ -505,7 +505,7 @@ func (d dexAPI) DeleteConnector(ctx context.Context, req *api.DeleteConnectorReq
return nil, errors.New("no id supplied")
}
err := d.s.DeleteConnector(req.Id)
err := d.s.DeleteConnector(ctx, req.Id)
if err != nil {
if err == storage.ErrNotFound {
return &api.DeleteConnectorResp{NotFound: true}, nil
@ -521,7 +521,7 @@ func (d dexAPI) ListConnectors(ctx context.Context, req *api.ListConnectorReq) (
return nil, fmt.Errorf("%s feature flag is not enabled", featureflags.APIConnectorsCRUD.Name)
}
connectorList, err := d.s.ListConnectors()
connectorList, err := d.s.ListConnectors(ctx)
if err != nil {
d.logger.Error("api: failed to list connectors", "err", err)
return nil, fmt.Errorf("list connectors: %v", err)

4
server/api_test.go

@ -149,7 +149,7 @@ func TestPassword(t *testing.T) {
t.Fatalf("Unable to update password: %v", err)
}
pass, err := s.GetPassword(updateReq.Email)
pass, err := s.GetPassword(ctx, updateReq.Email)
if err != nil {
t.Fatalf("Unable to retrieve password: %v", err)
}
@ -449,7 +449,7 @@ func TestUpdateClient(t *testing.T) {
t.Errorf("expected in response NotFound: %t", tc.want.NotFound)
}
client, err := s.GetClient(tc.req.Id)
client, err := s.GetClient(ctx, tc.req.Id)
if err != nil {
t.Errorf("no client found in the storage: %v", err)
}

18
server/deviceflowhandlers.go

@ -199,6 +199,7 @@ func (s *Server) handleDeviceTokenDeprecated(w http.ResponseWriter, r *http.Requ
}
func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
deviceCode := r.Form.Get("device_code")
if deviceCode == "" {
s.tokenErrHelper(w, errInvalidRequest, "No device code received", http.StatusBadRequest)
@ -208,7 +209,7 @@ func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
now := s.now()
// Grab the device token, check validity
deviceToken, err := s.storage.GetDeviceToken(deviceCode)
deviceToken, err := s.storage.GetDeviceToken(ctx, deviceCode)
if err != nil {
if err != storage.ErrNotFound {
s.logger.ErrorContext(r.Context(), "failed to get device code", "err", err)
@ -240,7 +241,7 @@ func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
return old, nil
}
// Update device token last request time in storage
if err := s.storage.UpdateDeviceToken(deviceCode, updater); err != nil {
if err := s.storage.UpdateDeviceToken(ctx, deviceCode, updater); err != nil {
s.logger.ErrorContext(r.Context(), "failed to update device token", "err", err)
s.renderError(r, w, http.StatusInternalServerError, "")
return
@ -299,7 +300,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
return
}
authCode, err := s.storage.GetAuthCode(code)
authCode, err := s.storage.GetAuthCode(ctx, code)
if err != nil || s.now().After(authCode.Expiry) {
errCode := http.StatusBadRequest
if err != nil && err != storage.ErrNotFound {
@ -311,7 +312,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
}
// Grab the device request from storage
deviceReq, err := s.storage.GetDeviceRequest(userCode)
deviceReq, err := s.storage.GetDeviceRequest(ctx, userCode)
if err != nil || s.now().After(deviceReq.Expiry) {
errCode := http.StatusBadRequest
if err != nil && err != storage.ErrNotFound {
@ -322,7 +323,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
return
}
client, err := s.storage.GetClient(deviceReq.ClientID)
client, err := s.storage.GetClient(ctx, deviceReq.ClientID)
if err != nil {
if err != storage.ErrNotFound {
s.logger.ErrorContext(r.Context(), "failed to get client", "err", err)
@ -345,7 +346,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
}
// Grab the device token from storage
old, err := s.storage.GetDeviceToken(deviceReq.DeviceCode)
old, err := s.storage.GetDeviceToken(ctx, deviceReq.DeviceCode)
if err != nil || s.now().After(old.Expiry) {
errCode := http.StatusBadRequest
if err != nil && err != storage.ErrNotFound {
@ -373,7 +374,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
}
// Update refresh token in the storage, store the token and mark as complete
if err := s.storage.UpdateDeviceToken(deviceReq.DeviceCode, updater); err != nil {
if err := s.storage.UpdateDeviceToken(ctx, deviceReq.DeviceCode, updater); err != nil {
s.logger.ErrorContext(r.Context(), "failed to update device token", "err", err)
s.renderError(r, w, http.StatusBadRequest, "")
return
@ -391,6 +392,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
}
func (s *Server) verifyUserCode(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
switch r.Method {
case http.MethodPost:
err := r.ParseForm()
@ -409,7 +411,7 @@ func (s *Server) verifyUserCode(w http.ResponseWriter, r *http.Request) {
userCode = strings.ToUpper(userCode)
// Find the user code in the available requests
deviceRequest, err := s.storage.GetDeviceRequest(userCode)
deviceRequest, err := s.storage.GetDeviceRequest(ctx, userCode)
if err != nil || s.now().After(deviceRequest.Expiry) {
if err != nil && err != storage.ErrNotFound {
s.logger.ErrorContext(r.Context(), "failed to get device request", "err", err)

70
server/handlers.go

@ -32,8 +32,9 @@ const (
)
func (s *Server) handlePublicKeys(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// TODO(ericchiang): Cache this.
keys, err := s.storage.GetKeys()
keys, err := s.storage.GetKeys(ctx)
if err != nil {
s.logger.ErrorContext(r.Context(), "failed to get keys", "err", err)
s.renderError(r, w, http.StatusInternalServerError, "Internal server error.")
@ -135,6 +136,7 @@ func (s *Server) constructDiscovery() discovery {
// handleAuthorization handles the OAuth2 auth endpoint.
func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// Extract the arguments
if err := r.ParseForm(); err != nil {
s.logger.ErrorContext(r.Context(), "failed to parse arguments", "err", err)
@ -144,8 +146,7 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
}
connectorID := r.Form.Get("connector_id")
connectors, err := s.storage.ListConnectors()
connectors, err := s.storage.ListConnectors(ctx)
if err != nil {
s.logger.ErrorContext(r.Context(), "failed to get list of connectors", "err", err)
s.renderError(r, w, http.StatusInternalServerError, "Failed to retrieve connector list.")
@ -219,7 +220,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
return
}
conn, err := s.getConnector(connID)
conn, err := s.getConnector(ctx, connID)
if err != nil {
s.logger.ErrorContext(r.Context(), "Failed to get connector", "err", err)
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist")
@ -314,6 +315,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
}
func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
authID := r.URL.Query().Get("state")
if authID == "" {
s.renderError(r, w, http.StatusBadRequest, "User session error.")
@ -322,7 +324,7 @@ func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) {
backLink := r.URL.Query().Get("back")
authReq, err := s.storage.GetAuthRequest(authID)
authReq, err := s.storage.GetAuthRequest(ctx, authID)
if err != nil {
if err == storage.ErrNotFound {
s.logger.ErrorContext(r.Context(), "invalid 'state' parameter provided", "err", err)
@ -345,7 +347,7 @@ func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) {
return
}
conn, err := s.getConnector(authReq.ConnectorID)
conn, err := s.getConnector(ctx, authReq.ConnectorID)
if err != nil {
s.logger.ErrorContext(r.Context(), "failed to get connector", "connector_id", authReq.ConnectorID, "err", err)
s.renderError(r, w, http.StatusInternalServerError, "Requested resource does not exist.")
@ -390,7 +392,7 @@ func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) {
}
if canSkipApproval {
authReq, err = s.storage.GetAuthRequest(authReq.ID)
authReq, err = s.storage.GetAuthRequest(ctx, authReq.ID)
if err != nil {
s.logger.ErrorContext(r.Context(), "failed to get finalized auth request", "err", err)
s.renderError(r, w, http.StatusInternalServerError, "Login error.")
@ -425,7 +427,7 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request)
return
}
authReq, err := s.storage.GetAuthRequest(authID)
authReq, err := s.storage.GetAuthRequest(ctx, authID)
if err != nil {
if err == storage.ErrNotFound {
s.logger.ErrorContext(r.Context(), "invalid 'state' parameter provided", "err", err)
@ -448,7 +450,7 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request)
return
}
conn, err := s.getConnector(authReq.ConnectorID)
conn, err := s.getConnector(ctx, authReq.ConnectorID)
if err != nil {
s.logger.ErrorContext(r.Context(), "failed to get connector", "connector_id", authReq.ConnectorID, "err", err)
s.renderError(r, w, http.StatusInternalServerError, "Requested resource does not exist.")
@ -490,7 +492,7 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request)
}
if canSkipApproval {
authReq, err = s.storage.GetAuthRequest(authReq.ID)
authReq, err = s.storage.GetAuthRequest(ctx, authReq.ID)
if err != nil {
s.logger.ErrorContext(r.Context(), "failed to get finalized auth request", "err", err)
s.renderError(r, w, http.StatusInternalServerError, "Login error.")
@ -521,7 +523,7 @@ func (s *Server) finalizeLogin(ctx context.Context, identity connector.Identity,
a.ConnectorData = identity.ConnectorData
return a, nil
}
if err := s.storage.UpdateAuthRequest(authReq.ID, updater); err != nil {
if err := s.storage.UpdateAuthRequest(ctx, authReq.ID, updater); err != nil {
return "", false, fmt.Errorf("failed to update auth request: %v", err)
}
@ -545,7 +547,7 @@ func (s *Server) finalizeLogin(ctx context.Context, identity connector.Identity,
if offlineAccessRequested && canRefresh {
// Try to retrieve an existing OfflineSession object for the corresponding user.
session, err := s.storage.GetOfflineSessions(identity.UserID, authReq.ConnectorID)
session, err := s.storage.GetOfflineSessions(ctx, identity.UserID, authReq.ConnectorID)
switch {
case err != nil && err == storage.ErrNotFound:
offlineSessions := storage.OfflineSessions{
@ -563,7 +565,7 @@ func (s *Server) finalizeLogin(ctx context.Context, identity connector.Identity,
}
case err == nil:
// Update existing OfflineSession obj with new RefreshTokenRef.
if err := s.storage.UpdateOfflineSessions(session.UserID, session.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
if err := s.storage.UpdateOfflineSessions(ctx, session.UserID, session.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
if len(identity.ConnectorData) > 0 {
old.ConnectorData = identity.ConnectorData
}
@ -594,6 +596,7 @@ func (s *Server) finalizeLogin(ctx context.Context, identity connector.Identity,
}
func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
macEncoded := r.FormValue("hmac")
if macEncoded == "" {
s.renderError(r, w, http.StatusUnauthorized, "Unauthorized request")
@ -605,7 +608,7 @@ func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) {
return
}
authReq, err := s.storage.GetAuthRequest(r.FormValue("req"))
authReq, err := s.storage.GetAuthRequest(ctx, r.FormValue("req"))
if err != nil {
s.logger.ErrorContext(r.Context(), "failed to get auth request", "err", err)
s.renderError(r, w, http.StatusInternalServerError, "Database error.")
@ -629,7 +632,7 @@ func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
client, err := s.storage.GetClient(authReq.ClientID)
client, err := s.storage.GetClient(ctx, authReq.ClientID)
if err != nil {
s.logger.ErrorContext(r.Context(), "Failed to get client", "client_id", authReq.ClientID, "err", err)
s.renderError(r, w, http.StatusInternalServerError, "Failed to retrieve client.")
@ -654,7 +657,7 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe
return
}
if err := s.storage.DeleteAuthRequest(authReq.ID); err != nil {
if err := s.storage.DeleteAuthRequest(ctx, authReq.ID); err != nil {
if err != storage.ErrNotFound {
s.logger.ErrorContext(r.Context(), "Failed to delete authorization request", "err", err)
s.renderError(r, w, http.StatusInternalServerError, "Internal server error.")
@ -786,6 +789,7 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe
}
func (s *Server) withClientFromStorage(w http.ResponseWriter, r *http.Request, handler func(http.ResponseWriter, *http.Request, storage.Client)) {
ctx := r.Context()
clientID, clientSecret, ok := r.BasicAuth()
if ok {
var err error
@ -802,7 +806,7 @@ func (s *Server) withClientFromStorage(w http.ResponseWriter, r *http.Request, h
clientSecret = r.PostFormValue("client_secret")
}
client, err := s.storage.GetClient(clientID)
client, err := s.storage.GetClient(ctx, clientID)
if err != nil {
if err != storage.ErrNotFound {
s.logger.ErrorContext(r.Context(), "failed to get client", "err", err)
@ -885,7 +889,7 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
return
}
authCode, err := s.storage.GetAuthCode(code)
authCode, err := s.storage.GetAuthCode(ctx, code)
if err != nil || s.now().After(authCode.Expiry) || authCode.ClientID != client.ID {
if err != storage.ErrNotFound {
s.logger.ErrorContext(r.Context(), "failed to get auth code", "err", err)
@ -950,7 +954,7 @@ func (s *Server) exchangeAuthCode(ctx context.Context, w http.ResponseWriter, au
return nil, err
}
if err := s.storage.DeleteAuthCode(authCode.ID); err != nil {
if err := s.storage.DeleteAuthCode(ctx, authCode.ID); err != nil {
s.logger.ErrorContext(ctx, "failed to delete auth code", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return nil, err
@ -960,7 +964,7 @@ func (s *Server) exchangeAuthCode(ctx context.Context, w http.ResponseWriter, au
// Ensure the connector supports refresh tokens.
//
// Connectors like `saml` do not implement RefreshConnector.
conn, err := s.getConnector(authCode.ConnectorID)
conn, err := s.getConnector(ctx, authCode.ConnectorID)
if err != nil {
s.logger.ErrorContext(ctx, "connector not found", "connector_id", authCode.ConnectorID, "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
@ -1016,7 +1020,7 @@ func (s *Server) exchangeAuthCode(ctx context.Context, w http.ResponseWriter, au
defer func() {
if deleteToken {
// Delete newly created refresh token from storage.
if err := s.storage.DeleteRefresh(refresh.ID); err != nil {
if err := s.storage.DeleteRefresh(ctx, refresh.ID); err != nil {
s.logger.ErrorContext(ctx, "failed to delete refresh token", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
@ -1032,7 +1036,7 @@ func (s *Server) exchangeAuthCode(ctx context.Context, w http.ResponseWriter, au
}
// Try to retrieve an existing OfflineSession object for the corresponding user.
if session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID); err != nil {
if session, err := s.storage.GetOfflineSessions(ctx, refresh.Claims.UserID, refresh.ConnectorID); err != nil {
if err != storage.ErrNotFound {
s.logger.ErrorContext(ctx, "failed to get offline session", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
@ -1057,7 +1061,7 @@ func (s *Server) exchangeAuthCode(ctx context.Context, w http.ResponseWriter, au
} else {
if oldTokenRef, ok := session.Refresh[tokenRef.ClientID]; ok {
// Delete old refresh token from storage.
if err := s.storage.DeleteRefresh(oldTokenRef.ID); err != nil && err != storage.ErrNotFound {
if err := s.storage.DeleteRefresh(ctx, oldTokenRef.ID); err != nil && err != storage.ErrNotFound {
s.logger.ErrorContext(ctx, "failed to delete refresh token", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
deleteToken = true
@ -1066,7 +1070,7 @@ func (s *Server) exchangeAuthCode(ctx context.Context, w http.ResponseWriter, au
}
// Update existing OfflineSession obj with new RefreshTokenRef.
if err := s.storage.UpdateOfflineSessions(session.UserID, session.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
if err := s.storage.UpdateOfflineSessions(ctx, session.UserID, session.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
old.Refresh[tokenRef.ClientID] = &tokenRef
return old, nil
}); err != nil {
@ -1140,7 +1144,7 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli
continue
}
isTrusted, err := s.validateCrossClientTrust(r.Context(), client.ID, peerID)
isTrusted, err := s.validateCrossClientTrust(ctx, client.ID, peerID)
if err != nil {
s.tokenErrHelper(w, errInvalidClient, fmt.Sprintf("Error validating cross client trust %v.", err), http.StatusBadRequest)
return
@ -1165,7 +1169,7 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli
// Which connector
connID := s.passwordConnector
conn, err := s.getConnector(connID)
conn, err := s.getConnector(ctx, connID)
if err != nil {
s.tokenErrHelper(w, errInvalidRequest, "Requested connector does not exist.", http.StatusBadRequest)
return
@ -1201,14 +1205,14 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli
Groups: identity.Groups,
}
accessToken, _, err := s.newAccessToken(r.Context(), client.ID, claims, scopes, nonce, connID)
accessToken, _, err := s.newAccessToken(ctx, client.ID, claims, scopes, nonce, connID)
if err != nil {
s.logger.ErrorContext(r.Context(), "password grant failed to create new access token", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}
idToken, expiry, err := s.newIDToken(r.Context(), client.ID, claims, scopes, nonce, accessToken, "", connID)
idToken, expiry, err := s.newIDToken(ctx, client.ID, claims, scopes, nonce, accessToken, "", connID)
if err != nil {
s.logger.ErrorContext(r.Context(), "password grant failed to create new ID token", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
@ -1268,7 +1272,7 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli
defer func() {
if deleteToken {
// Delete newly created refresh token from storage.
if err := s.storage.DeleteRefresh(refresh.ID); err != nil {
if err := s.storage.DeleteRefresh(ctx, refresh.ID); err != nil {
s.logger.ErrorContext(r.Context(), "failed to delete refresh token", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
@ -1284,7 +1288,7 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli
}
// Try to retrieve an existing OfflineSession object for the corresponding user.
if session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID); err != nil {
if session, err := s.storage.GetOfflineSessions(ctx, refresh.Claims.UserID, refresh.ConnectorID); err != nil {
if err != storage.ErrNotFound {
s.logger.ErrorContext(r.Context(), "failed to get offline session", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
@ -1310,7 +1314,7 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli
} else {
if oldTokenRef, ok := session.Refresh[tokenRef.ClientID]; ok {
// Delete old refresh token from storage.
if err := s.storage.DeleteRefresh(oldTokenRef.ID); err != nil {
if err := s.storage.DeleteRefresh(ctx, oldTokenRef.ID); err != nil {
if err == storage.ErrNotFound {
s.logger.Warn("database inconsistent, refresh token missing", "token_id", oldTokenRef.ID)
} else {
@ -1323,7 +1327,7 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli
}
// Update existing OfflineSession obj with new RefreshTokenRef.
if err := s.storage.UpdateOfflineSessions(session.UserID, session.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
if err := s.storage.UpdateOfflineSessions(ctx, session.UserID, session.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
old.Refresh[tokenRef.ClientID] = &tokenRef
old.ConnectorData = identity.ConnectorData
return old, nil
@ -1371,7 +1375,7 @@ func (s *Server) handleTokenExchange(w http.ResponseWriter, r *http.Request, cli
return
}
conn, err := s.getConnector(connID)
conn, err := s.getConnector(ctx, connID)
if err != nil {
s.logger.ErrorContext(r.Context(), "failed to get connector", "err", err)
s.tokenErrHelper(w, errInvalidRequest, "Requested connector does not exist.", http.StatusBadRequest)

8
server/handlers_test.go

@ -138,7 +138,7 @@ type emptyStorage struct {
storage.Storage
}
func (*emptyStorage) GetAuthRequest(string) (storage.AuthRequest, error) {
func (*emptyStorage) GetAuthRequest(context.Context, string) (storage.AuthRequest, error) {
return storage.AuthRequest{}, storage.ErrNotFound
}
@ -407,7 +407,7 @@ func TestHandlePassword(t *testing.T) {
err := json.Unmarshal(rr.Body.Bytes(), &ref)
require.NoError(t, err)
newSess, err := s.storage.GetOfflineSessions("0-385-28089-0", "test")
newSess, err := s.storage.GetOfflineSessions(ctx, "0-385-28089-0", "test")
if tc.offlineSessionCreated {
require.NoError(t, err)
require.Equal(t, `{"test": "true"}`, string(newSess.ConnectorData))
@ -562,7 +562,7 @@ func TestHandlePasswordLoginWithSkipApproval(t *testing.T) {
cb, _ := url.Parse(resp.Header.Get("Location"))
require.Equal(t, tc.expectedRes, cb.Path)
offlineSession, err := s.storage.GetOfflineSessions("0-385-28089-0", connID)
offlineSession, err := s.storage.GetOfflineSessions(ctx, "0-385-28089-0", connID)
if tc.offlineSessionCreated {
require.NoError(t, err)
require.NotEmpty(t, offlineSession)
@ -701,7 +701,7 @@ func TestHandleConnectorCallbackWithSkipApproval(t *testing.T) {
cb, _ := url.Parse(resp.Header.Get("Location"))
require.Equal(t, tc.expectedRes, cb.Path)
offlineSession, err := s.storage.GetOfflineSessions("0-385-28089-0", connID)
offlineSession, err := s.storage.GetOfflineSessions(ctx, "0-385-28089-0", connID)
if tc.offlineSessionCreated {
require.NoError(t, err)
require.NotEmpty(t, offlineSession)

2
server/introspectionhandler.go

@ -263,7 +263,7 @@ func (s *Server) introspectAccessToken(ctx context.Context, token string) (*Intr
return nil, newIntrospectInternalServerError()
}
client, err := s.storage.GetClient(clientID)
client, err := s.storage.GetClient(ctx, clientID)
if err != nil {
s.logger.ErrorContext(ctx, "error while fetching client from storage", "err", err.Error())
return nil, newIntrospectInternalServerError()

13
server/oauth2.go

@ -351,7 +351,7 @@ func genSubject(userID string, connID string) (string, error) {
}
func (s *Server) newIDToken(ctx context.Context, clientID string, claims storage.Claims, scopes []string, nonce, accessToken, code, connID string) (idToken string, expiry time.Time, err error) {
keys, err := s.storage.GetKeys()
keys, err := s.storage.GetKeys(ctx)
if err != nil {
s.logger.ErrorContext(ctx, "failed to get keys", "err", err)
return "", expiry, err
@ -453,6 +453,7 @@ func (s *Server) newIDToken(ctx context.Context, clientID string, claims storage
// parse the initial request from the OAuth2 client.
func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthRequest, error) {
ctx := r.Context()
if err := r.ParseForm(); err != nil {
return nil, newDisplayedErr(http.StatusBadRequest, "Failed to parse request.")
}
@ -477,7 +478,7 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques
codeChallengeMethod = codeChallengeMethodPlain
}
client, err := s.storage.GetClient(clientID)
client, err := s.storage.GetClient(ctx, clientID)
if err != nil {
if err == storage.ErrNotFound {
return nil, newDisplayedErr(http.StatusNotFound, "Invalid client_id (%q).", clientID)
@ -499,7 +500,7 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques
}
if connectorID != "" {
connectors, err := s.storage.ListConnectors()
connectors, err := s.storage.ListConnectors(ctx)
if err != nil {
s.logger.ErrorContext(r.Context(), "failed to list connectors", "err", err)
return nil, newRedirectedErr(errServerError, "Unable to retrieve connectors")
@ -634,7 +635,7 @@ func (s *Server) validateCrossClientTrust(ctx context.Context, clientID, peerID
if peerID == clientID {
return true, nil
}
peer, err := s.storage.GetClient(peerID)
peer, err := s.storage.GetClient(ctx, peerID)
if err != nil {
if err != storage.ErrNotFound {
s.logger.ErrorContext(ctx, "failed to get client", "err", err)
@ -707,7 +708,7 @@ type storageKeySet struct {
storage.Storage
}
func (s *storageKeySet) VerifySignature(_ context.Context, jwt string) (payload []byte, err error) {
func (s *storageKeySet) VerifySignature(ctx context.Context, jwt string) (payload []byte, err error) {
jws, err := jose.ParseSigned(jwt, []jose.SignatureAlgorithm{jose.RS256, jose.RS384, jose.RS512, jose.ES256, jose.ES384, jose.ES512})
if err != nil {
return nil, err
@ -719,7 +720,7 @@ func (s *storageKeySet) VerifySignature(_ context.Context, jwt string) (payload
break
}
skeys, err := s.Storage.GetKeys()
skeys, err := s.Storage.GetKeys(ctx)
if err != nil {
return nil, err
}

2
server/oauth2_test.go

@ -599,7 +599,7 @@ func TestValidRedirectURI(t *testing.T) {
func TestStorageKeySet(t *testing.T) {
s := memory.New(logger)
if err := s.UpdateKeys(func(keys storage.Keys) (storage.Keys, error) {
if err := s.UpdateKeys(context.TODO(), func(keys storage.Keys) (storage.Keys, error) {
keys.SigningKey = &jose.JSONWebKey{
Key: testKey,
KeyID: "testkey",

10
server/refreshhandlers.go

@ -84,7 +84,7 @@ func (s *Server) getRefreshTokenFromStorage(ctx context.Context, clientID *strin
refreshCtx := refreshContext{requestToken: token}
// Get RefreshToken
refresh, err := s.storage.GetRefresh(token.RefreshId)
refresh, err := s.storage.GetRefresh(ctx, token.RefreshId)
if err != nil {
if err != storage.ErrNotFound {
s.logger.ErrorContext(ctx, "failed to get refresh token", "err", err)
@ -126,14 +126,14 @@ func (s *Server) getRefreshTokenFromStorage(ctx context.Context, clientID *strin
refreshCtx.storageToken = &refresh
// Get Connector
refreshCtx.connector, err = s.getConnector(refresh.ConnectorID)
refreshCtx.connector, err = s.getConnector(ctx, refresh.ConnectorID)
if err != nil {
s.logger.ErrorContext(ctx, "connector not found", "connector_id", refresh.ConnectorID, "err", err)
return nil, newInternalServerError()
}
// Get Connector Data
session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID)
session, err := s.storage.GetOfflineSessions(ctx, refresh.Claims.UserID, refresh.ConnectorID)
switch {
case err != nil:
if err != storage.ErrNotFound {
@ -223,7 +223,7 @@ func (s *Server) updateOfflineSession(ctx context.Context, refresh *storage.Refr
// Update LastUsed time stamp in refresh token reference object
// in offline session for the user.
err := s.storage.UpdateOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID, offlineSessionUpdater)
err := s.storage.UpdateOfflineSessions(ctx, refresh.Claims.UserID, refresh.ConnectorID, offlineSessionUpdater)
if err != nil {
s.logger.ErrorContext(ctx, "failed to update offline session", "err", err)
return newInternalServerError()
@ -314,7 +314,7 @@ func (s *Server) updateRefreshToken(ctx context.Context, rCtx *refreshContext) (
}
// Update refresh token in the storage.
err := s.storage.UpdateRefreshToken(rCtx.storageToken.ID, refreshTokenUpdater)
err := s.storage.UpdateRefreshToken(ctx, rCtx.storageToken.ID, refreshTokenUpdater)
if err != nil {
s.logger.ErrorContext(ctx, "failed to update refresh token", "err", err)
return nil, ident, newInternalServerError()

4
server/rotation.go

@ -95,7 +95,7 @@ func (s *Server) startKeyRotation(ctx context.Context, strategy rotationStrategy
}
func (k keyRotator) rotate() error {
keys, err := k.GetKeys()
keys, err := k.GetKeys(context.Background())
if err != nil && err != storage.ErrNotFound {
return fmt.Errorf("get keys: %v", err)
}
@ -128,7 +128,7 @@ func (k keyRotator) rotate() error {
}
var nextRotation time.Time
err = k.Storage.UpdateKeys(func(keys storage.Keys) (storage.Keys, error) {
err = k.Storage.UpdateKeys(context.Background(), func(keys storage.Keys) (storage.Keys, error) {
tNow := k.now()
// if you are running multiple instances of dex, another instance

5
server/rotation_test.go

@ -1,6 +1,7 @@
package server
import (
"context"
"io"
"log/slog"
"sort"
@ -14,7 +15,7 @@ import (
)
func signingKeyID(t *testing.T, s storage.Storage) string {
keys, err := s.GetKeys()
keys, err := s.GetKeys(context.TODO())
if err != nil {
t.Fatal(err)
}
@ -22,7 +23,7 @@ func signingKeyID(t *testing.T, s storage.Storage) string {
}
func verificationKeyIDs(t *testing.T, s storage.Storage) (ids []string) {
keys, err := s.GetKeys()
keys, err := s.GetKeys(context.TODO())
if err != nil {
t.Fatal(err)
}

16
server/server.go

@ -316,7 +316,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
// Retrieves connector objects in backend storage. This list includes the static connectors
// defined in the ConfigMap and dynamic connectors retrieved from the storage.
storageConnectors, err := c.Storage.ListConnectors()
storageConnectors, err := c.Storage.ListConnectors(ctx)
if err != nil {
return nil, fmt.Errorf("server: failed to list connector objects from storage: %v", err)
}
@ -535,7 +535,7 @@ type passwordDB struct {
}
func (db passwordDB) Login(ctx context.Context, s connector.Scopes, email, password string) (connector.Identity, bool, error) {
p, err := db.s.GetPassword(email)
p, err := db.s.GetPassword(ctx, email)
if err != nil {
if err != storage.ErrNotFound {
return connector.Identity{}, false, fmt.Errorf("get password: %v", err)
@ -560,7 +560,7 @@ func (db passwordDB) Login(ctx context.Context, s connector.Scopes, email, passw
func (db passwordDB) Refresh(ctx context.Context, s connector.Scopes, identity connector.Identity) (connector.Identity, error) {
// If the user has been deleted, the refresh token will be rejected.
p, err := db.s.GetPassword(identity.Email)
p, err := db.s.GetPassword(ctx, identity.Email)
if err != nil {
if err == storage.ErrNotFound {
return connector.Identity{}, errors.New("user not found")
@ -602,13 +602,13 @@ type keyCacher struct {
keys atomic.Value // Always holds nil or type *storage.Keys.
}
func (k *keyCacher) GetKeys() (storage.Keys, error) {
func (k *keyCacher) GetKeys(ctx context.Context) (storage.Keys, error) {
keys, ok := k.keys.Load().(*storage.Keys)
if ok && keys != nil && k.now().Before(keys.NextRotation) {
return *keys, nil
}
storageKeys, err := k.Storage.GetKeys()
storageKeys, err := k.Storage.GetKeys(ctx)
if err != nil {
return storageKeys, err
}
@ -626,7 +626,7 @@ func (s *Server) startGarbageCollection(ctx context.Context, frequency time.Dura
case <-ctx.Done():
return
case <-time.After(frequency):
if r, err := s.storage.GarbageCollect(now()); err != nil {
if r, err := s.storage.GarbageCollect(ctx, now()); err != nil {
s.logger.ErrorContext(ctx, "garbage collection failed", "err", err)
} else if !r.IsEmpty() {
s.logger.InfoContext(ctx, "garbage collection run, delete auth",
@ -719,8 +719,8 @@ func (s *Server) OpenConnector(conn storage.Connector) (Connector, error) {
// getConnector retrieves the connector object with the given id from the storage
// and updates the connector list for server if necessary.
func (s *Server) getConnector(id string) (Connector, error) {
storageConnector, err := s.storage.GetConnector(id)
func (s *Server) getConnector(ctx context.Context, id string) (Connector, error) {
storageConnector, err := s.storage.GetConnector(ctx, id)
if err != nil {
return Connector{}, fmt.Errorf("failed to get connector object from storage: %v", err)
}

14
server/server_test.go

@ -875,7 +875,7 @@ func TestOAuth2CodeFlow(t *testing.T) {
t.Fatal(err)
}
tokens, err := s.storage.ListRefreshTokens()
tokens, err := s.storage.ListRefreshTokens(ctx)
if err != nil {
t.Fatalf("failed to get existed refresh token: %v", err)
}
@ -1369,15 +1369,15 @@ type storageWithKeysTrigger struct {
f func()
}
func (s storageWithKeysTrigger) GetKeys() (storage.Keys, error) {
func (s storageWithKeysTrigger) GetKeys(ctx context.Context) (storage.Keys, error) {
s.f()
return s.Storage.GetKeys()
return s.Storage.GetKeys(ctx)
}
func TestKeyCacher(t *testing.T) {
tNow := time.Now()
now := func() time.Time { return tNow }
ctx := context.TODO()
s := memory.New(logger)
tests := []struct {
@ -1390,7 +1390,7 @@ func TestKeyCacher(t *testing.T) {
},
{
before: func() {
s.UpdateKeys(func(old storage.Keys) (storage.Keys, error) {
s.UpdateKeys(ctx, func(old storage.Keys) (storage.Keys, error) {
old.NextRotation = tNow.Add(time.Minute)
return old, nil
})
@ -1410,7 +1410,7 @@ func TestKeyCacher(t *testing.T) {
{
before: func() {
tNow = tNow.Add(time.Hour)
s.UpdateKeys(func(old storage.Keys) (storage.Keys, error) {
s.UpdateKeys(ctx, func(old storage.Keys) (storage.Keys, error) {
old.NextRotation = tNow.Add(time.Minute)
return old, nil
})
@ -1428,7 +1428,7 @@ func TestKeyCacher(t *testing.T) {
for i, tc := range tests {
gotCall = false
tc.before()
s.GetKeys()
s.GetKeys(context.TODO())
if gotCall != tc.wantCallToStorage {
t.Errorf("case %d: expected call to storage=%t got call to storage=%t", i, tc.wantCallToStorage, gotCall)
}

120
storage/conformance/conformance.go

@ -148,7 +148,7 @@ func testAuthRequestCRUD(t *testing.T, s storage.Storage) {
t.Fatalf("failed creating auth request: %v", err)
}
if err := s.UpdateAuthRequest(a1.ID, func(old storage.AuthRequest) (storage.AuthRequest, error) {
if err := s.UpdateAuthRequest(ctx, a1.ID, func(old storage.AuthRequest) (storage.AuthRequest, error) {
old.Claims = identity
old.ConnectorID = "connID"
return old, nil
@ -156,7 +156,7 @@ func testAuthRequestCRUD(t *testing.T, s storage.Storage) {
t.Fatalf("failed to update auth request: %v", err)
}
got, err := s.GetAuthRequest(a1.ID)
got, err := s.GetAuthRequest(ctx, a1.ID)
if err != nil {
t.Fatalf("failed to get auth req: %v", err)
}
@ -168,15 +168,15 @@ func testAuthRequestCRUD(t *testing.T, s storage.Storage) {
t.Fatalf("storage does not support PKCE, wanted challenge=%#v got %#v", codeChallenge, got.PKCE)
}
if err := s.DeleteAuthRequest(a1.ID); err != nil {
if err := s.DeleteAuthRequest(ctx, a1.ID); err != nil {
t.Fatalf("failed to delete auth request: %v", err)
}
if err := s.DeleteAuthRequest(a2.ID); err != nil {
if err := s.DeleteAuthRequest(ctx, a2.ID); err != nil {
t.Fatalf("failed to delete auth request: %v", err)
}
_, err = s.GetAuthRequest(a1.ID)
_, err = s.GetAuthRequest(ctx, a1.ID)
mustBeErrNotFound(t, "auth request", err)
}
@ -234,7 +234,7 @@ func testAuthCodeCRUD(t *testing.T, s storage.Storage) {
t.Fatalf("failed creating auth code: %v", err)
}
got, err := s.GetAuthCode(a1.ID)
got, err := s.GetAuthCode(ctx, a1.ID)
if err != nil {
t.Fatalf("failed to get auth code: %v", err)
}
@ -246,15 +246,15 @@ func testAuthCodeCRUD(t *testing.T, s storage.Storage) {
t.Errorf("auth code retrieved from storage did not match: %s", diff)
}
if err := s.DeleteAuthCode(a1.ID); err != nil {
if err := s.DeleteAuthCode(ctx, a1.ID); err != nil {
t.Fatalf("delete auth code: %v", err)
}
if err := s.DeleteAuthCode(a2.ID); err != nil {
if err := s.DeleteAuthCode(ctx, a2.ID); err != nil {
t.Fatalf("delete auth code: %v", err)
}
_, err = s.GetAuthCode(a1.ID)
_, err = s.GetAuthCode(ctx, a1.ID)
mustBeErrNotFound(t, "auth code", err)
}
@ -268,7 +268,7 @@ func testClientCRUD(t *testing.T, s storage.Storage) {
Name: "dex client",
LogoURL: "https://goo.gl/JIyzIC",
}
err := s.DeleteClient(id1)
err := s.DeleteClient(ctx, id1)
mustBeErrNotFound(t, "client", err)
if err := s.CreateClient(ctx, c1); err != nil {
@ -293,7 +293,7 @@ func testClientCRUD(t *testing.T, s storage.Storage) {
}
getAndCompare := func(_ string, want storage.Client) {
gc, err := s.GetClient(id1)
gc, err := s.GetClient(ctx, id1)
if err != nil {
t.Errorf("get client: %v", err)
return
@ -306,7 +306,7 @@ func testClientCRUD(t *testing.T, s storage.Storage) {
getAndCompare(id1, c1)
newSecret := "barfoo"
err = s.UpdateClient(id1, func(old storage.Client) (storage.Client, error) {
err = s.UpdateClient(ctx, id1, func(old storage.Client) (storage.Client, error) {
old.Secret = newSecret
return old, nil
})
@ -316,15 +316,15 @@ func testClientCRUD(t *testing.T, s storage.Storage) {
c1.Secret = newSecret
getAndCompare(id1, c1)
if err := s.DeleteClient(id1); err != nil {
if err := s.DeleteClient(ctx, id1); err != nil {
t.Fatalf("delete client: %v", err)
}
if err := s.DeleteClient(id2); err != nil {
if err := s.DeleteClient(ctx, id2); err != nil {
t.Fatalf("delete client: %v", err)
}
_, err = s.GetClient(id1)
_, err = s.GetClient(ctx, id1)
mustBeErrNotFound(t, "client", err)
}
@ -359,7 +359,7 @@ func testRefreshTokenCRUD(t *testing.T, s storage.Storage) {
mustBeErrAlreadyExists(t, "refresh token", err)
getAndCompare := func(id string, want storage.RefreshToken) {
gr, err := s.GetRefresh(id)
gr, err := s.GetRefresh(ctx, id)
if err != nil {
t.Errorf("get refresh: %v", err)
return
@ -419,7 +419,7 @@ func testRefreshTokenCRUD(t *testing.T, s storage.Storage) {
r.LastUsed = updatedAt
return r, nil
}
if err := s.UpdateRefreshToken(id, updater); err != nil {
if err := s.UpdateRefreshToken(ctx, id, updater); err != nil {
t.Errorf("failed to update refresh token: %v", err)
}
refresh.Token = "spam"
@ -429,15 +429,15 @@ func testRefreshTokenCRUD(t *testing.T, s storage.Storage) {
// Ensure that updating the first token doesn't impact the second. Issue #847.
getAndCompare(id2, refresh2)
if err := s.DeleteRefresh(id); err != nil {
if err := s.DeleteRefresh(ctx, id); err != nil {
t.Fatalf("failed to delete refresh request: %v", err)
}
if err := s.DeleteRefresh(id2); err != nil {
if err := s.DeleteRefresh(ctx, id2); err != nil {
t.Fatalf("failed to delete refresh request: %v", err)
}
_, err = s.GetRefresh(id)
_, err = s.GetRefresh(ctx, id)
mustBeErrNotFound(t, "refresh token", err)
}
@ -485,7 +485,7 @@ func testPasswordCRUD(t *testing.T, s storage.Storage) {
}
getAndCompare := func(id string, want storage.Password) {
gr, err := s.GetPassword(id)
gr, err := s.GetPassword(ctx, id)
if err != nil {
t.Errorf("get password %q: %v", id, err)
return
@ -498,7 +498,7 @@ func testPasswordCRUD(t *testing.T, s storage.Storage) {
getAndCompare("jane@example.com", password1)
getAndCompare("JANE@example.com", password1) // Emails should be case insensitive
if err := s.UpdatePassword(password1.Email, func(old storage.Password) (storage.Password, error) {
if err := s.UpdatePassword(ctx, password1.Email, func(old storage.Password) (storage.Password, error) {
old.Username = "jane doe"
return old, nil
}); err != nil {
@ -512,7 +512,7 @@ func testPasswordCRUD(t *testing.T, s storage.Storage) {
passwordList = append(passwordList, password1, password2)
listAndCompare := func(want []storage.Password) {
passwords, err := s.ListPasswords()
passwords, err := s.ListPasswords(ctx)
if err != nil {
t.Errorf("list password: %v", err)
return
@ -526,15 +526,15 @@ func testPasswordCRUD(t *testing.T, s storage.Storage) {
listAndCompare(passwordList)
if err := s.DeletePassword(password1.Email); err != nil {
if err := s.DeletePassword(ctx, password1.Email); err != nil {
t.Fatalf("failed to delete password: %v", err)
}
if err := s.DeletePassword(password2.Email); err != nil {
if err := s.DeletePassword(ctx, password2.Email); err != nil {
t.Fatalf("failed to delete password: %v", err)
}
_, err = s.GetPassword(password1.Email)
_, err = s.GetPassword(ctx, password1.Email)
mustBeErrNotFound(t, "password", err)
}
@ -571,7 +571,7 @@ func testOfflineSessionCRUD(t *testing.T, s storage.Storage) {
}
getAndCompare := func(userID string, connID string, want storage.OfflineSessions) {
gr, err := s.GetOfflineSessions(userID, connID)
gr, err := s.GetOfflineSessions(ctx, userID, connID)
if err != nil {
t.Errorf("get offline session: %v", err)
return
@ -592,7 +592,7 @@ func testOfflineSessionCRUD(t *testing.T, s storage.Storage) {
}
session1.Refresh[tokenRef.ClientID] = &tokenRef
if err := s.UpdateOfflineSessions(session1.UserID, session1.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
if err := s.UpdateOfflineSessions(ctx, session1.UserID, session1.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
old.Refresh[tokenRef.ClientID] = &tokenRef
return old, nil
}); err != nil {
@ -601,15 +601,15 @@ func testOfflineSessionCRUD(t *testing.T, s storage.Storage) {
getAndCompare(userID1, "Conn1", session1)
if err := s.DeleteOfflineSessions(session1.UserID, session1.ConnID); err != nil {
if err := s.DeleteOfflineSessions(ctx, session1.UserID, session1.ConnID); err != nil {
t.Fatalf("failed to delete offline session: %v", err)
}
if err := s.DeleteOfflineSessions(session2.UserID, session2.ConnID); err != nil {
if err := s.DeleteOfflineSessions(ctx, session2.UserID, session2.ConnID); err != nil {
t.Fatalf("failed to delete offline session: %v", err)
}
_, err = s.GetOfflineSessions(session1.UserID, session1.ConnID)
_, err = s.GetOfflineSessions(ctx, session1.UserID, session1.ConnID)
mustBeErrNotFound(t, "offline session", err)
}
@ -646,7 +646,7 @@ func testConnectorCRUD(t *testing.T, s storage.Storage) {
}
getAndCompare := func(id string, want storage.Connector) {
gr, err := s.GetConnector(id)
gr, err := s.GetConnector(ctx, id)
if err != nil {
t.Errorf("get connector: %v", err)
return
@ -660,7 +660,7 @@ func testConnectorCRUD(t *testing.T, s storage.Storage) {
getAndCompare(id1, c1)
if err := s.UpdateConnector(c1.ID, func(old storage.Connector) (storage.Connector, error) {
if err := s.UpdateConnector(ctx, c1.ID, func(old storage.Connector) (storage.Connector, error) {
old.Type = "oidc"
return old, nil
}); err != nil {
@ -672,7 +672,7 @@ func testConnectorCRUD(t *testing.T, s storage.Storage) {
connectorList := []storage.Connector{c1, c2}
listAndCompare := func(want []storage.Connector) {
connectors, err := s.ListConnectors()
connectors, err := s.ListConnectors(ctx)
if err != nil {
t.Errorf("list connectors: %v", err)
return
@ -690,21 +690,23 @@ func testConnectorCRUD(t *testing.T, s storage.Storage) {
}
listAndCompare(connectorList)
if err := s.DeleteConnector(c1.ID); err != nil {
if err := s.DeleteConnector(ctx, c1.ID); err != nil {
t.Fatalf("failed to delete connector: %v", err)
}
if err := s.DeleteConnector(c2.ID); err != nil {
if err := s.DeleteConnector(ctx, c2.ID); err != nil {
t.Fatalf("failed to delete connector: %v", err)
}
_, err = s.GetConnector(c1.ID)
_, err = s.GetConnector(ctx, c1.ID)
mustBeErrNotFound(t, "connector", err)
}
func testKeysCRUD(t *testing.T, s storage.Storage) {
ctx := context.TODO()
updateAndCompare := func(k storage.Keys) {
err := s.UpdateKeys(func(oldKeys storage.Keys) (storage.Keys, error) {
err := s.UpdateKeys(ctx, func(oldKeys storage.Keys) (storage.Keys, error) {
return k, nil
})
if err != nil {
@ -712,7 +714,7 @@ func testKeysCRUD(t *testing.T, s storage.Storage) {
return
}
if got, err := s.GetKeys(); err != nil {
if got, err := s.GetKeys(ctx); err != nil {
t.Errorf("failed to get keys: %v", err)
} else {
got.NextRotation = got.NextRotation.UTC()
@ -786,24 +788,24 @@ func testGC(t *testing.T, s storage.Storage) {
}
for _, tz := range []*time.Location{time.UTC, est, pst} {
result, err := s.GarbageCollect(expiry.Add(-time.Hour).In(tz))
result, err := s.GarbageCollect(ctx, expiry.Add(-time.Hour).In(tz))
if err != nil {
t.Errorf("garbage collection failed: %v", err)
} else if result.AuthCodes != 0 || result.AuthRequests != 0 {
t.Errorf("expected no garbage collection results, got %#v", result)
}
if _, err := s.GetAuthCode(c.ID); err != nil {
if _, err := s.GetAuthCode(ctx, c.ID); err != nil {
t.Errorf("expected to be able to get auth code after GC: %v", err)
}
}
if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil {
if r, err := s.GarbageCollect(ctx, expiry.Add(time.Hour)); err != nil {
t.Errorf("garbage collection failed: %v", err)
} else if r.AuthCodes != 1 {
t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthCodes)
}
if _, err := s.GetAuthCode(c.ID); err == nil {
if _, err := s.GetAuthCode(ctx, c.ID); err == nil {
t.Errorf("expected auth code to be GC'd")
} else if err != storage.ErrNotFound {
t.Errorf("expected storage.ErrNotFound, got %v", err)
@ -837,24 +839,24 @@ func testGC(t *testing.T, s storage.Storage) {
}
for _, tz := range []*time.Location{time.UTC, est, pst} {
result, err := s.GarbageCollect(expiry.Add(-time.Hour).In(tz))
result, err := s.GarbageCollect(ctx, expiry.Add(-time.Hour).In(tz))
if err != nil {
t.Errorf("garbage collection failed: %v", err)
} else if result.AuthCodes != 0 || result.AuthRequests != 0 {
t.Errorf("expected no garbage collection results, got %#v", result)
}
if _, err := s.GetAuthRequest(a.ID); err != nil {
if _, err := s.GetAuthRequest(ctx, a.ID); err != nil {
t.Errorf("expected to be able to get auth request after GC: %v", err)
}
}
if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil {
if r, err := s.GarbageCollect(ctx, expiry.Add(time.Hour)); err != nil {
t.Errorf("garbage collection failed: %v", err)
} else if r.AuthRequests != 1 {
t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthRequests)
}
if _, err := s.GetAuthRequest(a.ID); err == nil {
if _, err := s.GetAuthRequest(ctx, a.ID); err == nil {
t.Errorf("expected auth request to be GC'd")
} else if err != storage.ErrNotFound {
t.Errorf("expected storage.ErrNotFound, got %v", err)
@ -874,23 +876,23 @@ func testGC(t *testing.T, s storage.Storage) {
}
for _, tz := range []*time.Location{time.UTC, est, pst} {
result, err := s.GarbageCollect(expiry.Add(-time.Hour).In(tz))
result, err := s.GarbageCollect(ctx, expiry.Add(-time.Hour).In(tz))
if err != nil {
t.Errorf("garbage collection failed: %v", err)
} else if result.DeviceRequests != 0 {
t.Errorf("expected no device garbage collection results, got %#v", result)
}
if _, err := s.GetDeviceRequest(d.UserCode); err != nil {
if _, err := s.GetDeviceRequest(ctx, d.UserCode); err != nil {
t.Errorf("expected to be able to get auth request after GC: %v", err)
}
}
if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil {
if r, err := s.GarbageCollect(ctx, expiry.Add(time.Hour)); err != nil {
t.Errorf("garbage collection failed: %v", err)
} else if r.DeviceRequests != 1 {
t.Errorf("expected to garbage collect 1 device request, got %d", r.DeviceRequests)
}
if _, err := s.GetDeviceRequest(d.UserCode); err == nil {
if _, err := s.GetDeviceRequest(ctx, d.UserCode); err == nil {
t.Errorf("expected device request to be GC'd")
} else if err != storage.ErrNotFound {
t.Errorf("expected storage.ErrNotFound, got %v", err)
@ -914,23 +916,23 @@ func testGC(t *testing.T, s storage.Storage) {
}
for _, tz := range []*time.Location{time.UTC, est, pst} {
result, err := s.GarbageCollect(expiry.Add(-time.Hour).In(tz))
result, err := s.GarbageCollect(ctx, expiry.Add(-time.Hour).In(tz))
if err != nil {
t.Errorf("garbage collection failed: %v", err)
} else if result.DeviceTokens != 0 {
t.Errorf("expected no device token garbage collection results, got %#v", result)
}
if _, err := s.GetDeviceToken(dt.DeviceCode); err != nil {
if _, err := s.GetDeviceToken(ctx, dt.DeviceCode); err != nil {
t.Errorf("expected to be able to get device token after GC: %v", err)
}
}
if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil {
if r, err := s.GarbageCollect(ctx, expiry.Add(time.Hour)); err != nil {
t.Errorf("garbage collection failed: %v", err)
} else if r.DeviceTokens != 1 {
t.Errorf("expected to garbage collect 1 device token, got %d", r.DeviceTokens)
}
if _, err := s.GetDeviceToken(dt.DeviceCode); err == nil {
if _, err := s.GetDeviceToken(ctx, dt.DeviceCode); err == nil {
t.Errorf("expected device token to be GC'd")
} else if err != storage.ErrNotFound {
t.Errorf("expected storage.ErrNotFound, got %v", err)
@ -969,7 +971,7 @@ func testTimezones(t *testing.T, s storage.Storage) {
if err := s.CreateAuthCode(ctx, c); err != nil {
t.Fatalf("failed creating auth code: %v", err)
}
got, err := s.GetAuthCode(c.ID)
got, err := s.GetAuthCode(ctx, c.ID)
if err != nil {
t.Fatalf("failed to get auth code: %v", err)
}
@ -1003,7 +1005,7 @@ func testDeviceRequestCRUD(t *testing.T, s storage.Storage) {
err := s.CreateDeviceRequest(ctx, d1)
mustBeErrAlreadyExists(t, "device request", err)
got, err := s.GetDeviceRequest(d1.UserCode)
got, err := s.GetDeviceRequest(ctx, d1.UserCode)
if err != nil {
t.Fatalf("failed to get device request: %v", err)
}
@ -1041,7 +1043,7 @@ func testDeviceTokenCRUD(t *testing.T, s storage.Storage) {
mustBeErrAlreadyExists(t, "device token", err)
// Update the device token, simulate a redemption
if err := s.UpdateDeviceToken(d1.DeviceCode, func(old storage.DeviceToken) (storage.DeviceToken, error) {
if err := s.UpdateDeviceToken(ctx, d1.DeviceCode, func(old storage.DeviceToken) (storage.DeviceToken, error) {
old.Token = "token data"
old.Status = "complete"
return old, nil
@ -1050,7 +1052,7 @@ func testDeviceTokenCRUD(t *testing.T, s storage.Storage) {
}
// Retrieve the device token
got, err := s.GetDeviceToken(d1.DeviceCode)
got, err := s.GetDeviceToken(ctx, d1.DeviceCode)
if err != nil {
t.Fatalf("failed to get device token: %v", err)
}

17
storage/conformance/transactions.go

@ -42,9 +42,9 @@ func testClientConcurrentUpdate(t *testing.T, s storage.Storage) {
var err1, err2 error
err1 = s.UpdateClient(c.ID, func(old storage.Client) (storage.Client, error) {
err1 = s.UpdateClient(ctx, c.ID, func(old storage.Client) (storage.Client, error) {
old.Secret = "new secret 1"
err2 = s.UpdateClient(c.ID, func(old storage.Client) (storage.Client, error) {
err2 = s.UpdateClient(ctx, c.ID, func(old storage.Client) (storage.Client, error) {
old.Secret = "new secret 2"
return old, nil
})
@ -87,9 +87,9 @@ func testAuthRequestConcurrentUpdate(t *testing.T, s storage.Storage) {
var err1, err2 error
err1 = s.UpdateAuthRequest(a.ID, func(old storage.AuthRequest) (storage.AuthRequest, error) {
err1 = s.UpdateAuthRequest(ctx, a.ID, func(old storage.AuthRequest) (storage.AuthRequest, error) {
old.State = "state 1"
err2 = s.UpdateAuthRequest(a.ID, func(old storage.AuthRequest) (storage.AuthRequest, error) {
err2 = s.UpdateAuthRequest(ctx, a.ID, func(old storage.AuthRequest) (storage.AuthRequest, error) {
old.State = "state 2"
return old, nil
})
@ -121,9 +121,9 @@ func testPasswordConcurrentUpdate(t *testing.T, s storage.Storage) {
var err1, err2 error
err1 = s.UpdatePassword(password.Email, func(old storage.Password) (storage.Password, error) {
err1 = s.UpdatePassword(ctx, password.Email, func(old storage.Password) (storage.Password, error) {
old.Username = "user 1"
err2 = s.UpdatePassword(password.Email, func(old storage.Password) (storage.Password, error) {
err2 = s.UpdatePassword(ctx, password.Email, func(old storage.Password) (storage.Password, error) {
old.Username = "user 2"
return old, nil
})
@ -163,8 +163,9 @@ func testKeysConcurrentUpdate(t *testing.T, s storage.Storage) {
var err1, err2 error
err1 = s.UpdateKeys(func(old storage.Keys) (storage.Keys, error) {
err2 = s.UpdateKeys(func(old storage.Keys) (storage.Keys, error) {
ctx := context.TODO()
err1 = s.UpdateKeys(ctx, func(old storage.Keys) (storage.Keys, error) {
err2 = s.UpdateKeys(ctx, func(old storage.Keys) (storage.Keys, error) {
return keys1, nil
})
return keys2, nil

8
storage/ent/client/authcode.go

@ -34,8 +34,8 @@ func (d *Database) CreateAuthCode(ctx context.Context, code storage.AuthCode) er
}
// GetAuthCode extracts an auth code from the database by id.
func (d *Database) GetAuthCode(id string) (storage.AuthCode, error) {
authCode, err := d.client.AuthCode.Get(context.TODO(), id)
func (d *Database) GetAuthCode(ctx context.Context, id string) (storage.AuthCode, error) {
authCode, err := d.client.AuthCode.Get(ctx, id)
if err != nil {
return storage.AuthCode{}, convertDBError("get auth code: %w", err)
}
@ -43,8 +43,8 @@ func (d *Database) GetAuthCode(id string) (storage.AuthCode, error) {
}
// DeleteAuthCode deletes an auth code from the database by id.
func (d *Database) DeleteAuthCode(id string) error {
err := d.client.AuthCode.DeleteOneID(id).Exec(context.TODO())
func (d *Database) DeleteAuthCode(ctx context.Context, id string) error {
err := d.client.AuthCode.DeleteOneID(id).Exec(ctx)
if err != nil {
return convertDBError("delete auth code: %w", err)
}

12
storage/ent/client/authrequest.go

@ -40,8 +40,8 @@ func (d *Database) CreateAuthRequest(ctx context.Context, authRequest storage.Au
}
// GetAuthRequest extracts an auth request from the database by id.
func (d *Database) GetAuthRequest(id string) (storage.AuthRequest, error) {
authRequest, err := d.client.AuthRequest.Get(context.TODO(), id)
func (d *Database) GetAuthRequest(ctx context.Context, id string) (storage.AuthRequest, error) {
authRequest, err := d.client.AuthRequest.Get(ctx, id)
if err != nil {
return storage.AuthRequest{}, convertDBError("get auth request: %w", err)
}
@ -49,8 +49,8 @@ func (d *Database) GetAuthRequest(id string) (storage.AuthRequest, error) {
}
// DeleteAuthRequest deletes an auth request from the database by id.
func (d *Database) DeleteAuthRequest(id string) error {
err := d.client.AuthRequest.DeleteOneID(id).Exec(context.TODO())
func (d *Database) DeleteAuthRequest(ctx context.Context, id string) error {
err := d.client.AuthRequest.DeleteOneID(id).Exec(ctx)
if err != nil {
return convertDBError("delete auth request: %w", err)
}
@ -58,8 +58,8 @@ func (d *Database) DeleteAuthRequest(id string) error {
}
// UpdateAuthRequest changes an auth request by id using an updater function and saves it to the database.
func (d *Database) UpdateAuthRequest(id string, updater func(old storage.AuthRequest) (storage.AuthRequest, error)) error {
tx, err := d.BeginTx(context.TODO())
func (d *Database) UpdateAuthRequest(ctx context.Context, id string, updater func(old storage.AuthRequest) (storage.AuthRequest, error)) error {
tx, err := d.BeginTx(ctx)
if err != nil {
return fmt.Errorf("update auth request tx: %w", err)
}

20
storage/ent/client/client.go

@ -24,8 +24,8 @@ func (d *Database) CreateClient(ctx context.Context, client storage.Client) erro
}
// ListClients extracts an array of oauth2 clients from the database.
func (d *Database) ListClients() ([]storage.Client, error) {
clients, err := d.client.OAuth2Client.Query().All(context.TODO())
func (d *Database) ListClients(ctx context.Context) ([]storage.Client, error) {
clients, err := d.client.OAuth2Client.Query().All(ctx)
if err != nil {
return nil, convertDBError("list clients: %w", err)
}
@ -38,8 +38,8 @@ func (d *Database) ListClients() ([]storage.Client, error) {
}
// GetClient extracts an oauth2 client from the database by id.
func (d *Database) GetClient(id string) (storage.Client, error) {
client, err := d.client.OAuth2Client.Get(context.TODO(), id)
func (d *Database) GetClient(ctx context.Context, id string) (storage.Client, error) {
client, err := d.client.OAuth2Client.Get(ctx, id)
if err != nil {
return storage.Client{}, convertDBError("get client: %w", err)
}
@ -47,8 +47,8 @@ func (d *Database) GetClient(id string) (storage.Client, error) {
}
// DeleteClient deletes an oauth2 client from the database by id.
func (d *Database) DeleteClient(id string) error {
err := d.client.OAuth2Client.DeleteOneID(id).Exec(context.TODO())
func (d *Database) DeleteClient(ctx context.Context, id string) error {
err := d.client.OAuth2Client.DeleteOneID(id).Exec(ctx)
if err != nil {
return convertDBError("delete client: %w", err)
}
@ -56,13 +56,13 @@ func (d *Database) DeleteClient(id string) error {
}
// UpdateClient changes an oauth2 client by id using an updater function and saves it to the database.
func (d *Database) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error {
tx, err := d.BeginTx(context.TODO())
func (d *Database) UpdateClient(ctx context.Context, id string, updater func(old storage.Client) (storage.Client, error)) error {
tx, err := d.BeginTx(ctx)
if err != nil {
return convertDBError("update client tx: %w", err)
}
client, err := tx.OAuth2Client.Get(context.TODO(), id)
client, err := tx.OAuth2Client.Get(ctx, id)
if err != nil {
return rollback(tx, "update client database: %w", err)
}
@ -79,7 +79,7 @@ func (d *Database) UpdateClient(id string, updater func(old storage.Client) (sto
SetLogoURL(newClient.LogoURL).
SetRedirectUris(newClient.RedirectURIs).
SetTrustedPeers(newClient.TrustedPeers).
Save(context.TODO())
Save(ctx)
if err != nil {
return rollback(tx, "update client uploading: %w", err)
}

20
storage/ent/client/connector.go

@ -22,8 +22,8 @@ func (d *Database) CreateConnector(ctx context.Context, connector storage.Connec
}
// ListConnectors extracts an array of connectors from the database.
func (d *Database) ListConnectors() ([]storage.Connector, error) {
connectors, err := d.client.Connector.Query().All(context.TODO())
func (d *Database) ListConnectors(ctx context.Context) ([]storage.Connector, error) {
connectors, err := d.client.Connector.Query().All(ctx)
if err != nil {
return nil, convertDBError("list connectors: %w", err)
}
@ -36,8 +36,8 @@ func (d *Database) ListConnectors() ([]storage.Connector, error) {
}
// GetConnector extracts a connector from the database by id.
func (d *Database) GetConnector(id string) (storage.Connector, error) {
connector, err := d.client.Connector.Get(context.TODO(), id)
func (d *Database) GetConnector(ctx context.Context, id string) (storage.Connector, error) {
connector, err := d.client.Connector.Get(ctx, id)
if err != nil {
return storage.Connector{}, convertDBError("get connector: %w", err)
}
@ -45,8 +45,8 @@ func (d *Database) GetConnector(id string) (storage.Connector, error) {
}
// DeleteConnector deletes a connector from the database by id.
func (d *Database) DeleteConnector(id string) error {
err := d.client.Connector.DeleteOneID(id).Exec(context.TODO())
func (d *Database) DeleteConnector(ctx context.Context, id string) error {
err := d.client.Connector.DeleteOneID(id).Exec(ctx)
if err != nil {
return convertDBError("delete connector: %w", err)
}
@ -54,13 +54,13 @@ func (d *Database) DeleteConnector(id string) error {
}
// UpdateConnector changes a connector by id using an updater function and saves it to the database.
func (d *Database) UpdateConnector(id string, updater func(old storage.Connector) (storage.Connector, error)) error {
tx, err := d.BeginTx(context.TODO())
func (d *Database) UpdateConnector(ctx context.Context, id string, updater func(old storage.Connector) (storage.Connector, error)) error {
tx, err := d.BeginTx(ctx)
if err != nil {
return convertDBError("update connector tx: %w", err)
}
connector, err := tx.Connector.Get(context.TODO(), id)
connector, err := tx.Connector.Get(ctx, id)
if err != nil {
return rollback(tx, "update connector database: %w", err)
}
@ -75,7 +75,7 @@ func (d *Database) UpdateConnector(id string, updater func(old storage.Connector
SetType(newConnector.Type).
SetResourceVersion(newConnector.ResourceVersion).
SetConfig(newConnector.Config).
Save(context.TODO())
Save(ctx)
if err != nil {
return rollback(tx, "update connector uploading: %w", err)
}

4
storage/ent/client/devicerequest.go

@ -25,10 +25,10 @@ func (d *Database) CreateDeviceRequest(ctx context.Context, request storage.Devi
}
// GetDeviceRequest extracts a device request from the database by user code.
func (d *Database) GetDeviceRequest(userCode string) (storage.DeviceRequest, error) {
func (d *Database) GetDeviceRequest(ctx context.Context, userCode string) (storage.DeviceRequest, error) {
deviceRequest, err := d.client.DeviceRequest.Query().
Where(devicerequest.UserCode(userCode)).
Only(context.TODO())
Only(ctx)
if err != nil {
return storage.DeviceRequest{}, convertDBError("get device request: %w", err)
}

12
storage/ent/client/devicetoken.go

@ -27,10 +27,10 @@ func (d *Database) CreateDeviceToken(ctx context.Context, token storage.DeviceTo
}
// GetDeviceToken extracts a token from the database by device code.
func (d *Database) GetDeviceToken(deviceCode string) (storage.DeviceToken, error) {
func (d *Database) GetDeviceToken(ctx context.Context, deviceCode string) (storage.DeviceToken, error) {
deviceToken, err := d.client.DeviceToken.Query().
Where(devicetoken.DeviceCode(deviceCode)).
Only(context.TODO())
Only(ctx)
if err != nil {
return storage.DeviceToken{}, convertDBError("get device token: %w", err)
}
@ -38,15 +38,15 @@ func (d *Database) GetDeviceToken(deviceCode string) (storage.DeviceToken, error
}
// UpdateDeviceToken changes a token by device code using an updater function and saves it to the database.
func (d *Database) UpdateDeviceToken(deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error {
tx, err := d.BeginTx(context.TODO())
func (d *Database) UpdateDeviceToken(ctx context.Context, deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error {
tx, err := d.BeginTx(ctx)
if err != nil {
return convertDBError("update device token tx: %w", err)
}
token, err := tx.DeviceToken.Query().
Where(devicetoken.DeviceCode(deviceCode)).
Only(context.TODO())
Only(ctx)
if err != nil {
return rollback(tx, "update device token database: %w", err)
}
@ -67,7 +67,7 @@ func (d *Database) UpdateDeviceToken(deviceCode string, updater func(old storage
SetStatus(newToken.Status).
SetCodeChallenge(newToken.PKCE.CodeChallenge).
SetCodeChallengeMethod(newToken.PKCE.CodeChallengeMethod).
Save(context.TODO())
Save(ctx)
if err != nil {
return rollback(tx, "update device token uploading: %w", err)
}

18
storage/ent/client/keys.go

@ -8,8 +8,8 @@ import (
"github.com/dexidp/dex/storage/ent/db"
)
func getKeys(client *db.KeysClient) (storage.Keys, error) {
rawKeys, err := client.Get(context.TODO(), keysRowID)
func getKeys(ctx context.Context, client *db.KeysClient) (storage.Keys, error) {
rawKeys, err := client.Get(ctx, keysRowID)
if err != nil {
return storage.Keys{}, convertDBError("get keys: %w", err)
}
@ -18,20 +18,20 @@ func getKeys(client *db.KeysClient) (storage.Keys, error) {
}
// GetKeys returns signing keys, public keys and verification keys from the database.
func (d *Database) GetKeys() (storage.Keys, error) {
return getKeys(d.client.Keys)
func (d *Database) GetKeys(ctx context.Context) (storage.Keys, error) {
return getKeys(ctx, d.client.Keys)
}
// UpdateKeys rotates keys using updater function.
func (d *Database) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) error {
func (d *Database) UpdateKeys(ctx context.Context, updater func(old storage.Keys) (storage.Keys, error)) error {
firstUpdate := false
tx, err := d.BeginTx(context.TODO())
tx, err := d.BeginTx(ctx)
if err != nil {
return convertDBError("update keys tx: %w", err)
}
storageKeys, err := getKeys(tx.Keys)
storageKeys, err := getKeys(ctx, tx.Keys)
if err != nil {
if !errors.Is(err, storage.ErrNotFound) {
return rollback(tx, "update keys get: %w", err)
@ -53,7 +53,7 @@ func (d *Database) UpdateKeys(updater func(old storage.Keys) (storage.Keys, erro
SetSigningKey(*newKeys.SigningKey).
SetSigningKeyPub(*newKeys.SigningKeyPub).
SetVerificationKeys(newKeys.VerificationKeys).
Save(context.TODO())
Save(ctx)
if err != nil {
return rollback(tx, "create keys: %w", err)
}
@ -68,7 +68,7 @@ func (d *Database) UpdateKeys(updater func(old storage.Keys) (storage.Keys, erro
SetSigningKey(*newKeys.SigningKey).
SetSigningKeyPub(*newKeys.SigningKeyPub).
SetVerificationKeys(newKeys.VerificationKeys).
Exec(context.TODO())
Exec(ctx)
if err != nil {
return rollback(tx, "update keys uploading: %w", err)
}

10
storage/ent/client/main.go

@ -70,13 +70,13 @@ func (d *Database) BeginTx(ctx context.Context) (*db.Tx, error) {
}
// GarbageCollect removes expired entities from the database.
func (d *Database) GarbageCollect(now time.Time) (storage.GCResult, error) {
func (d *Database) GarbageCollect(ctx context.Context, now time.Time) (storage.GCResult, error) {
result := storage.GCResult{}
utcNow := now.UTC()
q, err := d.client.AuthRequest.Delete().
Where(authrequest.ExpiryLT(utcNow)).
Exec(context.TODO())
Exec(ctx)
if err != nil {
return result, convertDBError("gc auth request: %w", err)
}
@ -84,7 +84,7 @@ func (d *Database) GarbageCollect(now time.Time) (storage.GCResult, error) {
q, err = d.client.AuthCode.Delete().
Where(authcode.ExpiryLT(utcNow)).
Exec(context.TODO())
Exec(ctx)
if err != nil {
return result, convertDBError("gc auth code: %w", err)
}
@ -92,7 +92,7 @@ func (d *Database) GarbageCollect(now time.Time) (storage.GCResult, error) {
q, err = d.client.DeviceRequest.Delete().
Where(devicerequest.ExpiryLT(utcNow)).
Exec(context.TODO())
Exec(ctx)
if err != nil {
return result, convertDBError("gc device request: %w", err)
}
@ -100,7 +100,7 @@ func (d *Database) GarbageCollect(now time.Time) (storage.GCResult, error) {
q, err = d.client.DeviceToken.Delete().
Where(devicetoken.ExpiryLT(utcNow)).
Exec(context.TODO())
Exec(ctx)
if err != nil {
return result, convertDBError("gc device token: %w", err)
}

16
storage/ent/client/offlinesession.go

@ -30,10 +30,10 @@ func (d *Database) CreateOfflineSessions(ctx context.Context, session storage.Of
}
// GetOfflineSessions extracts an offline session from the database by user id and connector id.
func (d *Database) GetOfflineSessions(userID, connID string) (storage.OfflineSessions, error) {
func (d *Database) GetOfflineSessions(ctx context.Context, userID, connID string) (storage.OfflineSessions, error) {
id := offlineSessionID(userID, connID, d.hasher)
offlineSession, err := d.client.OfflineSession.Get(context.TODO(), id)
offlineSession, err := d.client.OfflineSession.Get(ctx, id)
if err != nil {
return storage.OfflineSessions{}, convertDBError("get offline session: %w", err)
}
@ -41,10 +41,10 @@ func (d *Database) GetOfflineSessions(userID, connID string) (storage.OfflineSes
}
// DeleteOfflineSessions deletes an offline session from the database by user id and connector id.
func (d *Database) DeleteOfflineSessions(userID, connID string) error {
func (d *Database) DeleteOfflineSessions(ctx context.Context, userID, connID string) error {
id := offlineSessionID(userID, connID, d.hasher)
err := d.client.OfflineSession.DeleteOneID(id).Exec(context.TODO())
err := d.client.OfflineSession.DeleteOneID(id).Exec(ctx)
if err != nil {
return convertDBError("delete offline session: %w", err)
}
@ -52,15 +52,15 @@ func (d *Database) DeleteOfflineSessions(userID, connID string) error {
}
// UpdateOfflineSessions changes an offline session by user id and connector id using an updater function.
func (d *Database) UpdateOfflineSessions(userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error {
func (d *Database) UpdateOfflineSessions(ctx context.Context, userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error {
id := offlineSessionID(userID, connID, d.hasher)
tx, err := d.BeginTx(context.TODO())
tx, err := d.BeginTx(ctx)
if err != nil {
return convertDBError("update offline session tx: %w", err)
}
offlineSession, err := tx.OfflineSession.Get(context.TODO(), id)
offlineSession, err := tx.OfflineSession.Get(ctx, id)
if err != nil {
return rollback(tx, "update offline session database: %w", err)
}
@ -80,7 +80,7 @@ func (d *Database) UpdateOfflineSessions(userID string, connID string, updater f
SetConnID(newOfflineSession.ConnID).
SetConnectorData(newOfflineSession.ConnectorData).
SetRefresh(encodedRefresh).
Save(context.TODO())
Save(ctx)
if err != nil {
return rollback(tx, "update offline session uploading: %w", err)
}

20
storage/ent/client/password.go

@ -23,8 +23,8 @@ func (d *Database) CreatePassword(ctx context.Context, password storage.Password
}
// ListPasswords extracts an array of passwords from the database.
func (d *Database) ListPasswords() ([]storage.Password, error) {
passwords, err := d.client.Password.Query().All(context.TODO())
func (d *Database) ListPasswords(ctx context.Context) ([]storage.Password, error) {
passwords, err := d.client.Password.Query().All(ctx)
if err != nil {
return nil, convertDBError("list passwords: %w", err)
}
@ -37,11 +37,11 @@ func (d *Database) ListPasswords() ([]storage.Password, error) {
}
// GetPassword extracts a password from the database by email.
func (d *Database) GetPassword(email string) (storage.Password, error) {
func (d *Database) GetPassword(ctx context.Context, email string) (storage.Password, error) {
email = strings.ToLower(email)
passwordFromStorage, err := d.client.Password.Query().
Where(password.Email(email)).
Only(context.TODO())
Only(ctx)
if err != nil {
return storage.Password{}, convertDBError("get password: %w", err)
}
@ -49,11 +49,11 @@ func (d *Database) GetPassword(email string) (storage.Password, error) {
}
// DeletePassword deletes a password from the database by email.
func (d *Database) DeletePassword(email string) error {
func (d *Database) DeletePassword(ctx context.Context, email string) error {
email = strings.ToLower(email)
_, err := d.client.Password.Delete().
Where(password.Email(email)).
Exec(context.TODO())
Exec(ctx)
if err != nil {
return convertDBError("delete password: %w", err)
}
@ -61,17 +61,17 @@ func (d *Database) DeletePassword(email string) error {
}
// UpdatePassword changes a password by email using an updater function and saves it to the database.
func (d *Database) UpdatePassword(email string, updater func(old storage.Password) (storage.Password, error)) error {
func (d *Database) UpdatePassword(ctx context.Context, email string, updater func(old storage.Password) (storage.Password, error)) error {
email = strings.ToLower(email)
tx, err := d.BeginTx(context.TODO())
tx, err := d.BeginTx(ctx)
if err != nil {
return convertDBError("update connector tx: %w", err)
}
passwordToUpdate, err := tx.Password.Query().
Where(password.Email(email)).
Only(context.TODO())
Only(ctx)
if err != nil {
return rollback(tx, "update password database: %w", err)
}
@ -87,7 +87,7 @@ func (d *Database) UpdatePassword(email string, updater func(old storage.Passwor
SetHash(newPassword.Hash).
SetUsername(newPassword.Username).
SetUserID(newPassword.UserID).
Save(context.TODO())
Save(ctx)
if err != nil {
return rollback(tx, "update password uploading: %w", err)
}

20
storage/ent/client/refreshtoken.go

@ -34,8 +34,8 @@ func (d *Database) CreateRefresh(ctx context.Context, refresh storage.RefreshTok
}
// ListRefreshTokens extracts an array of refresh tokens from the database.
func (d *Database) ListRefreshTokens() ([]storage.RefreshToken, error) {
refreshTokens, err := d.client.RefreshToken.Query().All(context.TODO())
func (d *Database) ListRefreshTokens(ctx context.Context) ([]storage.RefreshToken, error) {
refreshTokens, err := d.client.RefreshToken.Query().All(ctx)
if err != nil {
return nil, convertDBError("list refresh tokens: %w", err)
}
@ -48,8 +48,8 @@ func (d *Database) ListRefreshTokens() ([]storage.RefreshToken, error) {
}
// GetRefresh extracts a refresh token from the database by id.
func (d *Database) GetRefresh(id string) (storage.RefreshToken, error) {
refreshToken, err := d.client.RefreshToken.Get(context.TODO(), id)
func (d *Database) GetRefresh(ctx context.Context, id string) (storage.RefreshToken, error) {
refreshToken, err := d.client.RefreshToken.Get(ctx, id)
if err != nil {
return storage.RefreshToken{}, convertDBError("get refresh token: %w", err)
}
@ -57,8 +57,8 @@ func (d *Database) GetRefresh(id string) (storage.RefreshToken, error) {
}
// DeleteRefresh deletes a refresh token from the database by id.
func (d *Database) DeleteRefresh(id string) error {
err := d.client.RefreshToken.DeleteOneID(id).Exec(context.TODO())
func (d *Database) DeleteRefresh(ctx context.Context, id string) error {
err := d.client.RefreshToken.DeleteOneID(id).Exec(ctx)
if err != nil {
return convertDBError("delete refresh token: %w", err)
}
@ -66,13 +66,13 @@ func (d *Database) DeleteRefresh(id string) error {
}
// UpdateRefreshToken changes a refresh token by id using an updater function and saves it to the database.
func (d *Database) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
tx, err := d.BeginTx(context.TODO())
func (d *Database) UpdateRefreshToken(ctx context.Context, id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
tx, err := d.BeginTx(ctx)
if err != nil {
return convertDBError("update refresh token tx: %w", err)
}
token, err := tx.RefreshToken.Get(context.TODO(), id)
token, err := tx.RefreshToken.Get(ctx, id)
if err != nil {
return rollback(tx, "update refresh token database: %w", err)
}
@ -99,7 +99,7 @@ func (d *Database) UpdateRefreshToken(id string, updater func(old storage.Refres
// Save utc time into database because ent doesn't support comparing dates with different timezones
SetLastUsed(newtToken.LastUsed.UTC()).
SetCreatedAt(newtToken.CreatedAt.UTC()).
Save(context.TODO())
Save(ctx)
if err != nil {
return rollback(tx, "update refresh token uploading: %w", err)
}

121
storage/etcd/etcd.go

@ -40,8 +40,8 @@ func (c *conn) Close() error {
return c.db.Close()
}
func (c *conn) GarbageCollect(now time.Time) (result storage.GCResult, err error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
func (c *conn) GarbageCollect(ctx context.Context, now time.Time) (result storage.GCResult, err error) {
ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout)
defer cancel()
authRequests, err := c.listAuthRequests(ctx)
if err != nil {
@ -113,8 +113,9 @@ func (c *conn) CreateAuthRequest(ctx context.Context, a storage.AuthRequest) err
return c.txnCreate(ctx, keyID(authRequestPrefix, a.ID), fromStorageAuthRequest(a))
}
func (c *conn) GetAuthRequest(id string) (a storage.AuthRequest, err error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
func (c *conn) GetAuthRequest(ctx context.Context, id string) (a storage.AuthRequest, err error) {
// TODO: Add this to other funcs??
ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout)
defer cancel()
var req AuthRequest
if err = c.getKey(ctx, keyID(authRequestPrefix, id), &req); err != nil {
@ -123,8 +124,8 @@ func (c *conn) GetAuthRequest(id string) (a storage.AuthRequest, err error) {
return toStorageAuthRequest(req), nil
}
func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) (storage.AuthRequest, error)) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
func (c *conn) UpdateAuthRequest(ctx context.Context, id string, updater func(a storage.AuthRequest) (storage.AuthRequest, error)) error {
ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout)
defer cancel()
return c.txnUpdate(ctx, keyID(authRequestPrefix, id), func(currentValue []byte) ([]byte, error) {
var current AuthRequest
@ -141,8 +142,8 @@ func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest)
})
}
func (c *conn) DeleteAuthRequest(id string) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
func (c *conn) DeleteAuthRequest(ctx context.Context, id string) error {
ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout)
defer cancel()
return c.deleteKey(ctx, keyID(authRequestPrefix, id))
}
@ -151,8 +152,8 @@ func (c *conn) CreateAuthCode(ctx context.Context, a storage.AuthCode) error {
return c.txnCreate(ctx, keyID(authCodePrefix, a.ID), fromStorageAuthCode(a))
}
func (c *conn) GetAuthCode(id string) (a storage.AuthCode, err error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
func (c *conn) GetAuthCode(ctx context.Context, id string) (a storage.AuthCode, err error) {
ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout)
defer cancel()
var ac AuthCode
err = c.getKey(ctx, keyID(authCodePrefix, id), &ac)
@ -162,8 +163,8 @@ func (c *conn) GetAuthCode(id string) (a storage.AuthCode, err error) {
return a, err
}
func (c *conn) DeleteAuthCode(id string) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
func (c *conn) DeleteAuthCode(ctx context.Context, id string) error {
ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout)
defer cancel()
return c.deleteKey(ctx, keyID(authCodePrefix, id))
}
@ -172,8 +173,8 @@ func (c *conn) CreateRefresh(ctx context.Context, r storage.RefreshToken) error
return c.txnCreate(ctx, keyID(refreshTokenPrefix, r.ID), fromStorageRefreshToken(r))
}
func (c *conn) GetRefresh(id string) (r storage.RefreshToken, err error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
func (c *conn) GetRefresh(ctx context.Context, id string) (r storage.RefreshToken, err error) {
ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout)
defer cancel()
var token RefreshToken
if err = c.getKey(ctx, keyID(refreshTokenPrefix, id), &token); err != nil {
@ -182,8 +183,8 @@ func (c *conn) GetRefresh(id string) (r storage.RefreshToken, err error) {
return toStorageRefreshToken(token), nil
}
func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
func (c *conn) UpdateRefreshToken(ctx context.Context, id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout)
defer cancel()
return c.txnUpdate(ctx, keyID(refreshTokenPrefix, id), func(currentValue []byte) ([]byte, error) {
var current RefreshToken
@ -200,14 +201,14 @@ func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshTok
})
}
func (c *conn) DeleteRefresh(id string) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
func (c *conn) DeleteRefresh(ctx context.Context, id string) error {
ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout)
defer cancel()
return c.deleteKey(ctx, keyID(refreshTokenPrefix, id))
}
func (c *conn) ListRefreshTokens() (tokens []storage.RefreshToken, err error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
func (c *conn) ListRefreshTokens(ctx context.Context) (tokens []storage.RefreshToken, err error) {
ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout)
defer cancel()
res, err := c.db.Get(ctx, refreshTokenPrefix, clientv3.WithPrefix())
if err != nil {
@ -227,15 +228,15 @@ func (c *conn) CreateClient(ctx context.Context, cli storage.Client) error {
return c.txnCreate(ctx, keyID(clientPrefix, cli.ID), cli)
}
func (c *conn) GetClient(id string) (cli storage.Client, err error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
func (c *conn) GetClient(ctx context.Context, id string) (cli storage.Client, err error) {
ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout)
defer cancel()
err = c.getKey(ctx, keyID(clientPrefix, id), &cli)
return cli, err
}
func (c *conn) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
func (c *conn) UpdateClient(ctx context.Context, id string, updater func(old storage.Client) (storage.Client, error)) error {
ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout)
defer cancel()
return c.txnUpdate(ctx, keyID(clientPrefix, id), func(currentValue []byte) ([]byte, error) {
var current storage.Client
@ -252,14 +253,14 @@ func (c *conn) UpdateClient(id string, updater func(old storage.Client) (storage
})
}
func (c *conn) DeleteClient(id string) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
func (c *conn) DeleteClient(ctx context.Context, id string) error {
ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout)
defer cancel()
return c.deleteKey(ctx, keyID(clientPrefix, id))
}
func (c *conn) ListClients() (clients []storage.Client, err error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
func (c *conn) ListClients(ctx context.Context) (clients []storage.Client, err error) {
ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout)
defer cancel()
res, err := c.db.Get(ctx, clientPrefix, clientv3.WithPrefix())
if err != nil {
@ -279,15 +280,15 @@ func (c *conn) CreatePassword(ctx context.Context, p storage.Password) error {
return c.txnCreate(ctx, passwordPrefix+strings.ToLower(p.Email), p)
}
func (c *conn) GetPassword(email string) (p storage.Password, err error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
func (c *conn) GetPassword(ctx context.Context, email string) (p storage.Password, err error) {
ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout)
defer cancel()
err = c.getKey(ctx, keyEmail(passwordPrefix, email), &p)
return p, err
}
func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (storage.Password, error)) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
func (c *conn) UpdatePassword(ctx context.Context, email string, updater func(p storage.Password) (storage.Password, error)) error {
ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout)
defer cancel()
return c.txnUpdate(ctx, keyEmail(passwordPrefix, email), func(currentValue []byte) ([]byte, error) {
var current storage.Password
@ -304,14 +305,14 @@ func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (st
})
}
func (c *conn) DeletePassword(email string) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
func (c *conn) DeletePassword(ctx context.Context, email string) error {
ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout)
defer cancel()
return c.deleteKey(ctx, keyEmail(passwordPrefix, email))
}
func (c *conn) ListPasswords() (passwords []storage.Password, err error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
func (c *conn) ListPasswords(ctx context.Context) (passwords []storage.Password, err error) {
ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout)
defer cancel()
res, err := c.db.Get(ctx, passwordPrefix, clientv3.WithPrefix())
if err != nil {
@ -331,8 +332,8 @@ func (c *conn) CreateOfflineSessions(ctx context.Context, s storage.OfflineSessi
return c.txnCreate(ctx, keySession(s.UserID, s.ConnID), fromStorageOfflineSessions(s))
}
func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
func (c *conn) UpdateOfflineSessions(ctx context.Context, userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error {
ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout)
defer cancel()
return c.txnUpdate(ctx, keySession(userID, connID), func(currentValue []byte) ([]byte, error) {
var current OfflineSessions
@ -349,8 +350,8 @@ func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(
})
}
func (c *conn) GetOfflineSessions(userID string, connID string) (s storage.OfflineSessions, err error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
func (c *conn) GetOfflineSessions(ctx context.Context, userID string, connID string) (s storage.OfflineSessions, err error) {
ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout)
defer cancel()
var os OfflineSessions
if err = c.getKey(ctx, keySession(userID, connID), &os); err != nil {
@ -359,8 +360,8 @@ func (c *conn) GetOfflineSessions(userID string, connID string) (s storage.Offli
return toStorageOfflineSessions(os), nil
}
func (c *conn) DeleteOfflineSessions(userID string, connID string) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
func (c *conn) DeleteOfflineSessions(ctx context.Context, userID string, connID string) error {
ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout)
defer cancel()
return c.deleteKey(ctx, keySession(userID, connID))
}
@ -369,15 +370,15 @@ func (c *conn) CreateConnector(ctx context.Context, connector storage.Connector)
return c.txnCreate(ctx, keyID(connectorPrefix, connector.ID), connector)
}
func (c *conn) GetConnector(id string) (conn storage.Connector, err error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
func (c *conn) GetConnector(ctx context.Context, id string) (conn storage.Connector, err error) {
ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout)
defer cancel()
err = c.getKey(ctx, keyID(connectorPrefix, id), &conn)
return conn, err
}
func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (storage.Connector, error)) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
func (c *conn) UpdateConnector(ctx context.Context, id string, updater func(s storage.Connector) (storage.Connector, error)) error {
ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout)
defer cancel()
return c.txnUpdate(ctx, keyID(connectorPrefix, id), func(currentValue []byte) ([]byte, error) {
var current storage.Connector
@ -394,14 +395,14 @@ func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (sto
})
}
func (c *conn) DeleteConnector(id string) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
func (c *conn) DeleteConnector(ctx context.Context, id string) error {
ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout)
defer cancel()
return c.deleteKey(ctx, keyID(connectorPrefix, id))
}
func (c *conn) ListConnectors() (connectors []storage.Connector, err error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
func (c *conn) ListConnectors(ctx context.Context) (connectors []storage.Connector, err error) {
ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout)
defer cancel()
res, err := c.db.Get(ctx, connectorPrefix, clientv3.WithPrefix())
if err != nil {
@ -417,8 +418,8 @@ func (c *conn) ListConnectors() (connectors []storage.Connector, err error) {
return connectors, nil
}
func (c *conn) GetKeys() (keys storage.Keys, err error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
func (c *conn) GetKeys(ctx context.Context) (keys storage.Keys, err error) {
ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout)
defer cancel()
res, err := c.db.Get(ctx, keysName)
if err != nil {
@ -430,8 +431,8 @@ func (c *conn) GetKeys() (keys storage.Keys, err error) {
return keys, err
}
func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
func (c *conn) UpdateKeys(ctx context.Context, updater func(old storage.Keys) (storage.Keys, error)) error {
ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout)
defer cancel()
return c.txnUpdate(ctx, keysName, func(currentValue []byte) ([]byte, error) {
var current storage.Keys
@ -560,8 +561,8 @@ func (c *conn) CreateDeviceRequest(ctx context.Context, d storage.DeviceRequest)
return c.txnCreate(ctx, keyID(deviceRequestPrefix, d.UserCode), fromStorageDeviceRequest(d))
}
func (c *conn) GetDeviceRequest(userCode string) (r storage.DeviceRequest, err error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
func (c *conn) GetDeviceRequest(ctx context.Context, userCode string) (r storage.DeviceRequest, err error) {
ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout)
defer cancel()
var dr DeviceRequest
if err = c.getKey(ctx, keyID(deviceRequestPrefix, userCode), &dr); err == nil {
@ -589,8 +590,8 @@ func (c *conn) CreateDeviceToken(ctx context.Context, t storage.DeviceToken) err
return c.txnCreate(ctx, keyID(deviceTokenPrefix, t.DeviceCode), fromStorageDeviceToken(t))
}
func (c *conn) GetDeviceToken(deviceCode string) (t storage.DeviceToken, err error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
func (c *conn) GetDeviceToken(ctx context.Context, deviceCode string) (t storage.DeviceToken, err error) {
ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout)
defer cancel()
var dt DeviceToken
if err = c.getKey(ctx, keyID(deviceTokenPrefix, deviceCode), &dt); err == nil {
@ -614,8 +615,8 @@ func (c *conn) listDeviceTokens(ctx context.Context) (deviceTokens []DeviceToken
return deviceTokens, nil
}
func (c *conn) UpdateDeviceToken(deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
func (c *conn) UpdateDeviceToken(ctx context.Context, deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error {
ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout)
defer cancel()
return c.txnUpdate(ctx, keyID(deviceTokenPrefix, deviceCode), func(currentValue []byte) ([]byte, error) {
var current DeviceToken

2
storage/health.go

@ -23,7 +23,7 @@ func NewCustomHealthCheckFunc(s Storage, now func() time.Time) func(context.Cont
return nil, fmt.Errorf("create auth request: %v", err)
}
if err := s.DeleteAuthRequest(a.ID); err != nil {
if err := s.DeleteAuthRequest(ctx, a.ID); err != nil {
return nil, fmt.Errorf("delete auth request: %v", err)
}

68
storage/kubernetes/storage.go

@ -262,7 +262,7 @@ func (cli *client) CreateConnector(ctx context.Context, c storage.Connector) err
return cli.post(resourceConnector, cli.fromStorageConnector(c))
}
func (cli *client) GetAuthRequest(id string) (storage.AuthRequest, error) {
func (cli *client) GetAuthRequest(ctx context.Context, id string) (storage.AuthRequest, error) {
var req AuthRequest
if err := cli.get(resourceAuthRequest, id, &req); err != nil {
return storage.AuthRequest{}, err
@ -270,7 +270,7 @@ func (cli *client) GetAuthRequest(id string) (storage.AuthRequest, error) {
return toStorageAuthRequest(req), nil
}
func (cli *client) GetAuthCode(id string) (storage.AuthCode, error) {
func (cli *client) GetAuthCode(ctx context.Context, id string) (storage.AuthCode, error) {
var code AuthCode
if err := cli.get(resourceAuthCode, id, &code); err != nil {
return storage.AuthCode{}, err
@ -278,7 +278,7 @@ func (cli *client) GetAuthCode(id string) (storage.AuthCode, error) {
return toStorageAuthCode(code), nil
}
func (cli *client) GetClient(id string) (storage.Client, error) {
func (cli *client) GetClient(ctx context.Context, id string) (storage.Client, error) {
c, err := cli.getClient(id)
if err != nil {
return storage.Client{}, err
@ -298,7 +298,7 @@ func (cli *client) getClient(id string) (Client, error) {
return c, nil
}
func (cli *client) GetPassword(email string) (storage.Password, error) {
func (cli *client) GetPassword(ctx context.Context, email string) (storage.Password, error) {
p, err := cli.getPassword(email)
if err != nil {
return storage.Password{}, err
@ -320,7 +320,7 @@ func (cli *client) getPassword(email string) (Password, error) {
return p, nil
}
func (cli *client) GetKeys() (storage.Keys, error) {
func (cli *client) GetKeys(ctx context.Context) (storage.Keys, error) {
var keys Keys
if err := cli.get(resourceKeys, keysName, &keys); err != nil {
return storage.Keys{}, err
@ -328,7 +328,7 @@ func (cli *client) GetKeys() (storage.Keys, error) {
return toStorageKeys(keys), nil
}
func (cli *client) GetRefresh(id string) (storage.RefreshToken, error) {
func (cli *client) GetRefresh(ctx context.Context, id string) (storage.RefreshToken, error) {
r, err := cli.getRefreshToken(id)
if err != nil {
return storage.RefreshToken{}, err
@ -341,7 +341,7 @@ func (cli *client) getRefreshToken(id string) (r RefreshToken, err error) {
return
}
func (cli *client) GetOfflineSessions(userID string, connID string) (storage.OfflineSessions, error) {
func (cli *client) GetOfflineSessions(ctx context.Context, userID string, connID string) (storage.OfflineSessions, error) {
o, err := cli.getOfflineSessions(userID, connID)
if err != nil {
return storage.OfflineSessions{}, err
@ -360,7 +360,7 @@ func (cli *client) getOfflineSessions(userID string, connID string) (o OfflineSe
return o, nil
}
func (cli *client) GetConnector(id string) (storage.Connector, error) {
func (cli *client) GetConnector(ctx context.Context, id string) (storage.Connector, error) {
var c Connector
if err := cli.get(resourceConnector, id, &c); err != nil {
return storage.Connector{}, err
@ -368,15 +368,15 @@ func (cli *client) GetConnector(id string) (storage.Connector, error) {
return toStorageConnector(c), nil
}
func (cli *client) ListClients() ([]storage.Client, error) {
func (cli *client) ListClients(ctx context.Context) ([]storage.Client, error) {
return nil, errors.New("not implemented")
}
func (cli *client) ListRefreshTokens() ([]storage.RefreshToken, error) {
func (cli *client) ListRefreshTokens(ctx context.Context) ([]storage.RefreshToken, error) {
return nil, errors.New("not implemented")
}
func (cli *client) ListPasswords() (passwords []storage.Password, err error) {
func (cli *client) ListPasswords(ctx context.Context) (passwords []storage.Password, err error) {
var passwordList PasswordList
if err = cli.list(resourcePassword, &passwordList); err != nil {
return passwords, fmt.Errorf("failed to list passwords: %v", err)
@ -395,7 +395,7 @@ func (cli *client) ListPasswords() (passwords []storage.Password, err error) {
return
}
func (cli *client) ListConnectors() (connectors []storage.Connector, err error) {
func (cli *client) ListConnectors(ctx context.Context) (connectors []storage.Connector, err error) {
var connectorList ConnectorList
if err = cli.list(resourceConnector, &connectorList); err != nil {
return connectors, fmt.Errorf("failed to list connectors: %v", err)
@ -409,15 +409,15 @@ func (cli *client) ListConnectors() (connectors []storage.Connector, err error)
return
}
func (cli *client) DeleteAuthRequest(id string) error {
func (cli *client) DeleteAuthRequest(ctx context.Context, id string) error {
return cli.delete(resourceAuthRequest, id)
}
func (cli *client) DeleteAuthCode(code string) error {
func (cli *client) DeleteAuthCode(ctx context.Context, code string) error {
return cli.delete(resourceAuthCode, code)
}
func (cli *client) DeleteClient(id string) error {
func (cli *client) DeleteClient(ctx context.Context, id string) error {
// Check for hash collision.
c, err := cli.getClient(id)
if err != nil {
@ -426,11 +426,11 @@ func (cli *client) DeleteClient(id string) error {
return cli.delete(resourceClient, c.ObjectMeta.Name)
}
func (cli *client) DeleteRefresh(id string) error {
func (cli *client) DeleteRefresh(ctx context.Context, id string) error {
return cli.delete(resourceRefreshToken, id)
}
func (cli *client) DeletePassword(email string) error {
func (cli *client) DeletePassword(ctx context.Context, email string) error {
// Check for hash collision.
p, err := cli.getPassword(email)
if err != nil {
@ -439,7 +439,7 @@ func (cli *client) DeletePassword(email string) error {
return cli.delete(resourcePassword, p.ObjectMeta.Name)
}
func (cli *client) DeleteOfflineSessions(userID string, connID string) error {
func (cli *client) DeleteOfflineSessions(ctx context.Context, userID string, connID string) error {
// Check for hash collision.
o, err := cli.getOfflineSessions(userID, connID)
if err != nil {
@ -448,11 +448,11 @@ func (cli *client) DeleteOfflineSessions(userID string, connID string) error {
return cli.delete(resourceOfflineSessions, o.ObjectMeta.Name)
}
func (cli *client) DeleteConnector(id string) error {
func (cli *client) DeleteConnector(ctx context.Context, id string) error {
return cli.delete(resourceConnector, id)
}
func (cli *client) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
func (cli *client) UpdateRefreshToken(ctx context.Context, id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
lock := newRefreshTokenLock(cli)
if err := lock.Lock(id); err != nil {
@ -460,7 +460,7 @@ func (cli *client) UpdateRefreshToken(id string, updater func(old storage.Refres
}
defer lock.Unlock(id)
return retryOnConflict(context.TODO(), func() error {
return retryOnConflict(ctx, func() error {
r, err := cli.getRefreshToken(id)
if err != nil {
return err
@ -479,7 +479,7 @@ func (cli *client) UpdateRefreshToken(id string, updater func(old storage.Refres
})
}
func (cli *client) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error {
func (cli *client) UpdateClient(ctx context.Context, id string, updater func(old storage.Client) (storage.Client, error)) error {
c, err := cli.getClient(id)
if err != nil {
return err
@ -496,7 +496,7 @@ func (cli *client) UpdateClient(id string, updater func(old storage.Client) (sto
return cli.put(resourceClient, c.ObjectMeta.Name, newClient)
}
func (cli *client) UpdatePassword(email string, updater func(old storage.Password) (storage.Password, error)) error {
func (cli *client) UpdatePassword(ctx context.Context, email string, updater func(old storage.Password) (storage.Password, error)) error {
p, err := cli.getPassword(email)
if err != nil {
return err
@ -513,8 +513,8 @@ func (cli *client) UpdatePassword(email string, updater func(old storage.Passwor
return cli.put(resourcePassword, p.ObjectMeta.Name, newPassword)
}
func (cli *client) UpdateOfflineSessions(userID string, connID string, updater func(old storage.OfflineSessions) (storage.OfflineSessions, error)) error {
return retryOnConflict(context.TODO(), func() error {
func (cli *client) UpdateOfflineSessions(ctx context.Context, userID string, connID string, updater func(old storage.OfflineSessions) (storage.OfflineSessions, error)) error {
return retryOnConflict(ctx, func() error {
o, err := cli.getOfflineSessions(userID, connID)
if err != nil {
return err
@ -531,7 +531,7 @@ func (cli *client) UpdateOfflineSessions(userID string, connID string, updater f
})
}
func (cli *client) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) error {
func (cli *client) UpdateKeys(ctx context.Context, updater func(old storage.Keys) (storage.Keys, error)) error {
firstUpdate := false
var keys Keys
if err := cli.get(resourceKeys, keysName, &keys); err != nil {
@ -576,7 +576,7 @@ func (cli *client) UpdateKeys(updater func(old storage.Keys) (storage.Keys, erro
return err
}
func (cli *client) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) (storage.AuthRequest, error)) error {
func (cli *client) UpdateAuthRequest(ctx context.Context, id string, updater func(a storage.AuthRequest) (storage.AuthRequest, error)) error {
var req AuthRequest
err := cli.get(resourceAuthRequest, id, &req)
if err != nil {
@ -593,8 +593,8 @@ func (cli *client) UpdateAuthRequest(id string, updater func(a storage.AuthReque
return cli.put(resourceAuthRequest, id, newReq)
}
func (cli *client) UpdateConnector(id string, updater func(a storage.Connector) (storage.Connector, error)) error {
return retryOnConflict(context.TODO(), func() error {
func (cli *client) UpdateConnector(ctx context.Context, id string, updater func(a storage.Connector) (storage.Connector, error)) error {
return retryOnConflict(ctx, func() error {
var c Connector
err := cli.get(resourceConnector, id, &c)
if err != nil {
@ -612,7 +612,7 @@ func (cli *client) UpdateConnector(id string, updater func(a storage.Connector)
})
}
func (cli *client) GarbageCollect(now time.Time) (result storage.GCResult, err error) {
func (cli *client) GarbageCollect(ctx context.Context, now time.Time) (result storage.GCResult, err error) {
var authRequests AuthRequestList
if err := cli.listN(resourceAuthRequest, &authRequests, gcResultLimit); err != nil {
return result, fmt.Errorf("failed to list auth requests: %v", err)
@ -687,7 +687,7 @@ func (cli *client) CreateDeviceRequest(ctx context.Context, d storage.DeviceRequ
return cli.post(resourceDeviceRequest, cli.fromStorageDeviceRequest(d))
}
func (cli *client) GetDeviceRequest(userCode string) (storage.DeviceRequest, error) {
func (cli *client) GetDeviceRequest(ctx context.Context, userCode string) (storage.DeviceRequest, error) {
var req DeviceRequest
if err := cli.get(resourceDeviceRequest, strings.ToLower(userCode), &req); err != nil {
return storage.DeviceRequest{}, err
@ -699,7 +699,7 @@ func (cli *client) CreateDeviceToken(ctx context.Context, t storage.DeviceToken)
return cli.post(resourceDeviceToken, cli.fromStorageDeviceToken(t))
}
func (cli *client) GetDeviceToken(deviceCode string) (storage.DeviceToken, error) {
func (cli *client) GetDeviceToken(ctx context.Context, deviceCode string) (storage.DeviceToken, error) {
var token DeviceToken
if err := cli.get(resourceDeviceToken, deviceCode, &token); err != nil {
return storage.DeviceToken{}, err
@ -712,8 +712,8 @@ func (cli *client) getDeviceToken(deviceCode string) (t DeviceToken, err error)
return
}
func (cli *client) UpdateDeviceToken(deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error {
return retryOnConflict(context.TODO(), func() error {
func (cli *client) UpdateDeviceToken(ctx context.Context, deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error {
return retryOnConflict(ctx, func() error {
r, err := cli.getDeviceToken(deviceCode)
if err != nil {
return err

14
storage/kubernetes/storage_test.go

@ -221,7 +221,7 @@ func TestUpdateKeys(t *testing.T) {
for _, test := range tests {
client := newStatusCodesResponseTestClient(test.getResponseCode, test.actionResponseCode)
err := client.UpdateKeys(test.updater)
err := client.UpdateKeys(context.TODO(), test.updater)
if err != nil {
if !test.wantErr {
t.Fatalf("Test %q: %v", test.name, err)
@ -339,9 +339,9 @@ func TestRefreshTokenLock(t *testing.T) {
require.NoError(t, err)
t.Run("Timeout lock error", func(t *testing.T) {
err = kubeClient.UpdateRefreshToken(r.ID, func(r storage.RefreshToken) (storage.RefreshToken, error) {
err = kubeClient.UpdateRefreshToken(ctx, r.ID, func(r storage.RefreshToken) (storage.RefreshToken, error) {
r.Token = "update-result-1"
err := kubeClient.UpdateRefreshToken(r.ID, func(r storage.RefreshToken) (storage.RefreshToken, error) {
err := kubeClient.UpdateRefreshToken(ctx, r.ID, func(r storage.RefreshToken) (storage.RefreshToken, error) {
r.Token = "timeout-err"
return r, nil
})
@ -350,7 +350,7 @@ func TestRefreshTokenLock(t *testing.T) {
})
require.NoError(t, err)
token, err := kubeClient.GetRefresh(r.ID)
token, err := kubeClient.GetRefresh(context.TODO(), r.ID)
require.NoError(t, err)
require.Equal(t, "update-result-1", token.Token)
})
@ -359,13 +359,13 @@ func TestRefreshTokenLock(t *testing.T) {
var lockBroken bool
lockTimeout = -time.Hour
err = kubeClient.UpdateRefreshToken(r.ID, func(r storage.RefreshToken) (storage.RefreshToken, error) {
err = kubeClient.UpdateRefreshToken(ctx, r.ID, func(r storage.RefreshToken) (storage.RefreshToken, error) {
r.Token = "update-result-2"
if lockBroken {
return r, nil
}
err := kubeClient.UpdateRefreshToken(r.ID, func(r storage.RefreshToken) (storage.RefreshToken, error) {
err := kubeClient.UpdateRefreshToken(ctx, r.ID, func(r storage.RefreshToken) (storage.RefreshToken, error) {
r.Token = "should-break-the-lock-and-finish-updating"
return r, nil
})
@ -376,7 +376,7 @@ func TestRefreshTokenLock(t *testing.T) {
})
require.NoError(t, err)
token, err := kubeClient.GetRefresh(r.ID)
token, err := kubeClient.GetRefresh(context.TODO(), r.ID)
require.NoError(t, err)
// Because concurrent update breaks the lock, the final result will be the value of the first update
require.Equal(t, "update-result-2", token.Token)

60
storage/memory/memory.go

@ -71,7 +71,7 @@ func (s *memStorage) tx(f func()) {
func (s *memStorage) Close() error { return nil }
func (s *memStorage) GarbageCollect(now time.Time) (result storage.GCResult, err error) {
func (s *memStorage) GarbageCollect(ctx context.Context, now time.Time) (result storage.GCResult, err error) {
s.tx(func() {
for id, a := range s.authCodes {
if now.After(a.Expiry) {
@ -183,7 +183,7 @@ func (s *memStorage) CreateConnector(ctx context.Context, connector storage.Conn
return
}
func (s *memStorage) GetAuthCode(id string) (c storage.AuthCode, err error) {
func (s *memStorage) GetAuthCode(ctx context.Context, id string) (c storage.AuthCode, err error) {
s.tx(func() {
var ok bool
if c, ok = s.authCodes[id]; !ok {
@ -194,7 +194,7 @@ func (s *memStorage) GetAuthCode(id string) (c storage.AuthCode, err error) {
return
}
func (s *memStorage) GetPassword(email string) (p storage.Password, err error) {
func (s *memStorage) GetPassword(ctx context.Context, email string) (p storage.Password, err error) {
email = strings.ToLower(email)
s.tx(func() {
var ok bool
@ -205,7 +205,7 @@ func (s *memStorage) GetPassword(email string) (p storage.Password, err error) {
return
}
func (s *memStorage) GetClient(id string) (client storage.Client, err error) {
func (s *memStorage) GetClient(ctx context.Context, id string) (client storage.Client, err error) {
s.tx(func() {
var ok bool
if client, ok = s.clients[id]; !ok {
@ -215,12 +215,12 @@ func (s *memStorage) GetClient(id string) (client storage.Client, err error) {
return
}
func (s *memStorage) GetKeys() (keys storage.Keys, err error) {
func (s *memStorage) GetKeys(ctx context.Context) (keys storage.Keys, err error) {
s.tx(func() { keys = s.keys })
return
}
func (s *memStorage) GetRefresh(id string) (tok storage.RefreshToken, err error) {
func (s *memStorage) GetRefresh(ctx context.Context, id string) (tok storage.RefreshToken, err error) {
s.tx(func() {
var ok bool
if tok, ok = s.refreshTokens[id]; !ok {
@ -231,7 +231,7 @@ func (s *memStorage) GetRefresh(id string) (tok storage.RefreshToken, err error)
return
}
func (s *memStorage) GetAuthRequest(id string) (req storage.AuthRequest, err error) {
func (s *memStorage) GetAuthRequest(ctx context.Context, id string) (req storage.AuthRequest, err error) {
s.tx(func() {
var ok bool
if req, ok = s.authReqs[id]; !ok {
@ -242,7 +242,7 @@ func (s *memStorage) GetAuthRequest(id string) (req storage.AuthRequest, err err
return
}
func (s *memStorage) GetOfflineSessions(userID string, connID string) (o storage.OfflineSessions, err error) {
func (s *memStorage) GetOfflineSessions(ctx context.Context, userID string, connID string) (o storage.OfflineSessions, err error) {
id := offlineSessionID{
userID: userID,
connID: connID,
@ -257,7 +257,7 @@ func (s *memStorage) GetOfflineSessions(userID string, connID string) (o storage
return
}
func (s *memStorage) GetConnector(id string) (connector storage.Connector, err error) {
func (s *memStorage) GetConnector(ctx context.Context, id string) (connector storage.Connector, err error) {
s.tx(func() {
var ok bool
if connector, ok = s.connectors[id]; !ok {
@ -267,7 +267,7 @@ func (s *memStorage) GetConnector(id string) (connector storage.Connector, err e
return
}
func (s *memStorage) ListClients() (clients []storage.Client, err error) {
func (s *memStorage) ListClients(ctx context.Context) (clients []storage.Client, err error) {
s.tx(func() {
for _, client := range s.clients {
clients = append(clients, client)
@ -276,7 +276,7 @@ func (s *memStorage) ListClients() (clients []storage.Client, err error) {
return
}
func (s *memStorage) ListRefreshTokens() (tokens []storage.RefreshToken, err error) {
func (s *memStorage) ListRefreshTokens(ctx context.Context) (tokens []storage.RefreshToken, err error) {
s.tx(func() {
for _, refresh := range s.refreshTokens {
tokens = append(tokens, refresh)
@ -285,7 +285,7 @@ func (s *memStorage) ListRefreshTokens() (tokens []storage.RefreshToken, err err
return
}
func (s *memStorage) ListPasswords() (passwords []storage.Password, err error) {
func (s *memStorage) ListPasswords(ctx context.Context) (passwords []storage.Password, err error) {
s.tx(func() {
for _, password := range s.passwords {
passwords = append(passwords, password)
@ -294,7 +294,7 @@ func (s *memStorage) ListPasswords() (passwords []storage.Password, err error) {
return
}
func (s *memStorage) ListConnectors() (conns []storage.Connector, err error) {
func (s *memStorage) ListConnectors(ctx context.Context) (conns []storage.Connector, err error) {
s.tx(func() {
for _, c := range s.connectors {
conns = append(conns, c)
@ -303,7 +303,7 @@ func (s *memStorage) ListConnectors() (conns []storage.Connector, err error) {
return
}
func (s *memStorage) DeletePassword(email string) (err error) {
func (s *memStorage) DeletePassword(ctx context.Context, email string) (err error) {
email = strings.ToLower(email)
s.tx(func() {
if _, ok := s.passwords[email]; !ok {
@ -315,7 +315,7 @@ func (s *memStorage) DeletePassword(email string) (err error) {
return
}
func (s *memStorage) DeleteClient(id string) (err error) {
func (s *memStorage) DeleteClient(ctx context.Context, id string) (err error) {
s.tx(func() {
if _, ok := s.clients[id]; !ok {
err = storage.ErrNotFound
@ -326,7 +326,7 @@ func (s *memStorage) DeleteClient(id string) (err error) {
return
}
func (s *memStorage) DeleteRefresh(id string) (err error) {
func (s *memStorage) DeleteRefresh(ctx context.Context, id string) (err error) {
s.tx(func() {
if _, ok := s.refreshTokens[id]; !ok {
err = storage.ErrNotFound
@ -337,7 +337,7 @@ func (s *memStorage) DeleteRefresh(id string) (err error) {
return
}
func (s *memStorage) DeleteAuthCode(id string) (err error) {
func (s *memStorage) DeleteAuthCode(ctx context.Context, id string) (err error) {
s.tx(func() {
if _, ok := s.authCodes[id]; !ok {
err = storage.ErrNotFound
@ -348,7 +348,7 @@ func (s *memStorage) DeleteAuthCode(id string) (err error) {
return
}
func (s *memStorage) DeleteAuthRequest(id string) (err error) {
func (s *memStorage) DeleteAuthRequest(ctx context.Context, id string) (err error) {
s.tx(func() {
if _, ok := s.authReqs[id]; !ok {
err = storage.ErrNotFound
@ -359,7 +359,7 @@ func (s *memStorage) DeleteAuthRequest(id string) (err error) {
return
}
func (s *memStorage) DeleteOfflineSessions(userID string, connID string) (err error) {
func (s *memStorage) DeleteOfflineSessions(ctx context.Context, userID string, connID string) (err error) {
id := offlineSessionID{
userID: userID,
connID: connID,
@ -374,7 +374,7 @@ func (s *memStorage) DeleteOfflineSessions(userID string, connID string) (err er
return
}
func (s *memStorage) DeleteConnector(id string) (err error) {
func (s *memStorage) DeleteConnector(ctx context.Context, id string) (err error) {
s.tx(func() {
if _, ok := s.connectors[id]; !ok {
err = storage.ErrNotFound
@ -385,7 +385,7 @@ func (s *memStorage) DeleteConnector(id string) (err error) {
return
}
func (s *memStorage) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) (err error) {
func (s *memStorage) UpdateClient(ctx context.Context, id string, updater func(old storage.Client) (storage.Client, error)) (err error) {
s.tx(func() {
client, ok := s.clients[id]
if !ok {
@ -399,7 +399,7 @@ func (s *memStorage) UpdateClient(id string, updater func(old storage.Client) (s
return
}
func (s *memStorage) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) (err error) {
func (s *memStorage) UpdateKeys(ctx context.Context, updater func(old storage.Keys) (storage.Keys, error)) (err error) {
s.tx(func() {
var keys storage.Keys
if keys, err = updater(s.keys); err == nil {
@ -409,7 +409,7 @@ func (s *memStorage) UpdateKeys(updater func(old storage.Keys) (storage.Keys, er
return
}
func (s *memStorage) UpdateAuthRequest(id string, updater func(old storage.AuthRequest) (storage.AuthRequest, error)) (err error) {
func (s *memStorage) UpdateAuthRequest(ctx context.Context, id string, updater func(old storage.AuthRequest) (storage.AuthRequest, error)) (err error) {
s.tx(func() {
req, ok := s.authReqs[id]
if !ok {
@ -423,7 +423,7 @@ func (s *memStorage) UpdateAuthRequest(id string, updater func(old storage.AuthR
return
}
func (s *memStorage) UpdatePassword(email string, updater func(p storage.Password) (storage.Password, error)) (err error) {
func (s *memStorage) UpdatePassword(ctx context.Context, email string, updater func(p storage.Password) (storage.Password, error)) (err error) {
email = strings.ToLower(email)
s.tx(func() {
req, ok := s.passwords[email]
@ -438,7 +438,7 @@ func (s *memStorage) UpdatePassword(email string, updater func(p storage.Passwor
return
}
func (s *memStorage) UpdateRefreshToken(id string, updater func(p storage.RefreshToken) (storage.RefreshToken, error)) (err error) {
func (s *memStorage) UpdateRefreshToken(ctx context.Context, id string, updater func(p storage.RefreshToken) (storage.RefreshToken, error)) (err error) {
s.tx(func() {
r, ok := s.refreshTokens[id]
if !ok {
@ -452,7 +452,7 @@ func (s *memStorage) UpdateRefreshToken(id string, updater func(p storage.Refres
return
}
func (s *memStorage) UpdateOfflineSessions(userID string, connID string, updater func(o storage.OfflineSessions) (storage.OfflineSessions, error)) (err error) {
func (s *memStorage) UpdateOfflineSessions(ctx context.Context, userID string, connID string, updater func(o storage.OfflineSessions) (storage.OfflineSessions, error)) (err error) {
id := offlineSessionID{
userID: userID,
connID: connID,
@ -470,7 +470,7 @@ func (s *memStorage) UpdateOfflineSessions(userID string, connID string, updater
return
}
func (s *memStorage) UpdateConnector(id string, updater func(c storage.Connector) (storage.Connector, error)) (err error) {
func (s *memStorage) UpdateConnector(ctx context.Context, id string, updater func(c storage.Connector) (storage.Connector, error)) (err error) {
s.tx(func() {
r, ok := s.connectors[id]
if !ok {
@ -495,7 +495,7 @@ func (s *memStorage) CreateDeviceRequest(ctx context.Context, d storage.DeviceRe
return
}
func (s *memStorage) GetDeviceRequest(userCode string) (req storage.DeviceRequest, err error) {
func (s *memStorage) GetDeviceRequest(ctx context.Context, userCode string) (req storage.DeviceRequest, err error) {
s.tx(func() {
var ok bool
if req, ok = s.deviceRequests[userCode]; !ok {
@ -517,7 +517,7 @@ func (s *memStorage) CreateDeviceToken(ctx context.Context, t storage.DeviceToke
return
}
func (s *memStorage) GetDeviceToken(deviceCode string) (t storage.DeviceToken, err error) {
func (s *memStorage) GetDeviceToken(ctx context.Context, deviceCode string) (t storage.DeviceToken, err error) {
s.tx(func() {
var ok bool
if t, ok = s.deviceTokens[deviceCode]; !ok {
@ -528,7 +528,7 @@ func (s *memStorage) GetDeviceToken(deviceCode string) (t storage.DeviceToken, e
return
}
func (s *memStorage) UpdateDeviceToken(deviceCode string, updater func(p storage.DeviceToken) (storage.DeviceToken, error)) (err error) {
func (s *memStorage) UpdateDeviceToken(ctx context.Context, deviceCode string, updater func(p storage.DeviceToken) (storage.DeviceToken, error)) (err error) {
s.tx(func() {
r, ok := s.deviceTokens[deviceCode]
if !ok {

34
storage/memory/static_test.go

@ -31,14 +31,14 @@ func TestStaticClients(t *testing.T) {
{
name: "get client from static storage",
action: func() error {
_, err := s.GetClient(c2.ID)
_, err := s.GetClient(ctx, c2.ID)
return err
},
},
{
name: "get client from backing storage",
action: func() error {
_, err := s.GetClient(c1.ID)
_, err := s.GetClient(ctx, c1.ID)
return err
},
},
@ -49,7 +49,7 @@ func TestStaticClients(t *testing.T) {
c.Secret = "new_" + c.Secret
return c, nil
}
return s.UpdateClient(c2.ID, updater)
return s.UpdateClient(ctx, c2.ID, updater)
},
wantErr: true,
},
@ -60,13 +60,13 @@ func TestStaticClients(t *testing.T) {
c.Secret = "new_" + c.Secret
return c, nil
}
return s.UpdateClient(c1.ID, updater)
return s.UpdateClient(ctx, c1.ID, updater)
},
},
{
name: "list clients",
action: func() error {
clients, err := s.ListClients()
clients, err := s.ListClients(ctx)
if err != nil {
return err
}
@ -116,21 +116,21 @@ func TestStaticPasswords(t *testing.T) {
{
name: "get password from static storage",
action: func() error {
_, err := s.GetPassword(p2.Email)
_, err := s.GetPassword(ctx, p2.Email)
return err
},
},
{
name: "get password from backing storage",
action: func() error {
_, err := s.GetPassword(p1.Email)
_, err := s.GetPassword(ctx, p1.Email)
return err
},
},
{
name: "get password from static storage with casing",
action: func() error {
_, err := s.GetPassword(strings.ToUpper(p2.Email))
_, err := s.GetPassword(ctx, strings.ToUpper(p2.Email))
return err
},
},
@ -141,7 +141,7 @@ func TestStaticPasswords(t *testing.T) {
p.Username = "new_" + p.Username
return p, nil
}
return s.UpdatePassword(p2.Email, updater)
return s.UpdatePassword(ctx, p2.Email, updater)
},
wantErr: true,
},
@ -152,7 +152,7 @@ func TestStaticPasswords(t *testing.T) {
p.Username = "new_" + p.Username
return p, nil
}
return s.UpdatePassword(p1.Email, updater)
return s.UpdatePassword(ctx, p1.Email, updater)
},
},
{
@ -168,7 +168,7 @@ func TestStaticPasswords(t *testing.T) {
{
name: "get password",
action: func() error {
p, err := s.GetPassword(p4.Email)
p, err := s.GetPassword(ctx, p4.Email)
if err != nil {
return err
}
@ -181,7 +181,7 @@ func TestStaticPasswords(t *testing.T) {
{
name: "list passwords",
action: func() error {
passwords, err := s.ListPasswords()
passwords, err := s.ListPasswords(ctx)
if err != nil {
return err
}
@ -228,14 +228,14 @@ func TestStaticConnectors(t *testing.T) {
{
name: "get connector from static storage",
action: func() error {
_, err := s.GetConnector(c2.ID)
_, err := s.GetConnector(ctx, c2.ID)
return err
},
},
{
name: "get connector from backing storage",
action: func() error {
_, err := s.GetConnector(c1.ID)
_, err := s.GetConnector(ctx, c1.ID)
return err
},
},
@ -246,7 +246,7 @@ func TestStaticConnectors(t *testing.T) {
c.Name = "New"
return c, nil
}
return s.UpdateConnector(c2.ID, updater)
return s.UpdateConnector(ctx, c2.ID, updater)
},
wantErr: true,
},
@ -257,13 +257,13 @@ func TestStaticConnectors(t *testing.T) {
c.Name = "New"
return c, nil
}
return s.UpdateConnector(c1.ID, updater)
return s.UpdateConnector(ctx, c1.ID, updater)
},
},
{
name: "list connectors",
action: func() error {
connectors, err := s.ListConnectors()
connectors, err := s.ListConnectors(ctx)
if err != nil {
return err
}

127
storage/sql/crud.go

@ -86,7 +86,7 @@ type scanner interface {
var _ storage.Storage = (*conn)(nil)
func (c *conn) GarbageCollect(now time.Time) (storage.GCResult, error) {
func (c *conn) GarbageCollect(ctc context.Context, now time.Time) (storage.GCResult, error) {
result := storage.GCResult{}
r, err := c.Exec(`delete from auth_request where expiry < $1`, now)
@ -158,9 +158,9 @@ func (c *conn) CreateAuthRequest(ctx context.Context, a storage.AuthRequest) err
return nil
}
func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) (storage.AuthRequest, error)) error {
func (c *conn) UpdateAuthRequest(ctx context.Context, id string, updater func(a storage.AuthRequest) (storage.AuthRequest, error)) error {
return c.ExecTx(func(tx *trans) error {
r, err := getAuthRequest(tx, id)
r, err := getAuthRequest(ctx, tx, id)
if err != nil {
return err
}
@ -200,11 +200,11 @@ func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest)
})
}
func (c *conn) GetAuthRequest(id string) (storage.AuthRequest, error) {
return getAuthRequest(c, id)
func (c *conn) GetAuthRequest(ctx context.Context, id string) (storage.AuthRequest, error) {
return getAuthRequest(ctx, c, id)
}
func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) {
func getAuthRequest(ctx context.Context, q querier, id string) (a storage.AuthRequest, err error) {
err = q.QueryRow(`
select
id, client_id, response_types, scopes, redirect_uri, nonce, state,
@ -258,7 +258,7 @@ func (c *conn) CreateAuthCode(ctx context.Context, a storage.AuthCode) error {
return nil
}
func (c *conn) GetAuthCode(id string) (a storage.AuthCode, err error) {
func (c *conn) GetAuthCode(ctx context.Context, id string) (a storage.AuthCode, err error) {
err = c.QueryRow(`
select
id, client_id, scopes, nonce, redirect_uri,
@ -310,9 +310,9 @@ func (c *conn) CreateRefresh(ctx context.Context, r storage.RefreshToken) error
return nil
}
func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
func (c *conn) UpdateRefreshToken(ctx context.Context, id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
return c.ExecTx(func(tx *trans) error {
r, err := getRefresh(tx, id)
r, err := getRefresh(ctx, tx, id)
if err != nil {
return err
}
@ -354,11 +354,11 @@ func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshTok
})
}
func (c *conn) GetRefresh(id string) (storage.RefreshToken, error) {
return getRefresh(c, id)
func (c *conn) GetRefresh(ctx context.Context, id string) (storage.RefreshToken, error) {
return getRefresh(ctx, c, id)
}
func getRefresh(q querier, id string) (storage.RefreshToken, error) {
func getRefresh(ctx context.Context, q querier, id string) (storage.RefreshToken, error) {
return scanRefresh(q.QueryRow(`
select
id, client_id, scopes, nonce,
@ -371,7 +371,7 @@ func getRefresh(q querier, id string) (storage.RefreshToken, error) {
`, id))
}
func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) {
func (c *conn) ListRefreshTokens(ctx context.Context) ([]storage.RefreshToken, error) {
rows, err := c.Query(`
select
id, client_id, scopes, nonce,
@ -418,12 +418,12 @@ func scanRefresh(s scanner) (r storage.RefreshToken, err error) {
return r, nil
}
func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) error {
func (c *conn) UpdateKeys(ctx context.Context, updater func(old storage.Keys) (storage.Keys, error)) error {
return c.ExecTx(func(tx *trans) error {
firstUpdate := false
// TODO(ericchiang): errors may cause a transaction be rolled back by the SQL
// server. Test this, and consider adding a COUNT() command beforehand.
old, err := getKeys(tx)
old, err := getKeys(ctx, tx)
if err != nil {
if err != storage.ErrNotFound {
return fmt.Errorf("get keys: %v", err)
@ -471,11 +471,11 @@ func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error))
})
}
func (c *conn) GetKeys() (keys storage.Keys, err error) {
return getKeys(c)
func (c *conn) GetKeys(ctx context.Context) (keys storage.Keys, err error) {
return getKeys(ctx, c)
}
func getKeys(q querier) (keys storage.Keys, err error) {
func getKeys(ctx context.Context, q querier) (keys storage.Keys, err error) {
err = q.QueryRow(`
select
verification_keys, signing_key, signing_key_pub, next_rotation
@ -494,9 +494,9 @@ func getKeys(q querier) (keys storage.Keys, err error) {
return keys, nil
}
func (c *conn) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error {
func (c *conn) UpdateClient(ctx context.Context, id string, updater func(old storage.Client) (storage.Client, error)) error {
return c.ExecTx(func(tx *trans) error {
cli, err := getClient(tx, id)
cli, err := getClient(ctx, tx, id)
if err != nil {
return err
}
@ -543,7 +543,7 @@ func (c *conn) CreateClient(ctx context.Context, cli storage.Client) error {
return nil
}
func getClient(q querier, id string) (storage.Client, error) {
func getClient(ctx context.Context, q querier, id string) (storage.Client, error) {
return scanClient(q.QueryRow(`
select
id, secret, redirect_uris, trusted_peers, public, name, logo_url
@ -551,11 +551,11 @@ func getClient(q querier, id string) (storage.Client, error) {
`, id))
}
func (c *conn) GetClient(id string) (storage.Client, error) {
return getClient(c, id)
func (c *conn) GetClient(ctx context.Context, id string) (storage.Client, error) {
return getClient(ctx, c, id)
}
func (c *conn) ListClients() ([]storage.Client, error) {
func (c *conn) ListClients(ctx context.Context) ([]storage.Client, error) {
rows, err := c.Query(`
select
id, secret, redirect_uris, trusted_peers, public, name, logo_url
@ -615,9 +615,9 @@ func (c *conn) CreatePassword(ctx context.Context, p storage.Password) error {
return nil
}
func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (storage.Password, error)) error {
func (c *conn) UpdatePassword(ctx context.Context, email string, updater func(p storage.Password) (storage.Password, error)) error {
return c.ExecTx(func(tx *trans) error {
p, err := getPassword(tx, email)
p, err := getPassword(ctx, tx, email)
if err != nil {
return err
}
@ -641,11 +641,11 @@ func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (st
})
}
func (c *conn) GetPassword(email string) (storage.Password, error) {
return getPassword(c, email)
func (c *conn) GetPassword(ctx context.Context, email string) (storage.Password, error) {
return getPassword(ctx, c, email)
}
func getPassword(q querier, email string) (p storage.Password, err error) {
func getPassword(ctx context.Context, q querier, email string) (p storage.Password, err error) {
return scanPassword(q.QueryRow(`
select
email, hash, username, user_id
@ -653,7 +653,7 @@ func getPassword(q querier, email string) (p storage.Password, err error) {
`, strings.ToLower(email)))
}
func (c *conn) ListPasswords() ([]storage.Password, error) {
func (c *conn) ListPasswords(ctx context.Context) ([]storage.Password, error) {
rows, err := c.Query(`
select
email, hash, username, user_id
@ -711,9 +711,9 @@ func (c *conn) CreateOfflineSessions(ctx context.Context, s storage.OfflineSessi
return nil
}
func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error {
func (c *conn) UpdateOfflineSessions(ctx context.Context, userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error {
return c.ExecTx(func(tx *trans) error {
s, err := getOfflineSessions(tx, userID, connID)
s, err := getOfflineSessions(ctx, tx, userID, connID)
if err != nil {
return err
}
@ -738,11 +738,11 @@ func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(
})
}
func (c *conn) GetOfflineSessions(userID string, connID string) (storage.OfflineSessions, error) {
return getOfflineSessions(c, userID, connID)
func (c *conn) GetOfflineSessions(ctx context.Context, userID string, connID string) (storage.OfflineSessions, error) {
return getOfflineSessions(ctx, c, userID, connID)
}
func getOfflineSessions(q querier, userID string, connID string) (storage.OfflineSessions, error) {
func getOfflineSessions(ctx context.Context, q querier, userID string, connID string) (storage.OfflineSessions, error) {
return scanOfflineSessions(q.QueryRow(`
select
user_id, conn_id, refresh, connector_data
@ -784,9 +784,9 @@ func (c *conn) CreateConnector(ctx context.Context, connector storage.Connector)
return nil
}
func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (storage.Connector, error)) error {
func (c *conn) UpdateConnector(ctx context.Context, id string, updater func(s storage.Connector) (storage.Connector, error)) error {
return c.ExecTx(func(tx *trans) error {
connector, err := getConnector(tx, id)
connector, err := getConnector(ctx, tx, id)
if err != nil {
return err
}
@ -813,11 +813,11 @@ func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (sto
})
}
func (c *conn) GetConnector(id string) (storage.Connector, error) {
return getConnector(c, id)
func (c *conn) GetConnector(ctx context.Context, id string) (storage.Connector, error) {
return getConnector(ctx, c, id)
}
func getConnector(q querier, id string) (storage.Connector, error) {
func getConnector(ctx context.Context, q querier, id string) (storage.Connector, error) {
return scanConnector(q.QueryRow(`
select
id, type, name, resource_version, config
@ -839,7 +839,7 @@ func scanConnector(s scanner) (c storage.Connector, err error) {
return c, nil
}
func (c *conn) ListConnectors() ([]storage.Connector, error) {
func (c *conn) ListConnectors(ctx context.Context) ([]storage.Connector, error) {
rows, err := c.Query(`
select
id, type, name, resource_version, config
@ -864,16 +864,31 @@ func (c *conn) ListConnectors() ([]storage.Connector, error) {
return connectors, nil
}
func (c *conn) DeleteAuthRequest(id string) error { return c.delete("auth_request", "id", id) }
func (c *conn) DeleteAuthCode(id string) error { return c.delete("auth_code", "id", id) }
func (c *conn) DeleteClient(id string) error { return c.delete("client", "id", id) }
func (c *conn) DeleteRefresh(id string) error { return c.delete("refresh_token", "id", id) }
func (c *conn) DeletePassword(email string) error {
func (c *conn) DeleteAuthRequest(ctx context.Context, id string) error {
return c.delete("auth_request", "id", id)
}
func (c *conn) DeleteAuthCode(ctx context.Context, id string) error {
return c.delete("auth_code", "id", id)
}
func (c *conn) DeleteClient(ctx context.Context, id string) error {
return c.delete("client", "id", id)
}
func (c *conn) DeleteRefresh(ctx context.Context, id string) error {
return c.delete("refresh_token", "id", id)
}
func (c *conn) DeletePassword(ctx context.Context, email string) error {
return c.delete("password", "email", strings.ToLower(email))
}
func (c *conn) DeleteConnector(id string) error { return c.delete("connector", "id", id) }
func (c *conn) DeleteOfflineSessions(userID string, connID string) error {
func (c *conn) DeleteConnector(ctx context.Context, id string) error {
return c.delete("connector", "id", id)
}
func (c *conn) DeleteOfflineSessions(ctx context.Context, userID string, connID string) error {
result, err := c.Exec(`delete from offline_session where user_id = $1 AND conn_id = $2`, userID, connID)
if err != nil {
return fmt.Errorf("delete offline_session: user_id = %s, conn_id = %s", userID, connID)
@ -948,11 +963,11 @@ func (c *conn) CreateDeviceToken(ctx context.Context, t storage.DeviceToken) err
return nil
}
func (c *conn) GetDeviceRequest(userCode string) (storage.DeviceRequest, error) {
return getDeviceRequest(c, userCode)
func (c *conn) GetDeviceRequest(ctx context.Context, userCode string) (storage.DeviceRequest, error) {
return getDeviceRequest(ctx, c, userCode)
}
func getDeviceRequest(q querier, userCode string) (d storage.DeviceRequest, err error) {
func getDeviceRequest(ctx context.Context, q querier, userCode string) (d storage.DeviceRequest, err error) {
err = q.QueryRow(`
select
device_code, client_id, client_secret, scopes, expiry
@ -970,11 +985,11 @@ func getDeviceRequest(q querier, userCode string) (d storage.DeviceRequest, err
return d, nil
}
func (c *conn) GetDeviceToken(deviceCode string) (storage.DeviceToken, error) {
return getDeviceToken(c, deviceCode)
func (c *conn) GetDeviceToken(ctx context.Context, deviceCode string) (storage.DeviceToken, error) {
return getDeviceToken(ctx, c, deviceCode)
}
func getDeviceToken(q querier, deviceCode string) (a storage.DeviceToken, err error) {
func getDeviceToken(ctx context.Context, q querier, deviceCode string) (a storage.DeviceToken, err error) {
err = q.QueryRow(`
select
status, token, expiry, last_request, poll_interval, code_challenge, code_challenge_method
@ -992,9 +1007,9 @@ func getDeviceToken(q querier, deviceCode string) (a storage.DeviceToken, err er
return a, nil
}
func (c *conn) UpdateDeviceToken(deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error {
func (c *conn) UpdateDeviceToken(ctx context.Context, deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error {
return c.ExecTx(func(tx *trans) error {
r, err := getDeviceToken(tx, deviceCode)
r, err := getDeviceToken(ctx, tx, deviceCode)
if err != nil {
return err
}

48
storage/static.go

@ -31,11 +31,11 @@ func WithStaticClients(s Storage, staticClients []Client) Storage {
return staticClientsStorage{s, staticClients, clientsByID}
}
func (s staticClientsStorage) GetClient(id string) (Client, error) {
func (s staticClientsStorage) GetClient(ctx context.Context, id string) (Client, error) {
if client, ok := s.clientsByID[id]; ok {
return client, nil
}
return s.Storage.GetClient(id)
return s.Storage.GetClient(ctx, id)
}
func (s staticClientsStorage) isStatic(id string) bool {
@ -43,8 +43,8 @@ func (s staticClientsStorage) isStatic(id string) bool {
return ok
}
func (s staticClientsStorage) ListClients() ([]Client, error) {
clients, err := s.Storage.ListClients()
func (s staticClientsStorage) ListClients(ctx context.Context) ([]Client, error) {
clients, err := s.Storage.ListClients(ctx)
if err != nil {
return nil, err
}
@ -67,18 +67,18 @@ func (s staticClientsStorage) CreateClient(ctx context.Context, c Client) error
return s.Storage.CreateClient(ctx, c)
}
func (s staticClientsStorage) DeleteClient(id string) error {
func (s staticClientsStorage) DeleteClient(ctx context.Context, id string) error {
if s.isStatic(id) {
return errors.New("static clients: read-only cannot delete client")
}
return s.Storage.DeleteClient(id)
return s.Storage.DeleteClient(ctx, id)
}
func (s staticClientsStorage) UpdateClient(id string, updater func(old Client) (Client, error)) error {
func (s staticClientsStorage) UpdateClient(ctx context.Context, id string, updater func(old Client) (Client, error)) error {
if s.isStatic(id) {
return errors.New("static clients: read-only cannot update client")
}
return s.Storage.UpdateClient(id, updater)
return s.Storage.UpdateClient(ctx, id, updater)
}
type staticPasswordsStorage struct {
@ -112,18 +112,18 @@ func (s staticPasswordsStorage) isStatic(email string) bool {
return ok
}
func (s staticPasswordsStorage) GetPassword(email string) (Password, error) {
func (s staticPasswordsStorage) GetPassword(ctx context.Context, email string) (Password, error) {
// TODO(ericchiang): BLAH. We really need to figure out how to handle
// lower cased emails better.
email = strings.ToLower(email)
if password, ok := s.passwordsByEmail[email]; ok {
return password, nil
}
return s.Storage.GetPassword(email)
return s.Storage.GetPassword(ctx, email)
}
func (s staticPasswordsStorage) ListPasswords() ([]Password, error) {
passwords, err := s.Storage.ListPasswords()
func (s staticPasswordsStorage) ListPasswords(ctx context.Context) ([]Password, error) {
passwords, err := s.Storage.ListPasswords(ctx)
if err != nil {
return nil, err
}
@ -147,18 +147,18 @@ func (s staticPasswordsStorage) CreatePassword(ctx context.Context, p Password)
return s.Storage.CreatePassword(ctx, p)
}
func (s staticPasswordsStorage) DeletePassword(email string) error {
func (s staticPasswordsStorage) DeletePassword(ctx context.Context, email string) error {
if s.isStatic(email) {
return errors.New("static passwords: read-only cannot delete password")
}
return s.Storage.DeletePassword(email)
return s.Storage.DeletePassword(ctx, email)
}
func (s staticPasswordsStorage) UpdatePassword(email string, updater func(old Password) (Password, error)) error {
func (s staticPasswordsStorage) UpdatePassword(ctx context.Context, email string, updater func(old Password) (Password, error)) error {
if s.isStatic(email) {
return errors.New("static passwords: read-only cannot update password")
}
return s.Storage.UpdatePassword(email, updater)
return s.Storage.UpdatePassword(ctx, email, updater)
}
// staticConnectorsStorage represents a storage with read-only set of connectors.
@ -185,15 +185,15 @@ func (s staticConnectorsStorage) isStatic(id string) bool {
return ok
}
func (s staticConnectorsStorage) GetConnector(id string) (Connector, error) {
func (s staticConnectorsStorage) GetConnector(ctx context.Context, id string) (Connector, error) {
if connector, ok := s.connectorsByID[id]; ok {
return connector, nil
}
return s.Storage.GetConnector(id)
return s.Storage.GetConnector(ctx, id)
}
func (s staticConnectorsStorage) ListConnectors() ([]Connector, error) {
connectors, err := s.Storage.ListConnectors()
func (s staticConnectorsStorage) ListConnectors(ctx context.Context) ([]Connector, error) {
connectors, err := s.Storage.ListConnectors(ctx)
if err != nil {
return nil, err
}
@ -217,16 +217,16 @@ func (s staticConnectorsStorage) CreateConnector(ctx context.Context, c Connecto
return s.Storage.CreateConnector(ctx, c)
}
func (s staticConnectorsStorage) DeleteConnector(id string) error {
func (s staticConnectorsStorage) DeleteConnector(ctx context.Context, id string) error {
if s.isStatic(id) {
return errors.New("static connectors: read-only cannot delete connector")
}
return s.Storage.DeleteConnector(id)
return s.Storage.DeleteConnector(ctx, id)
}
func (s staticConnectorsStorage) UpdateConnector(id string, updater func(old Connector) (Connector, error)) error {
func (s staticConnectorsStorage) UpdateConnector(ctx context.Context, id string, updater func(old Connector) (Connector, error)) error {
if s.isStatic(id) {
return errors.New("static connectors: read-only cannot update connector")
}
return s.Storage.UpdateConnector(id, updater)
return s.Storage.UpdateConnector(ctx, id, updater)
}

62
storage/storage.go

@ -89,30 +89,30 @@ type Storage interface {
// TODO(ericchiang): return (T, bool, error) so we can indicate not found
// requests that way instead of using ErrNotFound.
GetAuthRequest(id string) (AuthRequest, error)
GetAuthCode(id string) (AuthCode, error)
GetClient(id string) (Client, error)
GetKeys() (Keys, error)
GetRefresh(id string) (RefreshToken, error)
GetPassword(email string) (Password, error)
GetOfflineSessions(userID string, connID string) (OfflineSessions, error)
GetConnector(id string) (Connector, error)
GetDeviceRequest(userCode string) (DeviceRequest, error)
GetDeviceToken(deviceCode string) (DeviceToken, error)
ListClients() ([]Client, error)
ListRefreshTokens() ([]RefreshToken, error)
ListPasswords() ([]Password, error)
ListConnectors() ([]Connector, error)
GetAuthRequest(ctx context.Context, id string) (AuthRequest, error)
GetAuthCode(ctx context.Context, id string) (AuthCode, error)
GetClient(ctx context.Context, id string) (Client, error)
GetKeys(ctx context.Context) (Keys, error)
GetRefresh(ctx context.Context, id string) (RefreshToken, error)
GetPassword(ctx context.Context, email string) (Password, error)
GetOfflineSessions(ctx context.Context, userID string, connID string) (OfflineSessions, error)
GetConnector(ctx context.Context, id string) (Connector, error)
GetDeviceRequest(ctx context.Context, userCode string) (DeviceRequest, error)
GetDeviceToken(ctx context.Context, deviceCode string) (DeviceToken, error)
ListClients(ctx context.Context) ([]Client, error)
ListRefreshTokens(ctx context.Context) ([]RefreshToken, error)
ListPasswords(ctx context.Context) ([]Password, error)
ListConnectors(ctx context.Context) ([]Connector, error)
// Delete methods MUST be atomic.
DeleteAuthRequest(id string) error
DeleteAuthCode(code string) error
DeleteClient(id string) error
DeleteRefresh(id string) error
DeletePassword(email string) error
DeleteOfflineSessions(userID string, connID string) error
DeleteConnector(id string) error
DeleteAuthRequest(ctx context.Context, id string) error
DeleteAuthCode(ctx context.Context, code string) error
DeleteClient(ctx context.Context, id string) error
DeleteRefresh(ctx context.Context, id string) error
DeletePassword(ctx context.Context, email string) error
DeleteOfflineSessions(ctx context.Context, userID string, connID string) error
DeleteConnector(ctx context.Context, id string) error
// Update methods take a function for updating an object then performs that update within
// a transaction. "updater" functions may be called multiple times by a single update call.
@ -128,18 +128,18 @@ type Storage interface {
// // update failed, handle error
// }
//
UpdateClient(id string, updater func(old Client) (Client, error)) error
UpdateKeys(updater func(old Keys) (Keys, error)) error
UpdateAuthRequest(id string, updater func(a AuthRequest) (AuthRequest, error)) error
UpdateRefreshToken(id string, updater func(r RefreshToken) (RefreshToken, error)) error
UpdatePassword(email string, updater func(p Password) (Password, error)) error
UpdateOfflineSessions(userID string, connID string, updater func(s OfflineSessions) (OfflineSessions, error)) error
UpdateConnector(id string, updater func(c Connector) (Connector, error)) error
UpdateDeviceToken(deviceCode string, updater func(t DeviceToken) (DeviceToken, error)) error
UpdateClient(ctx context.Context, id string, updater func(old Client) (Client, error)) error
UpdateKeys(ctx context.Context, updater func(old Keys) (Keys, error)) error
UpdateAuthRequest(ctx context.Context, id string, updater func(a AuthRequest) (AuthRequest, error)) error
UpdateRefreshToken(ctx context.Context, id string, updater func(r RefreshToken) (RefreshToken, error)) error
UpdatePassword(ctx context.Context, email string, updater func(p Password) (Password, error)) error
UpdateOfflineSessions(ctx context.Context, userID string, connID string, updater func(s OfflineSessions) (OfflineSessions, error)) error
UpdateConnector(ctx context.Context, id string, updater func(c Connector) (Connector, error)) error
UpdateDeviceToken(ctx context.Context, deviceCode string, updater func(t DeviceToken) (DeviceToken, error)) error
// GarbageCollect deletes all expired AuthCodes,
// AuthRequests, DeviceRequests, and DeviceTokens.
GarbageCollect(now time.Time) (GCResult, error)
GarbageCollect(ctx context.Context, now time.Time) (GCResult, error)
}
// Client represents an OAuth2 client.

Loading…
Cancel
Save