mirror of https://github.com/dexidp/dex.git
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
223 lines
5.6 KiB
223 lines
5.6 KiB
package db |
|
|
|
import ( |
|
"fmt" |
|
"os" |
|
"sort" |
|
"strconv" |
|
"testing" |
|
|
|
"github.com/go-gorp/gorp" |
|
"github.com/kylelemons/godebug/pretty" |
|
) |
|
|
|
func initDB(dsn string) *gorp.DbMap { |
|
c, err := NewConnection(Config{DSN: dsn}) |
|
if err != nil { |
|
panic(fmt.Sprintf("error making db connection: %q", err)) |
|
} |
|
if _, err := c.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s;", migrationTable)); err != nil { |
|
panic(fmt.Sprintf("failed to drop migration table: %v", err)) |
|
} |
|
if err = c.DropTablesIfExists(); err != nil { |
|
panic(fmt.Sprintf("Unable to drop database tables: %v", err)) |
|
} |
|
|
|
return c |
|
} |
|
|
|
// TestGetPlannedMigrations is a sanity check, ensuring that at least one |
|
// migration can be found. |
|
func TestGetPlannedMigrations(t *testing.T) { |
|
dsn := os.Getenv("DEX_TEST_DSN") |
|
if dsn == "" { |
|
t.Skip("Test will not run without DEX_TEST_DSN environment variable.") |
|
return |
|
} |
|
dbMap := initDB(dsn) |
|
ms, err := GetPlannedMigrations(dbMap) |
|
if err != nil { |
|
pwd, err := os.Getwd() |
|
t.Logf("pwd: %v", pwd) |
|
t.Fatalf("unexpected err: %q", err) |
|
} |
|
|
|
if len(ms) == 0 { |
|
t.Fatalf("expected non-empty migrations") |
|
} |
|
} |
|
|
|
func TestMigrateClientMetadata(t *testing.T) { |
|
// oldClientModel exists to model what the client model looked like at |
|
// migration time. Without using this, the test fails because there's no |
|
// columns for the new fields. |
|
type oldClientModel struct { |
|
ID string `db:"id"` |
|
Secret []byte `db:"secret"` |
|
Metadata string `db:"metadata"` |
|
DexAdmin bool `db:"dex_admin"` |
|
} |
|
register(table{ |
|
name: clientTableName, |
|
model: oldClientModel{}, |
|
autoinc: false, |
|
pkey: []string{"id"}, |
|
}) |
|
|
|
dsn := os.Getenv("DEX_TEST_DSN") |
|
if dsn == "" { |
|
t.Skip("Test will not run without DEX_TEST_DSN environment variable.") |
|
return |
|
} |
|
dbMap := initDB(dsn) |
|
|
|
nMigrations := 9 |
|
n, err := MigrateMaxMigrations(dbMap, nMigrations) |
|
if err != nil { |
|
t.Fatalf("failed to perform initial migration: %v", err) |
|
} |
|
if n != nMigrations { |
|
t.Fatalf("expected to perform %d migrations, got %d", nMigrations, n) |
|
} |
|
|
|
tests := []struct { |
|
before string |
|
after string |
|
}{ |
|
// only update rows without a "redirect_uris" key |
|
{ |
|
`{"redirectURLs":["foo"]}`, |
|
`{"redirectURLs" : ["foo"], "redirect_uris" : ["foo"]}`, |
|
}, |
|
{ |
|
`{"redirectURLs":["foo","bar"]}`, |
|
`{"redirectURLs" : ["foo","bar"], "redirect_uris" : ["foo","bar"]}`, |
|
}, |
|
{ |
|
`{"redirect_uris":["foo"],"another_field":8}`, |
|
`{"redirect_uris":["foo"],"another_field":8}`, |
|
}, |
|
{ |
|
`{"redirectURLs" : ["foo"], "redirect_uris" : ["foo"]}`, |
|
`{"redirectURLs" : ["foo"], "redirect_uris" : ["foo"]}`, |
|
}, |
|
} |
|
|
|
for i, tt := range tests { |
|
model := &oldClientModel{ |
|
ID: strconv.Itoa(i), |
|
Secret: []byte("verysecret"), |
|
Metadata: tt.before, |
|
} |
|
if err := dbMap.Insert(model); err != nil { |
|
t.Fatalf("could not insert model: %v", err) |
|
} |
|
} |
|
|
|
n, err = MigrateMaxMigrations(dbMap, 1) |
|
if err != nil { |
|
t.Fatalf("failed to perform initial migration: %v", err) |
|
} |
|
if n != 1 { |
|
t.Fatalf("expected to perform 1 migration, got %d", n) |
|
} |
|
|
|
for i, tt := range tests { |
|
id := strconv.Itoa(i) |
|
m, err := dbMap.Get(oldClientModel{}, id) |
|
if err != nil { |
|
t.Errorf("case %d: failed to get model: %v", i, err) |
|
continue |
|
} |
|
cim, ok := m.(*oldClientModel) |
|
if !ok { |
|
t.Errorf("case %d: unrecognized model type: %T", i, m) |
|
continue |
|
} |
|
|
|
if cim.Metadata != tt.after { |
|
t.Errorf("case %d: want=%q, got=%q", i, tt.after, cim.Metadata) |
|
} |
|
} |
|
} |
|
|
|
func TestMigrationNumber11(t *testing.T) { |
|
dsn := os.Getenv("DEX_TEST_DSN") |
|
if dsn == "" { |
|
t.Skip("Test will not run without DEX_TEST_DSN environment variable.") |
|
return |
|
} |
|
|
|
tests := []struct { |
|
sqlStmt string |
|
wantEmails []string |
|
wantError bool |
|
}{ |
|
{ |
|
sqlStmt: `INSERT INTO authd_user |
|
(id, email, email_verified, display_name, admin, created_at) |
|
VALUES |
|
(1, 'Foo@example.com', TRUE, 'foo', FALSE, extract(epoch from now())), |
|
(2, 'Bar@example.com', TRUE, 'foo', FALSE, extract(epoch from now())) |
|
;`, |
|
wantEmails: []string{"foo@example.com", "bar@example.com"}, |
|
wantError: false, |
|
}, |
|
{ |
|
sqlStmt: `INSERT INTO authd_user |
|
(id, email, email_verified, display_name, admin, created_at) |
|
VALUES |
|
(1, 'Foo@example.com', TRUE, 'foo', FALSE, extract(epoch from now())), |
|
(2, 'foo@example.com', TRUE, 'foo', FALSE, extract(epoch from now())), |
|
(3, 'bar@example.com', TRUE, 'foo', FALSE, extract(epoch from now())) |
|
;`, |
|
wantError: true, |
|
}, |
|
} |
|
migrateN := func(dbMap *gorp.DbMap, n int) error { |
|
nPerformed, err := MigrateMaxMigrations(dbMap, n) |
|
if err == nil && n != nPerformed { |
|
err = fmt.Errorf("expected to perform %d migrations, performed %d", n, nPerformed) |
|
} |
|
return err |
|
} |
|
|
|
for i, tt := range tests { |
|
err := func() error { |
|
dbMap := initDB(dsn) |
|
|
|
nMigrations := 10 |
|
if err := migrateN(dbMap, nMigrations); err != nil { |
|
return fmt.Errorf("failed to perform initial migration: %v", err) |
|
} |
|
if _, err := dbMap.Exec(tt.sqlStmt); err != nil { |
|
return fmt.Errorf("failed to insert users: %v", err) |
|
} |
|
if err := migrateN(dbMap, 1); err != nil { |
|
if tt.wantError { |
|
return nil |
|
} |
|
return fmt.Errorf("failed to perform migration: %v", err) |
|
} |
|
|
|
if tt.wantError { |
|
return fmt.Errorf("expected an error when migrating") |
|
} |
|
|
|
var gotEmails []string |
|
if _, err := dbMap.Select(&gotEmails, `SELECT email FROM authd_user;`); err != nil { |
|
return fmt.Errorf("could not get user emails: %v", err) |
|
} |
|
|
|
sort.Strings(tt.wantEmails) |
|
sort.Strings(gotEmails) |
|
if diff := pretty.Compare(tt.wantEmails, gotEmails); diff != "" { |
|
return fmt.Errorf("wantEmails != gotEmails: %s", diff) |
|
} |
|
return nil |
|
}() |
|
if err != nil { |
|
t.Errorf("case %d: %v", i, err) |
|
} |
|
} |
|
}
|
|
|