From ed5dee9960c86c7e19ce4e16d4b1db80c618d171 Mon Sep 17 00:00:00 2001 From: Eric Chiang Date: Tue, 16 Feb 2016 18:19:23 -0800 Subject: [PATCH] db: clean up quote and executor function calls, improve translate docs --- db/client.go | 36 +++++++++++++---------- db/conn.go | 8 ------ db/connector_config.go | 21 +++++++------- db/db.go | 60 +++++++++++++++++++++++++++++++++++++++ db/key.go | 22 +++++++++----- db/password.go | 10 +++---- db/refresh.go | 29 +++++++++++-------- db/session.go | 14 ++++----- db/session_key.go | 16 +++++------ db/transaction.go | 33 --------------------- db/translate/translate.go | 8 ++++-- db/user.go | 38 ++++++++++++------------- 12 files changed, 169 insertions(+), 126 deletions(-) create mode 100644 db/db.go delete mode 100644 db/transaction.go diff --git a/db/client.go b/db/client.go index 22fcae85..eee4e755 100644 --- a/db/client.go +++ b/db/client.go @@ -86,15 +86,21 @@ func (m *clientIdentityModel) ClientIdentity() (*oidc.ClientIdentity, error) { } func NewClientIdentityRepo(dbm *gorp.DbMap) client.ClientIdentityRepo { - return &clientIdentityRepo{dbMap: dbm} + return newClientIdentityRepo(dbm) +} + +func newClientIdentityRepo(dbm *gorp.DbMap) *clientIdentityRepo { + return &clientIdentityRepo{db: &db{dbm}} } func NewClientIdentityRepoFromClients(dbm *gorp.DbMap, clients []oidc.ClientIdentity) (client.ClientIdentityRepo, error) { - tx, err := dbm.Begin() + repo := newClientIdentityRepo(dbm) + tx, err := repo.begin() if err != nil { return nil, err } defer tx.Rollback() + exec := repo.executor(tx) for _, c := range clients { dec, err := base64.URLEncoding.DecodeString(c.Credentials.Secret) if err != nil { @@ -104,7 +110,7 @@ func NewClientIdentityRepoFromClients(dbm *gorp.DbMap, clients []oidc.ClientIden if err != nil { return nil, err } - err = tx.Insert(cm) + err = exec.Insert(cm) if err != nil { return nil, err } @@ -112,15 +118,15 @@ func NewClientIdentityRepoFromClients(dbm *gorp.DbMap, clients []oidc.ClientIden if err := tx.Commit(); err != nil { return nil, err } - return NewClientIdentityRepo(dbm), nil + return repo, nil } type clientIdentityRepo struct { - dbMap *gorp.DbMap + *db } func (r *clientIdentityRepo) Metadata(clientID string) (*oidc.ClientMetadata, error) { - m, err := r.dbMap.Get(clientIdentityModel{}, clientID) + m, err := r.executor(nil).Get(clientIdentityModel{}, clientID) if err == sql.ErrNoRows || m == nil { return nil, client.ErrorNotFound } @@ -143,7 +149,7 @@ func (r *clientIdentityRepo) Metadata(clientID string) (*oidc.ClientMetadata, er } func (r *clientIdentityRepo) IsDexAdmin(clientID string) (bool, error) { - m, err := r.dbMap.Get(clientIdentityModel{}, clientID) + m, err := r.executor(nil).Get(clientIdentityModel{}, clientID) if m == nil || err != nil { return false, err } @@ -158,15 +164,15 @@ func (r *clientIdentityRepo) IsDexAdmin(clientID string) (bool, error) { } func (r *clientIdentityRepo) SetDexAdmin(clientID string, isAdmin bool) error { - tx, err := r.dbMap.Begin() + tx, err := r.begin() if err != nil { return err } defer tx.Rollback() + exec := r.executor(tx) - m, err := tx.Get(clientIdentityModel{}, clientID) + m, err := exec.Get(clientIdentityModel{}, clientID) if m == nil || err != nil { - rollback(tx) return err } @@ -177,7 +183,7 @@ func (r *clientIdentityRepo) SetDexAdmin(clientID string, isAdmin bool) error { } cim.DexAdmin = isAdmin - _, err = tx.Update(cim) + _, err = exec.Update(cim) if err != nil { return err } @@ -186,7 +192,7 @@ func (r *clientIdentityRepo) SetDexAdmin(clientID string, isAdmin bool) error { } func (r *clientIdentityRepo) Authenticate(creds oidc.ClientCredentials) (bool, error) { - m, err := r.dbMap.Get(clientIdentityModel{}, creds.ID) + m, err := r.executor(nil).Get(clientIdentityModel{}, creds.ID) if m == nil || err != nil { return false, err } @@ -222,7 +228,7 @@ func (r *clientIdentityRepo) New(id string, meta oidc.ClientMetadata) (*oidc.Cli return nil, err } - if err := r.dbMap.Insert(cim); err != nil { + if err := r.executor(nil).Insert(cim); err != nil { switch sqlErr := err.(type) { case *pq.Error: if sqlErr.Code == pgErrorCodeUniqueViolation { @@ -246,9 +252,9 @@ func (r *clientIdentityRepo) New(id string, meta oidc.ClientMetadata) (*oidc.Cli } func (r *clientIdentityRepo) All() ([]oidc.ClientIdentity, error) { - qt := r.dbMap.Dialect.QuotedTableForQuery("", clientIdentityTableName) + qt := r.quote(clientIdentityTableName) q := fmt.Sprintf("SELECT * FROM %s", qt) - objs, err := r.dbMap.Select(&clientIdentityModel{}, q) + objs, err := r.executor(nil).Select(&clientIdentityModel{}, q) if err != nil { return nil, err } diff --git a/db/conn.go b/db/conn.go index 7d31bd52..e6256cc5 100644 --- a/db/conn.go +++ b/db/conn.go @@ -8,7 +8,6 @@ import ( "github.com/go-gorp/gorp" - "github.com/coreos/dex/pkg/log" "github.com/coreos/dex/repo" // Import database drivers @@ -95,13 +94,6 @@ func TransactionFactory(conn *gorp.DbMap) repo.TransactionFactory { } } -func rollback(tx *gorp.Transaction) { - err := tx.Rollback() - if err != nil { - log.Errorf("unable to rollback: %v", err) - } -} - // NewMemDB creates a new in memory sqlite3 database. func NewMemDB() *gorp.DbMap { dbMap, err := NewConnection(Config{DSN: "sqlite3://:memory:"}) diff --git a/db/connector_config.go b/db/connector_config.go index 5f5a1a17..60603d83 100644 --- a/db/connector_config.go +++ b/db/connector_config.go @@ -60,17 +60,17 @@ func (m *connectorConfigModel) ConnectorConfig() (connector.ConnectorConfig, err } func NewConnectorConfigRepo(dbm *gorp.DbMap) *ConnectorConfigRepo { - return &ConnectorConfigRepo{dbMap: dbm} + return &ConnectorConfigRepo{&db{dbm}} } type ConnectorConfigRepo struct { - dbMap *gorp.DbMap + *db } func (r *ConnectorConfigRepo) All() ([]connector.ConnectorConfig, error) { - qt := r.dbMap.Dialect.QuotedTableForQuery("", connectorConfigTableName) + qt := r.quote(connectorConfigTableName) q := fmt.Sprintf("SELECT * FROM %s", qt) - objs, err := r.dbMap.Select(&connectorConfigModel{}, q) + objs, err := r.executor(nil).Select(&connectorConfigModel{}, q) if err != nil { return nil, err } @@ -93,10 +93,10 @@ func (r *ConnectorConfigRepo) All() ([]connector.ConnectorConfig, error) { } func (r *ConnectorConfigRepo) GetConnectorByID(tx repo.Transaction, id string) (connector.ConnectorConfig, error) { - qt := r.dbMap.Dialect.QuotedTableForQuery("", connectorConfigTableName) + qt := r.quote(connectorConfigTableName) q := fmt.Sprintf("SELECT * FROM %s WHERE id = $1", qt) var c connectorConfigModel - if err := executor(r.dbMap, tx).SelectOne(&c, q, id); err != nil { + if err := r.executor(tx).SelectOne(&c, q, id); err != nil { if err == sql.ErrNoRows { return nil, connector.ErrorNotFound } @@ -116,19 +116,20 @@ func (r *ConnectorConfigRepo) Set(cfgs []connector.ConnectorConfig) error { insert[i] = m } - tx, err := r.dbMap.Begin() + tx, err := r.begin() if err != nil { return err } defer tx.Rollback() + exec := r.executor(tx) - qt := r.dbMap.Dialect.QuotedTableForQuery("", connectorConfigTableName) + qt := r.quote(connectorConfigTableName) q := fmt.Sprintf("DELETE FROM %s", qt) - if _, err = tx.Exec(q); err != nil { + if _, err = exec.Exec(q); err != nil { return err } - if err = tx.Insert(insert...); err != nil { + if err = exec.Insert(insert...); err != nil { return fmt.Errorf("DB insert failed %#v: %v", insert, err) } diff --git a/db/db.go b/db/db.go new file mode 100644 index 00000000..59584471 --- /dev/null +++ b/db/db.go @@ -0,0 +1,60 @@ +// Package db provides SQL implementations of dex's storage interfaces. +package db + +import ( + "github.com/go-gorp/gorp" + + "github.com/coreos/dex/db/translate" + "github.com/coreos/dex/repo" +) + +// db is the connection type passed to repos. +// +// TODO(ericchiang): Eventually just return this instead of gorp.DbMap during Conn. +// All actions should go through this type instead of dbMap. +type db struct { + dbMap *gorp.DbMap +} + +// executor returns a driver agnostic SQL executor. +// +// The expected flavor of all queries is the flavor used by github.com/lib/pq. All bind +// parameters must be unique and in sequential order (e.g. $1, $2, ...). +// +// See github.com/coreos/dex/db/translate for details on the translation. +// +// If tx is nil, a non-transaction context is provided. +func (conn *db) executor(tx repo.Transaction) gorp.SqlExecutor { + var exec gorp.SqlExecutor + if tx == nil { + exec = conn.dbMap + } else { + gorpTx, ok := tx.(*gorp.Transaction) + if !ok { + panic("wrong kind of transaction passed to a DB repo") + } + + // Check if the underlying value of the pointer is nil. + // This is not caught by the initial comparison (tx == nil). + if gorpTx == nil { + exec = conn.dbMap + } else { + exec = gorpTx + } + } + + if _, ok := conn.dbMap.Dialect.(gorp.SqliteDialect); ok { + exec = translate.NewTranslatingExecutor(exec, translate.PostgresToSQLite) + } + return exec +} + +// quote escapes a table name for a driver specific SQL query. quote uses the +// gorp's package underlying quote logic and should NOT be used on untrusted input. +func (conn *db) quote(tableName string) string { + return conn.dbMap.Dialect.QuotedTableForQuery("", tableName) +} + +func (conn *db) begin() (repo.Transaction, error) { + return conn.dbMap.Begin() +} diff --git a/db/key.go b/db/key.go index 8c4072c3..e1cf34cd 100644 --- a/db/key.go +++ b/db/key.go @@ -98,7 +98,7 @@ func NewPrivateKeySetRepo(dbm *gorp.DbMap, useOldFormat bool, secrets ...[]byte) } r := &PrivateKeySetRepo{ - dbMap: dbm, + db: &db{dbm}, useOldFormat: useOldFormat, secrets: secrets, } @@ -107,17 +107,22 @@ func NewPrivateKeySetRepo(dbm *gorp.DbMap, useOldFormat bool, secrets ...[]byte) } type PrivateKeySetRepo struct { - dbMap *gorp.DbMap + *db useOldFormat bool secrets [][]byte } func (r *PrivateKeySetRepo) Set(ks key.KeySet) error { - qt := r.dbMap.Dialect.QuotedTableForQuery("", keyTableName) - _, err := r.dbMap.Exec(fmt.Sprintf("DELETE FROM %s", qt)) + qt := r.quote(keyTableName) + tx, err := r.begin() if err != nil { return err } + defer tx.Rollback() + exec := r.executor(tx) + if _, err := exec.Exec(fmt.Sprintf("DELETE FROM %s", qt)); err != nil { + return err + } pks, ok := ks.(*key.PrivateKeySet) if !ok { @@ -147,12 +152,15 @@ func (r *PrivateKeySetRepo) Set(ks key.KeySet) error { } b := &privateKeySetBlob{Value: v} - return r.dbMap.Insert(b) + if err := exec.Insert(b); err != nil { + return err + } + return tx.Commit() } func (r *PrivateKeySetRepo) Get() (key.KeySet, error) { - qt := r.dbMap.Dialect.QuotedTableForQuery("", keyTableName) - objs, err := r.dbMap.Select(&privateKeySetBlob{}, fmt.Sprintf("SELECT * FROM %s", qt)) + qt := r.quote(keyTableName) + objs, err := r.executor(nil).Select(&privateKeySetBlob{}, fmt.Sprintf("SELECT * FROM %s", qt)) if err != nil { return nil, err } diff --git a/db/password.go b/db/password.go index afd2f247..c2402aa8 100644 --- a/db/password.go +++ b/db/password.go @@ -34,7 +34,7 @@ type passwordInfoModel struct { func NewPasswordInfoRepo(dbm *gorp.DbMap) user.PasswordInfoRepo { return &passwordInfoRepo{ - dbMap: dbm, + db: &db{dbm}, } } @@ -49,7 +49,7 @@ func NewPasswordInfoRepoFromPasswordInfos(dbm *gorp.DbMap, infos []user.Password } type passwordInfoRepo struct { - dbMap *gorp.DbMap + *db } func (r *passwordInfoRepo) Get(tx repo.Transaction, userID string) (user.PasswordInfo, error) { @@ -101,7 +101,7 @@ func (r *passwordInfoRepo) Update(tx repo.Transaction, pw user.PasswordInfo) err } func (r *passwordInfoRepo) get(tx repo.Transaction, id string) (user.PasswordInfo, error) { - ex := executor(r.dbMap, tx) + ex := r.executor(tx) m, err := ex.Get(passwordInfoModel{}, id) if err != nil { @@ -122,7 +122,7 @@ func (r *passwordInfoRepo) get(tx repo.Transaction, id string) (user.PasswordInf } func (r *passwordInfoRepo) insert(tx repo.Transaction, pw user.PasswordInfo) error { - ex := executor(r.dbMap, tx) + ex := r.executor(tx) pm, err := newPasswordInfoModel(&pw) if err != nil { return err @@ -131,7 +131,7 @@ func (r *passwordInfoRepo) insert(tx repo.Transaction, pw user.PasswordInfo) err } func (r *passwordInfoRepo) update(tx repo.Transaction, pw user.PasswordInfo) error { - ex := executor(r.dbMap, tx) + ex := r.executor(tx) pm, err := newPasswordInfoModel(&pw) if err != nil { return err diff --git a/db/refresh.go b/db/refresh.go index f2dc193a..552cc1e6 100644 --- a/db/refresh.go +++ b/db/refresh.go @@ -13,6 +13,7 @@ import ( "github.com/coreos/dex/pkg/log" "github.com/coreos/dex/refresh" + "github.com/coreos/dex/repo" ) const ( @@ -29,7 +30,7 @@ func init() { } type refreshTokenRepo struct { - dbMap *gorp.DbMap + *db tokenGenerator refresh.RefreshTokenGenerator } @@ -77,15 +78,12 @@ func checkTokenPayload(payloadHash, payload []byte) error { } func NewRefreshTokenRepo(dbm *gorp.DbMap) refresh.RefreshTokenRepo { - return &refreshTokenRepo{ - dbMap: dbm, - tokenGenerator: refresh.DefaultRefreshTokenGenerator, - } + return NewRefreshTokenRepoWithGenerator(dbm, refresh.DefaultRefreshTokenGenerator) } func NewRefreshTokenRepoWithGenerator(dbm *gorp.DbMap, gen refresh.RefreshTokenGenerator) refresh.RefreshTokenRepo { return &refreshTokenRepo{ - dbMap: dbm, + db: &db{dbm}, tokenGenerator: gen, } } @@ -115,7 +113,7 @@ func (r *refreshTokenRepo) Create(userID, clientID string) (string, error) { ClientID: clientID, } - if err := r.dbMap.Insert(record); err != nil { + if err := r.executor(nil).Insert(record); err != nil { return "", err } @@ -151,7 +149,13 @@ func (r *refreshTokenRepo) Revoke(userID, token string) error { return err } - record, err := r.get(nil, tokenID) + tx, err := r.begin() + if err != nil { + return err + } + defer tx.Rollback() + exec := r.executor(tx) + record, err := r.get(tx, tokenID) if err != nil { return err } @@ -164,7 +168,7 @@ func (r *refreshTokenRepo) Revoke(userID, token string) error { return err } - deleted, err := r.dbMap.Delete(record) + deleted, err := exec.Delete(record) if err != nil { return err } @@ -172,10 +176,11 @@ func (r *refreshTokenRepo) Revoke(userID, token string) error { return refresh.ErrorInvalidToken } - return nil + return tx.Commit() } -func (r *refreshTokenRepo) get(tx *gorp.Transaction, tokenID int64) (*refreshTokenModel, error) { - ex := executor(r.dbMap, tx) + +func (r *refreshTokenRepo) get(tx repo.Transaction, tokenID int64) (*refreshTokenModel, error) { + ex := r.executor(tx) result, err := ex.Get(refreshTokenModel{}, tokenID) if err != nil { return nil, err diff --git a/db/session.go b/db/session.go index 172985fc..1eb05cfe 100644 --- a/db/session.go +++ b/db/session.go @@ -123,16 +123,16 @@ func NewSessionRepo(dbm *gorp.DbMap) *SessionRepo { } func NewSessionRepoWithClock(dbm *gorp.DbMap, clock clockwork.Clock) *SessionRepo { - return &SessionRepo{dbMap: dbm, clock: clock} + return &SessionRepo{db: &db{dbm}, clock: clock} } type SessionRepo struct { - dbMap *gorp.DbMap + *db clock clockwork.Clock } func (r *SessionRepo) Get(sessionID string) (*session.Session, error) { - m, err := r.dbMap.Get(sessionModel{}, sessionID) + m, err := r.executor(nil).Get(sessionModel{}, sessionID) if err != nil { return nil, err } @@ -163,7 +163,7 @@ func (r *SessionRepo) Create(s session.Session) error { if err != nil { return err } - return r.dbMap.Insert(sm) + return r.executor(nil).Insert(sm) } func (r *SessionRepo) Update(s session.Session) error { @@ -171,7 +171,7 @@ func (r *SessionRepo) Update(s session.Session) error { if err != nil { return err } - n, err := r.dbMap.Update(sm) + n, err := r.executor(nil).Update(sm) if err != nil { return err } @@ -182,9 +182,9 @@ func (r *SessionRepo) Update(s session.Session) error { } func (r *SessionRepo) purge() error { - qt := r.dbMap.Dialect.QuotedTableForQuery("", sessionTableName) + qt := r.quote(sessionTableName) q := fmt.Sprintf("DELETE FROM %s WHERE expires_at < $1 OR state = $2", qt) - res, err := executor(r.dbMap, nil).Exec(q, r.clock.Now().Unix(), string(session.SessionStateDead)) + res, err := r.executor(nil).Exec(q, r.clock.Now().Unix(), string(session.SessionStateDead)) if err != nil { return err } diff --git a/db/session_key.go b/db/session_key.go index f58e7610..ff35876c 100644 --- a/db/session_key.go +++ b/db/session_key.go @@ -38,11 +38,11 @@ func NewSessionKeyRepo(dbm *gorp.DbMap) *SessionKeyRepo { } func NewSessionKeyRepoWithClock(dbm *gorp.DbMap, clock clockwork.Clock) *SessionKeyRepo { - return &SessionKeyRepo{dbMap: dbm, clock: clock} + return &SessionKeyRepo{db: &db{dbm}, clock: clock} } type SessionKeyRepo struct { - dbMap *gorp.DbMap + *db clock clockwork.Clock } @@ -53,11 +53,11 @@ func (r *SessionKeyRepo) Push(sk session.SessionKey, exp time.Duration) error { ExpiresAt: r.clock.Now().Unix() + int64(exp.Seconds()), Stale: false, } - return r.dbMap.Insert(skm) + return r.executor(nil).Insert(skm) } func (r *SessionKeyRepo) Pop(key string) (string, error) { - m, err := r.dbMap.Get(sessionKeyModel{}, key) + m, err := r.executor(nil).Get(sessionKeyModel{}, key) if err != nil { return "", err } @@ -76,9 +76,9 @@ func (r *SessionKeyRepo) Pop(key string) (string, error) { return "", errors.New("invalid session key") } - qt := r.dbMap.Dialect.QuotedTableForQuery("", sessionKeyTableName) + qt := r.quote(sessionKeyTableName) q := fmt.Sprintf("UPDATE %s SET stale=$1 WHERE key=$2 AND stale=$3", qt) - res, err := executor(r.dbMap, nil).Exec(q, true, key, false) + res, err := r.executor(nil).Exec(q, true, key, false) if err != nil { return "", err } @@ -94,9 +94,9 @@ func (r *SessionKeyRepo) Pop(key string) (string, error) { } func (r *SessionKeyRepo) purge() error { - qt := r.dbMap.Dialect.QuotedTableForQuery("", sessionKeyTableName) + qt := r.quote(sessionKeyTableName) q := fmt.Sprintf("DELETE FROM %s WHERE stale = $1 OR expires_at < $2", qt) - res, err := executor(r.dbMap, nil).Exec(q, true, r.clock.Now().Unix()) + res, err := r.executor(nil).Exec(q, true, r.clock.Now().Unix()) if err != nil { return err } diff --git a/db/transaction.go b/db/transaction.go deleted file mode 100644 index 19da5a13..00000000 --- a/db/transaction.go +++ /dev/null @@ -1,33 +0,0 @@ -package db - -import ( - "github.com/go-gorp/gorp" - - "github.com/coreos/dex/db/translate" - "github.com/coreos/dex/repo" -) - -func executor(dbMap *gorp.DbMap, tx repo.Transaction) gorp.SqlExecutor { - var exec gorp.SqlExecutor - if tx == nil { - exec = dbMap - } else { - gorpTx, ok := tx.(*gorp.Transaction) - if !ok { - panic("wrong kind of transaction passed to a DB repo") - } - - // Check if the underlying value of the pointer is nil. - // This is not caught by the initial comparison (tx == nil). - if gorpTx == nil { - exec = dbMap - } else { - exec = gorpTx - } - } - - if _, ok := dbMap.Dialect.(gorp.SqliteDialect); ok { - exec = translate.NewExecutor(exec, translate.PostgresToSQLite) - } - return exec -} diff --git a/db/translate/translate.go b/db/translate/translate.go index 390b95ce..ce5ddc0b 100644 --- a/db/translate/translate.go +++ b/db/translate/translate.go @@ -15,14 +15,18 @@ var ( trueRegexp = regexp.MustCompile(`\btrue\b`) ) -// PostgresToSQLite implements translation of the pq driver to sqlite3. +// PostgresToSQLite translates github.com/lib/pq flavored SQL quries to github.com/mattn/go-sqlite3's flavor. +// +// It assumes that possitional bind arguements ($1, $2, etc.) are unqiue and in sequential order. func PostgresToSQLite(query string) string { query = bindRegexp.ReplaceAllString(query, "?") query = trueRegexp.ReplaceAllString(query, "1") return query } -func NewExecutor(exec gorp.SqlExecutor, translate func(string) string) gorp.SqlExecutor { +// NewTranslatingExecutor returns an executor wrapping the existing executor. All query strings passed to +// the executor will be run through the translate function before begin passed to the driver. +func NewTranslatingExecutor(exec gorp.SqlExecutor, translate func(string) string) gorp.SqlExecutor { return &executor{exec, translate} } diff --git a/db/user.go b/db/user.go index 86f4638a..89e49b3c 100644 --- a/db/user.go +++ b/db/user.go @@ -41,7 +41,7 @@ func init() { func NewUserRepo(dbm *gorp.DbMap) user.UserRepo { return &userRepo{ - dbMap: dbm, + db: &db{dbm}, } } @@ -52,7 +52,7 @@ func NewUserRepoFromUsers(dbm *gorp.DbMap, us []user.UserWithRemoteIdentities) ( if err != nil { return nil, err } - err = repo.dbMap.Insert(um) + err = repo.executor(nil).Insert(um) for _, ri := range u.RemoteIdentities { err = repo.AddRemoteIdentity(nil, u.User.ID, ri) if err != nil { @@ -64,7 +64,7 @@ func NewUserRepoFromUsers(dbm *gorp.DbMap, us []user.UserWithRemoteIdentities) ( } type userRepo struct { - dbMap *gorp.DbMap + *db } func (r *userRepo) Get(tx repo.Transaction, userID string) (user.User, error) { @@ -106,8 +106,8 @@ func (r *userRepo) Disable(tx repo.Transaction, userID string, disable bool) err return user.ErrorInvalidID } - qt := r.dbMap.Dialect.QuotedTableForQuery("", userTableName) - ex := executor(r.dbMap, tx) + qt := r.quote(userTableName) + ex := r.executor(tx) result, err := ex.Exec(fmt.Sprintf("UPDATE %s SET disabled = $1 WHERE id = $2;", qt), disable, userID) if err != nil { return err @@ -220,7 +220,7 @@ func (r *userRepo) RemoveRemoteIdentity(tx repo.Transaction, userID string, rid return err } - ex := executor(r.dbMap, tx) + ex := r.executor(tx) deleted, err := ex.Delete(rim) if err != nil { @@ -235,12 +235,12 @@ func (r *userRepo) RemoveRemoteIdentity(tx repo.Transaction, userID string, rid } func (r *userRepo) GetRemoteIdentities(tx repo.Transaction, userID string) ([]user.RemoteIdentity, error) { - ex := executor(r.dbMap, tx) + ex := r.executor(tx) if userID == "" { return nil, user.ErrorInvalidID } - qt := r.dbMap.Dialect.QuotedTableForQuery("", remoteIdentityMappingTableName) + qt := r.quote(remoteIdentityMappingTableName) rims, err := ex.Select(&remoteIdentityMappingModel{}, fmt.Sprintf("SELECT * FROM %s WHERE user_id = $1", qt), userID) if err != nil { @@ -271,8 +271,8 @@ func (r *userRepo) GetRemoteIdentities(tx repo.Transaction, userID string) ([]us } func (r *userRepo) GetAdminCount(tx repo.Transaction) (int, error) { - qt := r.dbMap.Dialect.QuotedTableForQuery("", userTableName) - ex := executor(r.dbMap, tx) + qt := r.quote(userTableName) + ex := r.executor(tx) i, err := ex.SelectInt(fmt.Sprintf("SELECT count(*) FROM %s WHERE admin=true;", qt)) return int(i), err } @@ -286,9 +286,9 @@ func (r *userRepo) List(tx repo.Transaction, filter user.UserFilter, maxResults if err != nil { return nil, "", err } - ex := executor(r.dbMap, tx) + ex := r.executor(tx) - qt := r.dbMap.Dialect.QuotedTableForQuery("", userTableName) + qt := r.quote(userTableName) // Ask for one more than needed so we know if there's more results, and // hence, whether a nextPageToken is necessary. @@ -336,7 +336,7 @@ func (r *userRepo) List(tx repo.Transaction, filter user.UserFilter, maxResults } func (r *userRepo) insert(tx repo.Transaction, usr user.User) error { - ex := executor(r.dbMap, tx) + ex := r.executor(tx) um, err := newUserModel(&usr) if err != nil { return err @@ -345,7 +345,7 @@ func (r *userRepo) insert(tx repo.Transaction, usr user.User) error { } func (r *userRepo) update(tx repo.Transaction, usr user.User) error { - ex := executor(r.dbMap, tx) + ex := r.executor(tx) um, err := newUserModel(&usr) if err != nil { return err @@ -355,7 +355,7 @@ func (r *userRepo) update(tx repo.Transaction, usr user.User) error { } func (r *userRepo) get(tx repo.Transaction, userID string) (user.User, error) { - ex := executor(r.dbMap, tx) + ex := r.executor(tx) m, err := ex.Get(userModel{}, userID) if err != nil { @@ -376,7 +376,7 @@ func (r *userRepo) get(tx repo.Transaction, userID string) (user.User, error) { } func (r *userRepo) getUserIDForRemoteIdentity(tx repo.Transaction, ri user.RemoteIdentity) (string, error) { - ex := executor(r.dbMap, tx) + ex := r.executor(tx) m, err := ex.Get(remoteIdentityMappingModel{}, ri.ConnectorID, ri.ID) if err != nil { @@ -397,8 +397,8 @@ func (r *userRepo) getUserIDForRemoteIdentity(tx repo.Transaction, ri user.Remot } func (r *userRepo) getByEmail(tx repo.Transaction, email string) (user.User, error) { - qt := r.dbMap.Dialect.QuotedTableForQuery("", userTableName) - ex := executor(r.dbMap, tx) + qt := r.quote(userTableName) + ex := r.executor(tx) var um userModel err := ex.SelectOne(&um, fmt.Sprintf("select * from %s where email = $1", qt), email) @@ -412,7 +412,7 @@ func (r *userRepo) getByEmail(tx repo.Transaction, email string) (user.User, err } func (r *userRepo) insertRemoteIdentity(tx repo.Transaction, userID string, ri user.RemoteIdentity) error { - ex := executor(r.dbMap, tx) + ex := r.executor(tx) rim, err := newRemoteIdentityMappingModel(userID, ri) if err != nil {