|
|
|
|
@ -7,6 +7,7 @@ import (
|
|
|
|
|
"errors" |
|
|
|
|
"fmt" |
|
|
|
|
"reflect" |
|
|
|
|
"sync" |
|
|
|
|
|
|
|
|
|
"entgo.io/ent" |
|
|
|
|
"entgo.io/ent/dialect/sql" |
|
|
|
|
@ -69,42 +70,38 @@ func NewTxContext(parent context.Context, tx *Tx) context.Context {
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// OrderFunc applies an ordering on the sql selector.
|
|
|
|
|
// Deprecated: Use Asc/Desc functions or the package builders instead.
|
|
|
|
|
type OrderFunc func(*sql.Selector) |
|
|
|
|
|
|
|
|
|
// columnChecker returns a function indicates if the column exists in the given column.
|
|
|
|
|
func columnChecker(table string) func(string) error { |
|
|
|
|
checks := map[string]func(string) bool{ |
|
|
|
|
authcode.Table: authcode.ValidColumn, |
|
|
|
|
authrequest.Table: authrequest.ValidColumn, |
|
|
|
|
connector.Table: connector.ValidColumn, |
|
|
|
|
devicerequest.Table: devicerequest.ValidColumn, |
|
|
|
|
devicetoken.Table: devicetoken.ValidColumn, |
|
|
|
|
keys.Table: keys.ValidColumn, |
|
|
|
|
oauth2client.Table: oauth2client.ValidColumn, |
|
|
|
|
offlinesession.Table: offlinesession.ValidColumn, |
|
|
|
|
password.Table: password.ValidColumn, |
|
|
|
|
refreshtoken.Table: refreshtoken.ValidColumn, |
|
|
|
|
} |
|
|
|
|
check, ok := checks[table] |
|
|
|
|
if !ok { |
|
|
|
|
return func(string) error { |
|
|
|
|
return fmt.Errorf("unknown table %q", table) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
return func(column string) error { |
|
|
|
|
if !check(column) { |
|
|
|
|
return fmt.Errorf("unknown column %q for table %q", column, table) |
|
|
|
|
} |
|
|
|
|
return nil |
|
|
|
|
} |
|
|
|
|
var ( |
|
|
|
|
initCheck sync.Once |
|
|
|
|
columnCheck sql.ColumnCheck |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
// columnChecker checks if the column exists in the given table.
|
|
|
|
|
func checkColumn(table, column string) error { |
|
|
|
|
initCheck.Do(func() { |
|
|
|
|
columnCheck = sql.NewColumnCheck(map[string]func(string) bool{ |
|
|
|
|
authcode.Table: authcode.ValidColumn, |
|
|
|
|
authrequest.Table: authrequest.ValidColumn, |
|
|
|
|
connector.Table: connector.ValidColumn, |
|
|
|
|
devicerequest.Table: devicerequest.ValidColumn, |
|
|
|
|
devicetoken.Table: devicetoken.ValidColumn, |
|
|
|
|
keys.Table: keys.ValidColumn, |
|
|
|
|
oauth2client.Table: oauth2client.ValidColumn, |
|
|
|
|
offlinesession.Table: offlinesession.ValidColumn, |
|
|
|
|
password.Table: password.ValidColumn, |
|
|
|
|
refreshtoken.Table: refreshtoken.ValidColumn, |
|
|
|
|
}) |
|
|
|
|
}) |
|
|
|
|
return columnCheck(table, column) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Asc applies the given fields in ASC order.
|
|
|
|
|
func Asc(fields ...string) OrderFunc { |
|
|
|
|
func Asc(fields ...string) func(*sql.Selector) { |
|
|
|
|
return func(s *sql.Selector) { |
|
|
|
|
check := columnChecker(s.TableName()) |
|
|
|
|
for _, f := range fields { |
|
|
|
|
if err := check(f); err != nil { |
|
|
|
|
if err := checkColumn(s.TableName(), f); err != nil { |
|
|
|
|
s.AddError(&ValidationError{Name: f, err: fmt.Errorf("db: %w", err)}) |
|
|
|
|
} |
|
|
|
|
s.OrderBy(sql.Asc(s.C(f))) |
|
|
|
|
@ -113,11 +110,10 @@ func Asc(fields ...string) OrderFunc {
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Desc applies the given fields in DESC order.
|
|
|
|
|
func Desc(fields ...string) OrderFunc { |
|
|
|
|
func Desc(fields ...string) func(*sql.Selector) { |
|
|
|
|
return func(s *sql.Selector) { |
|
|
|
|
check := columnChecker(s.TableName()) |
|
|
|
|
for _, f := range fields { |
|
|
|
|
if err := check(f); err != nil { |
|
|
|
|
if err := checkColumn(s.TableName(), f); err != nil { |
|
|
|
|
s.AddError(&ValidationError{Name: f, err: fmt.Errorf("db: %w", err)}) |
|
|
|
|
} |
|
|
|
|
s.OrderBy(sql.Desc(s.C(f))) |
|
|
|
|
@ -149,8 +145,7 @@ func Count() AggregateFunc {
|
|
|
|
|
// Max applies the "max" aggregation function on the given field of each group.
|
|
|
|
|
func Max(field string) AggregateFunc { |
|
|
|
|
return func(s *sql.Selector) string { |
|
|
|
|
check := columnChecker(s.TableName()) |
|
|
|
|
if err := check(field); err != nil { |
|
|
|
|
if err := checkColumn(s.TableName(), field); err != nil { |
|
|
|
|
s.AddError(&ValidationError{Name: field, err: fmt.Errorf("db: %w", err)}) |
|
|
|
|
return "" |
|
|
|
|
} |
|
|
|
|
@ -161,8 +156,7 @@ func Max(field string) AggregateFunc {
|
|
|
|
|
// Mean applies the "mean" aggregation function on the given field of each group.
|
|
|
|
|
func Mean(field string) AggregateFunc { |
|
|
|
|
return func(s *sql.Selector) string { |
|
|
|
|
check := columnChecker(s.TableName()) |
|
|
|
|
if err := check(field); err != nil { |
|
|
|
|
if err := checkColumn(s.TableName(), field); err != nil { |
|
|
|
|
s.AddError(&ValidationError{Name: field, err: fmt.Errorf("db: %w", err)}) |
|
|
|
|
return "" |
|
|
|
|
} |
|
|
|
|
@ -173,8 +167,7 @@ func Mean(field string) AggregateFunc {
|
|
|
|
|
// Min applies the "min" aggregation function on the given field of each group.
|
|
|
|
|
func Min(field string) AggregateFunc { |
|
|
|
|
return func(s *sql.Selector) string { |
|
|
|
|
check := columnChecker(s.TableName()) |
|
|
|
|
if err := check(field); err != nil { |
|
|
|
|
if err := checkColumn(s.TableName(), field); err != nil { |
|
|
|
|
s.AddError(&ValidationError{Name: field, err: fmt.Errorf("db: %w", err)}) |
|
|
|
|
return "" |
|
|
|
|
} |
|
|
|
|
@ -185,8 +178,7 @@ func Min(field string) AggregateFunc {
|
|
|
|
|
// Sum applies the "sum" aggregation function on the given field of each group.
|
|
|
|
|
func Sum(field string) AggregateFunc { |
|
|
|
|
return func(s *sql.Selector) string { |
|
|
|
|
check := columnChecker(s.TableName()) |
|
|
|
|
if err := check(field); err != nil { |
|
|
|
|
if err := checkColumn(s.TableName(), field); err != nil { |
|
|
|
|
s.AddError(&ValidationError{Name: field, err: fmt.Errorf("db: %w", err)}) |
|
|
|
|
return "" |
|
|
|
|
} |
|
|
|
|
|