mirror of https://github.com/dexidp/dex.git
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
475 lines
10 KiB
475 lines
10 KiB
package migrate |
|
|
|
import ( |
|
"bytes" |
|
"database/sql" |
|
"errors" |
|
"fmt" |
|
"io" |
|
"os" |
|
"path" |
|
"regexp" |
|
"sort" |
|
"strconv" |
|
"strings" |
|
"time" |
|
|
|
"github.com/rubenv/sql-migrate/sqlparse" |
|
"gopkg.in/gorp.v1" |
|
) |
|
|
|
type MigrationDirection int |
|
|
|
const ( |
|
Up MigrationDirection = iota |
|
Down |
|
) |
|
|
|
var tableName = "gorp_migrations" |
|
var schemaName = "" |
|
var numberPrefixRegex = regexp.MustCompile(`^(\d+).*$`) |
|
|
|
// Set the name of the table used to store migration info. |
|
// |
|
// Should be called before any other call such as (Exec, ExecMax, ...). |
|
func SetTable(name string) { |
|
if name != "" { |
|
tableName = name |
|
} |
|
} |
|
|
|
// SetSchema sets the name of a schema that the migration table be referenced. |
|
func SetSchema(name string) { |
|
if name != "" { |
|
schemaName = name |
|
} |
|
} |
|
|
|
func getTableName() string { |
|
t := tableName |
|
if schemaName != "" { |
|
t = fmt.Sprintf("%s.%s", schemaName, t) |
|
} |
|
|
|
return t |
|
} |
|
|
|
type Migration struct { |
|
Id string |
|
Up []string |
|
Down []string |
|
} |
|
|
|
func (m Migration) Less(other *Migration) bool { |
|
switch { |
|
case m.isNumeric() && other.isNumeric(): |
|
return m.VersionInt() < other.VersionInt() |
|
case m.isNumeric() && !other.isNumeric(): |
|
return true |
|
case !m.isNumeric() && other.isNumeric(): |
|
return false |
|
default: |
|
return m.Id < other.Id |
|
} |
|
} |
|
|
|
func (m Migration) isNumeric() bool { |
|
return len(m.NumberPrefixMatches()) > 0 |
|
} |
|
|
|
func (m Migration) NumberPrefixMatches() []string { |
|
return numberPrefixRegex.FindStringSubmatch(m.Id) |
|
} |
|
|
|
func (m Migration) VersionInt() int64 { |
|
v := m.NumberPrefixMatches()[1] |
|
value, err := strconv.ParseInt(v, 10, 64) |
|
if err != nil { |
|
panic(fmt.Sprintf("Could not parse %q into int64: %s", v, err)) |
|
} |
|
return value |
|
} |
|
|
|
type PlannedMigration struct { |
|
*Migration |
|
Queries []string |
|
} |
|
|
|
type byId []*Migration |
|
|
|
func (b byId) Len() int { return len(b) } |
|
func (b byId) Swap(i, j int) { b[i], b[j] = b[j], b[i] } |
|
func (b byId) Less(i, j int) bool { return b[i].Less(b[j]) } |
|
|
|
type MigrationRecord struct { |
|
Id string `db:"id"` |
|
AppliedAt time.Time `db:"applied_at"` |
|
} |
|
|
|
var MigrationDialects = map[string]gorp.Dialect{ |
|
"sqlite3": gorp.SqliteDialect{}, |
|
"postgres": gorp.PostgresDialect{}, |
|
"mysql": gorp.MySQLDialect{"InnoDB", "UTF8"}, |
|
"mssql": gorp.SqlServerDialect{}, |
|
"oci8": gorp.OracleDialect{}, |
|
} |
|
|
|
type MigrationSource interface { |
|
// Finds the migrations. |
|
// |
|
// The resulting slice of migrations should be sorted by Id. |
|
FindMigrations() ([]*Migration, error) |
|
} |
|
|
|
// A hardcoded set of migrations, in-memory. |
|
type MemoryMigrationSource struct { |
|
Migrations []*Migration |
|
} |
|
|
|
var _ MigrationSource = (*MemoryMigrationSource)(nil) |
|
|
|
func (m MemoryMigrationSource) FindMigrations() ([]*Migration, error) { |
|
// Make sure migrations are sorted |
|
sort.Sort(byId(m.Migrations)) |
|
|
|
return m.Migrations, nil |
|
} |
|
|
|
// A set of migrations loaded from a directory. |
|
type FileMigrationSource struct { |
|
Dir string |
|
} |
|
|
|
var _ MigrationSource = (*FileMigrationSource)(nil) |
|
|
|
func (f FileMigrationSource) FindMigrations() ([]*Migration, error) { |
|
migrations := make([]*Migration, 0) |
|
|
|
file, err := os.Open(f.Dir) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
files, err := file.Readdir(0) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
for _, info := range files { |
|
if strings.HasSuffix(info.Name(), ".sql") { |
|
file, err := os.Open(path.Join(f.Dir, info.Name())) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
migration, err := ParseMigration(info.Name(), file) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
migrations = append(migrations, migration) |
|
} |
|
} |
|
|
|
// Make sure migrations are sorted |
|
sort.Sort(byId(migrations)) |
|
|
|
return migrations, nil |
|
} |
|
|
|
// Migrations from a bindata asset set. |
|
type AssetMigrationSource struct { |
|
// Asset should return content of file in path if exists |
|
Asset func(path string) ([]byte, error) |
|
|
|
// AssetDir should return list of files in the path |
|
AssetDir func(path string) ([]string, error) |
|
|
|
// Path in the bindata to use. |
|
Dir string |
|
} |
|
|
|
var _ MigrationSource = (*AssetMigrationSource)(nil) |
|
|
|
func (a AssetMigrationSource) FindMigrations() ([]*Migration, error) { |
|
migrations := make([]*Migration, 0) |
|
|
|
files, err := a.AssetDir(a.Dir) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
for _, name := range files { |
|
if strings.HasSuffix(name, ".sql") { |
|
file, err := a.Asset(path.Join(a.Dir, name)) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
migration, err := ParseMigration(name, bytes.NewReader(file)) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
migrations = append(migrations, migration) |
|
} |
|
} |
|
|
|
// Make sure migrations are sorted |
|
sort.Sort(byId(migrations)) |
|
|
|
return migrations, nil |
|
} |
|
|
|
// Migration parsing |
|
func ParseMigration(id string, r io.ReadSeeker) (*Migration, error) { |
|
m := &Migration{ |
|
Id: id, |
|
} |
|
|
|
up, err := sqlparse.SplitSQLStatements(r, true) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
down, err := sqlparse.SplitSQLStatements(r, false) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
m.Up = up |
|
m.Down = down |
|
|
|
return m, nil |
|
} |
|
|
|
// Execute a set of migrations |
|
// |
|
// Returns the number of applied migrations. |
|
func Exec(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection) (int, error) { |
|
return ExecMax(db, dialect, m, dir, 0) |
|
} |
|
|
|
// Execute a set of migrations |
|
// |
|
// Will apply at most `max` migrations. Pass 0 for no limit (or use Exec). |
|
// |
|
// Returns the number of applied migrations. |
|
func ExecMax(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, max int) (int, error) { |
|
migrations, dbMap, err := PlanMigration(db, dialect, m, dir, max) |
|
if err != nil { |
|
return 0, err |
|
} |
|
|
|
// Apply migrations |
|
applied := 0 |
|
for _, migration := range migrations { |
|
trans, err := dbMap.Begin() |
|
if err != nil { |
|
return applied, err |
|
} |
|
|
|
for _, stmt := range migration.Queries { |
|
_, err := trans.Exec(stmt) |
|
if err != nil { |
|
trans.Rollback() |
|
return applied, err |
|
} |
|
} |
|
|
|
if dir == Up { |
|
err = trans.Insert(&MigrationRecord{ |
|
Id: migration.Id, |
|
AppliedAt: time.Now(), |
|
}) |
|
if err != nil { |
|
return applied, err |
|
} |
|
} else if dir == Down { |
|
_, err := trans.Delete(&MigrationRecord{ |
|
Id: migration.Id, |
|
}) |
|
if err != nil { |
|
return applied, err |
|
} |
|
} else { |
|
panic("Not possible") |
|
} |
|
|
|
err = trans.Commit() |
|
if err != nil { |
|
return applied, err |
|
} |
|
|
|
applied++ |
|
} |
|
|
|
return applied, nil |
|
} |
|
|
|
// Plan a migration. |
|
func PlanMigration(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, max int) ([]*PlannedMigration, *gorp.DbMap, error) { |
|
dbMap, err := getMigrationDbMap(db, dialect) |
|
if err != nil { |
|
return nil, nil, err |
|
} |
|
|
|
migrations, err := m.FindMigrations() |
|
if err != nil { |
|
return nil, nil, err |
|
} |
|
|
|
var migrationRecords []MigrationRecord |
|
_, err = dbMap.Select(&migrationRecords, fmt.Sprintf("SELECT * FROM %s", getTableName())) |
|
if err != nil { |
|
return nil, nil, err |
|
} |
|
|
|
// Sort migrations that have been run by Id. |
|
var existingMigrations []*Migration |
|
for _, migrationRecord := range migrationRecords { |
|
existingMigrations = append(existingMigrations, &Migration{ |
|
Id: migrationRecord.Id, |
|
}) |
|
} |
|
sort.Sort(byId(existingMigrations)) |
|
|
|
// Get last migration that was run |
|
record := &Migration{} |
|
if len(existingMigrations) > 0 { |
|
record = existingMigrations[len(existingMigrations)-1] |
|
} |
|
|
|
result := make([]*PlannedMigration, 0) |
|
|
|
// Add missing migrations up to the last run migration. |
|
// This can happen for example when merges happened. |
|
if len(existingMigrations) > 0 { |
|
result = append(result, ToCatchup(migrations, existingMigrations, record)...) |
|
} |
|
|
|
// Figure out which migrations to apply |
|
toApply := ToApply(migrations, record.Id, dir) |
|
toApplyCount := len(toApply) |
|
if max > 0 && max < toApplyCount { |
|
toApplyCount = max |
|
} |
|
for _, v := range toApply[0:toApplyCount] { |
|
|
|
if dir == Up { |
|
result = append(result, &PlannedMigration{ |
|
Migration: v, |
|
Queries: v.Up, |
|
}) |
|
} else if dir == Down { |
|
result = append(result, &PlannedMigration{ |
|
Migration: v, |
|
Queries: v.Down, |
|
}) |
|
} |
|
} |
|
|
|
return result, dbMap, nil |
|
} |
|
|
|
// Filter a slice of migrations into ones that should be applied. |
|
func ToApply(migrations []*Migration, current string, direction MigrationDirection) []*Migration { |
|
var index = -1 |
|
if current != "" { |
|
for index < len(migrations)-1 { |
|
index++ |
|
if migrations[index].Id == current { |
|
break |
|
} |
|
} |
|
} |
|
|
|
if direction == Up { |
|
return migrations[index+1:] |
|
} else if direction == Down { |
|
if index == -1 { |
|
return []*Migration{} |
|
} |
|
|
|
// Add in reverse order |
|
toApply := make([]*Migration, index+1) |
|
for i := 0; i < index+1; i++ { |
|
toApply[index-i] = migrations[i] |
|
} |
|
return toApply |
|
} |
|
|
|
panic("Not possible") |
|
} |
|
|
|
func ToCatchup(migrations, existingMigrations []*Migration, lastRun *Migration) []*PlannedMigration { |
|
missing := make([]*PlannedMigration, 0) |
|
for _, migration := range migrations { |
|
found := false |
|
for _, existing := range existingMigrations { |
|
if existing.Id == migration.Id { |
|
found = true |
|
break |
|
} |
|
} |
|
if !found && migration.Less(lastRun) { |
|
missing = append(missing, &PlannedMigration{Migration: migration, Queries: migration.Up}) |
|
} |
|
} |
|
return missing |
|
} |
|
|
|
func GetMigrationRecords(db *sql.DB, dialect string) ([]*MigrationRecord, error) { |
|
dbMap, err := getMigrationDbMap(db, dialect) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
var records []*MigrationRecord |
|
query := fmt.Sprintf("SELECT * FROM %s ORDER BY id ASC", getTableName()) |
|
_, err = dbMap.Select(&records, query) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
return records, nil |
|
} |
|
|
|
func getMigrationDbMap(db *sql.DB, dialect string) (*gorp.DbMap, error) { |
|
d, ok := MigrationDialects[dialect] |
|
if !ok { |
|
return nil, fmt.Errorf("Unknown dialect: %s", dialect) |
|
} |
|
|
|
// When using the mysql driver, make sure that the parseTime option is |
|
// configured, otherwise it won't map time columns to time.Time. See |
|
// https://github.com/rubenv/sql-migrate/issues/2 |
|
if dialect == "mysql" { |
|
var out *time.Time |
|
err := db.QueryRow("SELECT NOW()").Scan(&out) |
|
if err != nil { |
|
if err.Error() == "sql: Scan error on column index 0: unsupported driver -> Scan pair: []uint8 -> *time.Time" { |
|
return nil, errors.New(`Cannot parse dates. |
|
|
|
Make sure that the parseTime option is supplied to your database connection. |
|
Check https://github.com/go-sql-driver/mysql#parsetime for more info.`) |
|
} else { |
|
return nil, err |
|
} |
|
} |
|
} |
|
|
|
// Create migration database map |
|
dbMap := &gorp.DbMap{Db: db, Dialect: d} |
|
dbMap.AddTableWithNameAndSchema(MigrationRecord{}, schemaName, tableName).SetKeys(false, "Id") |
|
//dbMap.TraceOn("", log.New(os.Stdout, "migrate: ", log.Lmicroseconds)) |
|
|
|
err := dbMap.CreateTablesIfNotExists() |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
return dbMap, nil |
|
} |
|
|
|
// TODO: Run migration + record insert in transaction.
|
|
|