|
|
|
|
@ -10,6 +10,16 @@ import (
|
|
|
|
|
"entgo.io/ent/dialect" |
|
|
|
|
"entgo.io/ent/dialect/sql" |
|
|
|
|
"entgo.io/ent/dialect/sql/sqlgraph" |
|
|
|
|
"github.com/dexidp/dex/storage/ent/db/authcode" |
|
|
|
|
"github.com/dexidp/dex/storage/ent/db/authrequest" |
|
|
|
|
"github.com/dexidp/dex/storage/ent/db/connector" |
|
|
|
|
"github.com/dexidp/dex/storage/ent/db/devicerequest" |
|
|
|
|
"github.com/dexidp/dex/storage/ent/db/devicetoken" |
|
|
|
|
"github.com/dexidp/dex/storage/ent/db/keys" |
|
|
|
|
"github.com/dexidp/dex/storage/ent/db/oauth2client" |
|
|
|
|
"github.com/dexidp/dex/storage/ent/db/offlinesession" |
|
|
|
|
"github.com/dexidp/dex/storage/ent/db/password" |
|
|
|
|
"github.com/dexidp/dex/storage/ent/db/refreshtoken" |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
// ent aliases to avoid import conflicts in user's code.
|
|
|
|
|
@ -25,36 +35,64 @@ type (
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
// OrderFunc applies an ordering on the sql selector.
|
|
|
|
|
type OrderFunc func(*sql.Selector, func(string) bool) |
|
|
|
|
type OrderFunc func(*sql.Selector) |
|
|
|
|
|
|
|
|
|
// columnChecker returns a function indicates if the column exists in the given column.
|
|
|
|
|
func columnChecker(table string) func(string) error { |
|
|
|
|
checks := map[string]func(string) bool{ |
|
|
|
|
authcode.Table: authcode.ValidColumn, |
|
|
|
|
authrequest.Table: authrequest.ValidColumn, |
|
|
|
|
connector.Table: connector.ValidColumn, |
|
|
|
|
devicerequest.Table: devicerequest.ValidColumn, |
|
|
|
|
devicetoken.Table: devicetoken.ValidColumn, |
|
|
|
|
keys.Table: keys.ValidColumn, |
|
|
|
|
oauth2client.Table: oauth2client.ValidColumn, |
|
|
|
|
offlinesession.Table: offlinesession.ValidColumn, |
|
|
|
|
password.Table: password.ValidColumn, |
|
|
|
|
refreshtoken.Table: refreshtoken.ValidColumn, |
|
|
|
|
} |
|
|
|
|
check, ok := checks[table] |
|
|
|
|
if !ok { |
|
|
|
|
return func(string) error { |
|
|
|
|
return fmt.Errorf("unknown table %q", table) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
return func(column string) error { |
|
|
|
|
if !check(column) { |
|
|
|
|
return fmt.Errorf("unknown column %q for table %q", column, table) |
|
|
|
|
} |
|
|
|
|
return nil |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Asc applies the given fields in ASC order.
|
|
|
|
|
func Asc(fields ...string) OrderFunc { |
|
|
|
|
return func(s *sql.Selector, check func(string) bool) { |
|
|
|
|
return func(s *sql.Selector) { |
|
|
|
|
check := columnChecker(s.TableName()) |
|
|
|
|
for _, f := range fields { |
|
|
|
|
if check(f) { |
|
|
|
|
s.OrderBy(sql.Asc(f)) |
|
|
|
|
} else { |
|
|
|
|
s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) |
|
|
|
|
if err := check(f); err != nil { |
|
|
|
|
s.AddError(&ValidationError{Name: f, err: fmt.Errorf("db: %w", err)}) |
|
|
|
|
} |
|
|
|
|
s.OrderBy(sql.Asc(s.C(f))) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Desc applies the given fields in DESC order.
|
|
|
|
|
func Desc(fields ...string) OrderFunc { |
|
|
|
|
return func(s *sql.Selector, check func(string) bool) { |
|
|
|
|
return func(s *sql.Selector) { |
|
|
|
|
check := columnChecker(s.TableName()) |
|
|
|
|
for _, f := range fields { |
|
|
|
|
if check(f) { |
|
|
|
|
s.OrderBy(sql.Desc(f)) |
|
|
|
|
} else { |
|
|
|
|
s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) |
|
|
|
|
if err := check(f); err != nil { |
|
|
|
|
s.AddError(&ValidationError{Name: f, err: fmt.Errorf("db: %w", err)}) |
|
|
|
|
} |
|
|
|
|
s.OrderBy(sql.Desc(s.C(f))) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// AggregateFunc applies an aggregation step on the group-by traversal/selector.
|
|
|
|
|
type AggregateFunc func(*sql.Selector, func(string) bool) string |
|
|
|
|
type AggregateFunc func(*sql.Selector) string |
|
|
|
|
|
|
|
|
|
// As is a pseudo aggregation function for renaming another other functions with custom names. For example:
|
|
|
|
|
//
|
|
|
|
|
@ -63,23 +101,24 @@ type AggregateFunc func(*sql.Selector, func(string) bool) string
|
|
|
|
|
// Scan(ctx, &v)
|
|
|
|
|
//
|
|
|
|
|
func As(fn AggregateFunc, end string) AggregateFunc { |
|
|
|
|
return func(s *sql.Selector, check func(string) bool) string { |
|
|
|
|
return sql.As(fn(s, check), end) |
|
|
|
|
return func(s *sql.Selector) string { |
|
|
|
|
return sql.As(fn(s), end) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Count applies the "count" aggregation function on each group.
|
|
|
|
|
func Count() AggregateFunc { |
|
|
|
|
return func(s *sql.Selector, _ func(string) bool) string { |
|
|
|
|
return func(s *sql.Selector) string { |
|
|
|
|
return sql.Count("*") |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Max applies the "max" aggregation function on the given field of each group.
|
|
|
|
|
func Max(field string) AggregateFunc { |
|
|
|
|
return func(s *sql.Selector, check func(string) bool) string { |
|
|
|
|
if !check(field) { |
|
|
|
|
s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) |
|
|
|
|
return func(s *sql.Selector) string { |
|
|
|
|
check := columnChecker(s.TableName()) |
|
|
|
|
if err := check(field); err != nil { |
|
|
|
|
s.AddError(&ValidationError{Name: field, err: fmt.Errorf("db: %w", err)}) |
|
|
|
|
return "" |
|
|
|
|
} |
|
|
|
|
return sql.Max(s.C(field)) |
|
|
|
|
@ -88,9 +127,10 @@ func Max(field string) AggregateFunc {
|
|
|
|
|
|
|
|
|
|
// Mean applies the "mean" aggregation function on the given field of each group.
|
|
|
|
|
func Mean(field string) AggregateFunc { |
|
|
|
|
return func(s *sql.Selector, check func(string) bool) string { |
|
|
|
|
if !check(field) { |
|
|
|
|
s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) |
|
|
|
|
return func(s *sql.Selector) string { |
|
|
|
|
check := columnChecker(s.TableName()) |
|
|
|
|
if err := check(field); err != nil { |
|
|
|
|
s.AddError(&ValidationError{Name: field, err: fmt.Errorf("db: %w", err)}) |
|
|
|
|
return "" |
|
|
|
|
} |
|
|
|
|
return sql.Avg(s.C(field)) |
|
|
|
|
@ -99,9 +139,10 @@ func Mean(field string) AggregateFunc {
|
|
|
|
|
|
|
|
|
|
// Min applies the "min" aggregation function on the given field of each group.
|
|
|
|
|
func Min(field string) AggregateFunc { |
|
|
|
|
return func(s *sql.Selector, check func(string) bool) string { |
|
|
|
|
if !check(field) { |
|
|
|
|
s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) |
|
|
|
|
return func(s *sql.Selector) string { |
|
|
|
|
check := columnChecker(s.TableName()) |
|
|
|
|
if err := check(field); err != nil { |
|
|
|
|
s.AddError(&ValidationError{Name: field, err: fmt.Errorf("db: %w", err)}) |
|
|
|
|
return "" |
|
|
|
|
} |
|
|
|
|
return sql.Min(s.C(field)) |
|
|
|
|
@ -110,9 +151,10 @@ func Min(field string) AggregateFunc {
|
|
|
|
|
|
|
|
|
|
// Sum applies the "sum" aggregation function on the given field of each group.
|
|
|
|
|
func Sum(field string) AggregateFunc { |
|
|
|
|
return func(s *sql.Selector, check func(string) bool) string { |
|
|
|
|
if !check(field) { |
|
|
|
|
s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) |
|
|
|
|
return func(s *sql.Selector) string { |
|
|
|
|
check := columnChecker(s.TableName()) |
|
|
|
|
if err := check(field); err != nil { |
|
|
|
|
s.AddError(&ValidationError{Name: field, err: fmt.Errorf("db: %w", err)}) |
|
|
|
|
return "" |
|
|
|
|
} |
|
|
|
|
return sql.Sum(s.C(field)) |
|
|
|
|
|