Skip to content

Commit

Permalink
fix: create migration table if needed for pending
Browse files Browse the repository at this point in the history
  • Loading branch information
cking committed Feb 17, 2023
1 parent f15b7e6 commit ad16151
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 17 deletions.
31 changes: 16 additions & 15 deletions migrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,20 +75,8 @@ func New(opts ...Option) (*Migrator, error) {

// Migrate applies all available migrations
func (m *Migrator) Migrate(ctx context.Context, db Conn) error {
// create migrations table if doesn't exist
_, err := db.Exec(ctx, fmt.Sprintf(`
CREATE TABLE IF NOT EXISTS %s (
id INT8 NOT NULL,
version VARCHAR(255) NOT NULL,
PRIMARY KEY (id)
);
`, m.tableName))
if err != nil {
return err
}

// count applied migrations
count, err := countApplied(ctx, db, m.tableName)
count, err := m.countApplied(ctx, db, m.tableName)
if err != nil {
return err
}
Expand All @@ -97,6 +85,8 @@ func (m *Migrator) Migrate(ctx context.Context, db Conn) error {
return errors.New("migrator: applied migration number on db cannot be greater than the defined migration list")
}

m.logger.Log("Running missing migrations...", map[string]any{"missing": len(m.migrations) - count})

// plan migrations
for idx, migration := range m.migrations[count:] {
tx, err := db.Begin(ctx)
Expand All @@ -123,14 +113,25 @@ func (m *Migrator) Migrate(ctx context.Context, db Conn) error {

// Pending returns all pending (not yet applied) migrations
func (m *Migrator) Pending(ctx context.Context, db Conn) ([]Migration, error) {
count, err := countApplied(ctx, db, m.tableName)
count, err := m.countApplied(ctx, db, m.tableName)
if err != nil {
return nil, err
}
return m.migrations[count:len(m.migrations)], nil
}

func countApplied(ctx context.Context, db Conn, tableName string) (int, error) {
func (m *Migrator) countApplied(ctx context.Context, db Conn, tableName string) (int, error) {
// create migrations table if doesn't exist
if _, err := db.Exec(ctx, fmt.Sprintf(`
CREATE TABLE IF NOT EXISTS %s (
id INT8 NOT NULL,
version VARCHAR(255) NOT NULL,
PRIMARY KEY (id)
);
`, m.tableName)); err != nil {
return 0, err
}

// count applied migrations
var count int
row := db.QueryRow(ctx, fmt.Sprintf("SELECT count(*) FROM %s", tableName))
Expand Down
105 changes: 103 additions & 2 deletions migrator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
_ "embed"
"errors"
"fmt"
"os"
"testing"

Expand All @@ -25,7 +26,17 @@ var migrations = []mgx.Migration{
}),
}

func TestMigrate(t *testing.T) {
var _ mgx.Logger = (*TestLogger)(nil)

type TestLogger struct {
logged bool
}

func (t *TestLogger) Log(_ string, _ map[string]any) {
t.logged = true
}

func connectToDatabase(t *testing.T) *pgx.Conn {
// create db connection
url := os.Getenv("POSTGRES")
if url == "" {
Expand All @@ -36,7 +47,21 @@ func TestMigrate(t *testing.T) {
if err != nil {
t.Fatal(err)
}
// TODO: replace with your db connection

if _, err := db.Exec(context.Background(), "DROP SCHEMA public CASCADE"); err != nil {
t.Fatal(err)
}
if _, err := db.Exec(context.Background(), "CREATE SCHEMA public"); err != nil {
t.Fatal(err)
}
return db
}

func TestMigrate(t *testing.T) {
db := connectToDatabase(t)
defer func(db *pgx.Conn, ctx context.Context) {
_ = db.Close(ctx)
}(db, context.Background())

// create migrator
migrator, err := mgx.New(mgx.Migrations(migrations...))
Expand All @@ -50,3 +75,79 @@ func TestMigrate(t *testing.T) {
t.Fatal(err)
}
}

func TestMigrateWithCustomLogger(t *testing.T) {
db := connectToDatabase(t)
defer func(db *pgx.Conn, ctx context.Context) {
_ = db.Close(ctx)
}(db, context.Background())

l := new(TestLogger)

// create migrator
migrator, err := mgx.New(mgx.Log(l))
if err != nil {
t.Fatal(err)
}

// run migrator
err = migrator.Migrate(context.Background(), db)
if err != nil {
t.Fatal(err)
}

if !l.logged {
t.Fatal("custom logger was not called")
}
}

func TestMigrateWithCustomTableName(t *testing.T) {
db := connectToDatabase(t)
defer func(db *pgx.Conn, ctx context.Context) {
_ = db.Close(ctx)
}(db, context.Background())

// create migrator
tableName := "custom_table_name"
migrator, err := mgx.New(mgx.TableName(tableName))
if err != nil {
t.Fatal(err)
}

// run migrator
err = migrator.Migrate(context.Background(), db)
if err != nil {
t.Fatal(err)
}

// check if table exists
var rows int
if err := db.QueryRow(
context.Background(),
"SELECT COUNT(*) FROM "+tableName,
).Scan(&rows); err != nil {
t.Fatal(err)
}
}

func TestPending(t *testing.T) {
db := connectToDatabase(t)
defer func(db *pgx.Conn, ctx context.Context) {
_ = db.Close(ctx)
}(db, context.Background())

// create migrator
migrator, err := mgx.New(mgx.Migrations(migrations...))
if err != nil {
t.Fatal(err)
}

pending, err := migrator.Pending(context.Background(), db)
if err == nil && len(pending) != len(migrations) {
err = fmt.Errorf("there should be %d pending migrations, only %d found", len(migrations), len(pending))
}

if err != nil {
t.Fatal(err)
}
}

0 comments on commit ad16151

Please sign in to comment.