You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
448 lines
9.9 KiB
448 lines
9.9 KiB
package migrate |
|
|
|
import ( |
|
"context" |
|
"errors" |
|
"fmt" |
|
"os" |
|
"path/filepath" |
|
"regexp" |
|
"time" |
|
|
|
"github.com/uptrace/bun" |
|
) |
|
|
|
type MigratorOption func(m *Migrator) |
|
|
|
func WithTableName(table string) MigratorOption { |
|
return func(m *Migrator) { |
|
m.table = table |
|
} |
|
} |
|
|
|
func WithLocksTableName(table string) MigratorOption { |
|
return func(m *Migrator) { |
|
m.locksTable = table |
|
} |
|
} |
|
|
|
// WithMarkAppliedOnSuccess sets the migrator to only mark migrations as applied/unapplied |
|
// when their up/down is successful |
|
func WithMarkAppliedOnSuccess(enabled bool) MigratorOption { |
|
return func(m *Migrator) { |
|
m.markAppliedOnSuccess = enabled |
|
} |
|
} |
|
|
|
type Migrator struct { |
|
db *bun.DB |
|
migrations *Migrations |
|
|
|
ms MigrationSlice |
|
|
|
table string |
|
locksTable string |
|
markAppliedOnSuccess bool |
|
} |
|
|
|
func NewMigrator(db *bun.DB, migrations *Migrations, opts ...MigratorOption) *Migrator { |
|
m := &Migrator{ |
|
db: db, |
|
migrations: migrations, |
|
|
|
ms: migrations.ms, |
|
|
|
table: "bun_migrations", |
|
locksTable: "bun_migration_locks", |
|
} |
|
for _, opt := range opts { |
|
opt(m) |
|
} |
|
return m |
|
} |
|
|
|
func (m *Migrator) DB() *bun.DB { |
|
return m.db |
|
} |
|
|
|
// MigrationsWithStatus returns migrations with status in ascending order. |
|
func (m *Migrator) MigrationsWithStatus(ctx context.Context) (MigrationSlice, error) { |
|
sorted, _, err := m.migrationsWithStatus(ctx) |
|
return sorted, err |
|
} |
|
|
|
func (m *Migrator) migrationsWithStatus(ctx context.Context) (MigrationSlice, int64, error) { |
|
sorted := m.migrations.Sorted() |
|
|
|
applied, err := m.AppliedMigrations(ctx) |
|
if err != nil { |
|
return nil, 0, err |
|
} |
|
|
|
appliedMap := migrationMap(applied) |
|
for i := range sorted { |
|
m1 := &sorted[i] |
|
if m2, ok := appliedMap[m1.Name]; ok { |
|
m1.ID = m2.ID |
|
m1.GroupID = m2.GroupID |
|
m1.MigratedAt = m2.MigratedAt |
|
} |
|
} |
|
|
|
return sorted, applied.LastGroupID(), nil |
|
} |
|
|
|
func (m *Migrator) Init(ctx context.Context) error { |
|
if _, err := m.db.NewCreateTable(). |
|
Model((*Migration)(nil)). |
|
ModelTableExpr(m.table). |
|
IfNotExists(). |
|
Exec(ctx); err != nil { |
|
return err |
|
} |
|
if _, err := m.db.NewCreateTable(). |
|
Model((*migrationLock)(nil)). |
|
ModelTableExpr(m.locksTable). |
|
IfNotExists(). |
|
Exec(ctx); err != nil { |
|
return err |
|
} |
|
return nil |
|
} |
|
|
|
func (m *Migrator) Reset(ctx context.Context) error { |
|
if _, err := m.db.NewDropTable(). |
|
Model((*Migration)(nil)). |
|
ModelTableExpr(m.table). |
|
IfExists(). |
|
Exec(ctx); err != nil { |
|
return err |
|
} |
|
if _, err := m.db.NewDropTable(). |
|
Model((*migrationLock)(nil)). |
|
ModelTableExpr(m.locksTable). |
|
IfExists(). |
|
Exec(ctx); err != nil { |
|
return err |
|
} |
|
return m.Init(ctx) |
|
} |
|
|
|
// Migrate runs unapplied migrations. If a migration fails, migrate immediately exits. |
|
func (m *Migrator) Migrate(ctx context.Context, opts ...MigrationOption) (*MigrationGroup, error) { |
|
cfg := newMigrationConfig(opts) |
|
|
|
if err := m.validate(); err != nil { |
|
return nil, err |
|
} |
|
|
|
migrations, lastGroupID, err := m.migrationsWithStatus(ctx) |
|
if err != nil { |
|
return nil, err |
|
} |
|
migrations = migrations.Unapplied() |
|
|
|
group := new(MigrationGroup) |
|
if len(migrations) == 0 { |
|
return group, nil |
|
} |
|
group.ID = lastGroupID + 1 |
|
|
|
for i := range migrations { |
|
migration := &migrations[i] |
|
migration.GroupID = group.ID |
|
|
|
if !m.markAppliedOnSuccess { |
|
if err := m.MarkApplied(ctx, migration); err != nil { |
|
return group, err |
|
} |
|
} |
|
|
|
group.Migrations = migrations[:i+1] |
|
|
|
if !cfg.nop && migration.Up != nil { |
|
if err := migration.Up(ctx, m.db); err != nil { |
|
return group, err |
|
} |
|
} |
|
|
|
if m.markAppliedOnSuccess { |
|
if err := m.MarkApplied(ctx, migration); err != nil { |
|
return group, err |
|
} |
|
} |
|
} |
|
|
|
return group, nil |
|
} |
|
|
|
func (m *Migrator) Rollback(ctx context.Context, opts ...MigrationOption) (*MigrationGroup, error) { |
|
cfg := newMigrationConfig(opts) |
|
|
|
if err := m.validate(); err != nil { |
|
return nil, err |
|
} |
|
|
|
migrations, err := m.MigrationsWithStatus(ctx) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
lastGroup := migrations.LastGroup() |
|
|
|
for i := len(lastGroup.Migrations) - 1; i >= 0; i-- { |
|
migration := &lastGroup.Migrations[i] |
|
|
|
if !m.markAppliedOnSuccess { |
|
if err := m.MarkUnapplied(ctx, migration); err != nil { |
|
return lastGroup, err |
|
} |
|
} |
|
|
|
if !cfg.nop && migration.Down != nil { |
|
if err := migration.Down(ctx, m.db); err != nil { |
|
return lastGroup, err |
|
} |
|
} |
|
|
|
if m.markAppliedOnSuccess { |
|
if err := m.MarkUnapplied(ctx, migration); err != nil { |
|
return lastGroup, err |
|
} |
|
} |
|
} |
|
|
|
return lastGroup, nil |
|
} |
|
|
|
type goMigrationConfig struct { |
|
packageName string |
|
goTemplate string |
|
} |
|
|
|
type GoMigrationOption func(cfg *goMigrationConfig) |
|
|
|
func WithPackageName(name string) GoMigrationOption { |
|
return func(cfg *goMigrationConfig) { |
|
cfg.packageName = name |
|
} |
|
} |
|
|
|
func WithGoTemplate(template string) GoMigrationOption { |
|
return func(cfg *goMigrationConfig) { |
|
cfg.goTemplate = template |
|
} |
|
} |
|
|
|
// CreateGoMigration creates a Go migration file. |
|
func (m *Migrator) CreateGoMigration( |
|
ctx context.Context, name string, opts ...GoMigrationOption, |
|
) (*MigrationFile, error) { |
|
cfg := &goMigrationConfig{ |
|
packageName: "migrations", |
|
goTemplate: goTemplate, |
|
} |
|
for _, opt := range opts { |
|
opt(cfg) |
|
} |
|
|
|
name, err := m.genMigrationName(name) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
fname := name + ".go" |
|
fpath := filepath.Join(m.migrations.getDirectory(), fname) |
|
content := fmt.Sprintf(cfg.goTemplate, cfg.packageName) |
|
|
|
if err := os.WriteFile(fpath, []byte(content), 0o644); err != nil { |
|
return nil, err |
|
} |
|
|
|
mf := &MigrationFile{ |
|
Name: fname, |
|
Path: fpath, |
|
Content: content, |
|
} |
|
return mf, nil |
|
} |
|
|
|
// CreateTxSQLMigration creates transactional up and down SQL migration files. |
|
func (m *Migrator) CreateTxSQLMigrations(ctx context.Context, name string) ([]*MigrationFile, error) { |
|
name, err := m.genMigrationName(name) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
up, err := m.createSQL(ctx, name+".up.tx.sql", true) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
down, err := m.createSQL(ctx, name+".down.tx.sql", true) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
return []*MigrationFile{up, down}, nil |
|
} |
|
|
|
// CreateSQLMigrations creates up and down SQL migration files. |
|
func (m *Migrator) CreateSQLMigrations(ctx context.Context, name string) ([]*MigrationFile, error) { |
|
name, err := m.genMigrationName(name) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
up, err := m.createSQL(ctx, name+".up.sql", false) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
down, err := m.createSQL(ctx, name+".down.sql", false) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
return []*MigrationFile{up, down}, nil |
|
} |
|
|
|
func (m *Migrator) createSQL(ctx context.Context, fname string, transactional bool) (*MigrationFile, error) { |
|
fpath := filepath.Join(m.migrations.getDirectory(), fname) |
|
|
|
template := sqlTemplate |
|
if transactional { |
|
template = transactionalSQLTemplate |
|
} |
|
|
|
if err := os.WriteFile(fpath, []byte(template), 0o644); err != nil { |
|
return nil, err |
|
} |
|
|
|
mf := &MigrationFile{ |
|
Name: fname, |
|
Path: fpath, |
|
Content: goTemplate, |
|
} |
|
return mf, nil |
|
} |
|
|
|
var nameRE = regexp.MustCompile(`^[0-9a-z_\-]+$`) |
|
|
|
func (m *Migrator) genMigrationName(name string) (string, error) { |
|
const timeFormat = "20060102150405" |
|
|
|
if name == "" { |
|
return "", errors.New("migrate: migration name can't be empty") |
|
} |
|
if !nameRE.MatchString(name) { |
|
return "", fmt.Errorf("migrate: invalid migration name: %q", name) |
|
} |
|
|
|
version := time.Now().UTC().Format(timeFormat) |
|
return fmt.Sprintf("%s_%s", version, name), nil |
|
} |
|
|
|
// MarkApplied marks the migration as applied (completed). |
|
func (m *Migrator) MarkApplied(ctx context.Context, migration *Migration) error { |
|
_, err := m.db.NewInsert().Model(migration). |
|
ModelTableExpr(m.table). |
|
Exec(ctx) |
|
return err |
|
} |
|
|
|
// MarkUnapplied marks the migration as unapplied (new). |
|
func (m *Migrator) MarkUnapplied(ctx context.Context, migration *Migration) error { |
|
_, err := m.db.NewDelete(). |
|
Model(migration). |
|
ModelTableExpr(m.table). |
|
Where("id = ?", migration.ID). |
|
Exec(ctx) |
|
return err |
|
} |
|
|
|
func (m *Migrator) TruncateTable(ctx context.Context) error { |
|
_, err := m.db.NewTruncateTable().TableExpr(m.table).Exec(ctx) |
|
return err |
|
} |
|
|
|
// MissingMigrations returns applied migrations that can no longer be found. |
|
func (m *Migrator) MissingMigrations(ctx context.Context) (MigrationSlice, error) { |
|
applied, err := m.AppliedMigrations(ctx) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
existing := migrationMap(m.migrations.ms) |
|
for i := len(applied) - 1; i >= 0; i-- { |
|
m := &applied[i] |
|
if _, ok := existing[m.Name]; ok { |
|
applied = append(applied[:i], applied[i+1:]...) |
|
} |
|
} |
|
|
|
return applied, nil |
|
} |
|
|
|
// AppliedMigrations selects applied (applied) migrations in descending order. |
|
func (m *Migrator) AppliedMigrations(ctx context.Context) (MigrationSlice, error) { |
|
var ms MigrationSlice |
|
if err := m.db.NewSelect(). |
|
ColumnExpr("*"). |
|
Model(&ms). |
|
ModelTableExpr(m.table). |
|
Scan(ctx); err != nil { |
|
return nil, err |
|
} |
|
return ms, nil |
|
} |
|
|
|
func (m *Migrator) formattedTableName(db *bun.DB) string { |
|
return db.Formatter().FormatQuery(m.table) |
|
} |
|
|
|
func (m *Migrator) validate() error { |
|
if len(m.ms) == 0 { |
|
return errors.New("migrate: there are no migrations") |
|
} |
|
return nil |
|
} |
|
|
|
//------------------------------------------------------------------------------ |
|
|
|
type migrationLock struct { |
|
ID int64 `bun:",pk,autoincrement"` |
|
TableName string `bun:",unique"` |
|
} |
|
|
|
func (m *Migrator) Lock(ctx context.Context) error { |
|
lock := &migrationLock{ |
|
TableName: m.formattedTableName(m.db), |
|
} |
|
if _, err := m.db.NewInsert(). |
|
Model(lock). |
|
ModelTableExpr(m.locksTable). |
|
Exec(ctx); err != nil { |
|
return fmt.Errorf("migrate: migrations table is already locked (%w)", err) |
|
} |
|
return nil |
|
} |
|
|
|
func (m *Migrator) Unlock(ctx context.Context) error { |
|
tableName := m.formattedTableName(m.db) |
|
_, err := m.db.NewDelete(). |
|
Model((*migrationLock)(nil)). |
|
ModelTableExpr(m.locksTable). |
|
Where("? = ?", bun.Ident("table_name"), tableName). |
|
Exec(ctx) |
|
return err |
|
} |
|
|
|
func migrationMap(ms MigrationSlice) map[string]*Migration { |
|
mp := make(map[string]*Migration) |
|
for i := range ms { |
|
m := &ms[i] |
|
mp[m.Name] = m |
|
} |
|
return mp |
|
}
|
|
|