diff --git a/migrator.go b/migrator.go index 8c22d5a..94fa0f2 100644 --- a/migrator.go +++ b/migrator.go @@ -75,20 +75,8 @@ func New(opts ...Option) (*Migrator, error) { // Migrate applies all available migrations func (m *Migrator) Migrate(ctx context.Context, db Conn) error { - // create migrations table if doesn't exist - _, err := db.Exec(ctx, fmt.Sprintf(` - CREATE TABLE IF NOT EXISTS %s ( - id INT8 NOT NULL, - version VARCHAR(255) NOT NULL, - PRIMARY KEY (id) - ); - `, m.tableName)) - if err != nil { - return err - } - // count applied migrations - count, err := countApplied(ctx, db, m.tableName) + count, err := m.countApplied(ctx, db, m.tableName) if err != nil { return err } @@ -97,6 +85,8 @@ func (m *Migrator) Migrate(ctx context.Context, db Conn) error { return errors.New("migrator: applied migration number on db cannot be greater than the defined migration list") } + m.logger.Log("Running missing migrations...", map[string]any{"missing": len(m.migrations) - count}) + // plan migrations for idx, migration := range m.migrations[count:] { tx, err := db.Begin(ctx) @@ -123,14 +113,25 @@ func (m *Migrator) Migrate(ctx context.Context, db Conn) error { // Pending returns all pending (not yet applied) migrations func (m *Migrator) Pending(ctx context.Context, db Conn) ([]Migration, error) { - count, err := countApplied(ctx, db, m.tableName) + count, err := m.countApplied(ctx, db, m.tableName) if err != nil { return nil, err } return m.migrations[count:len(m.migrations)], nil } -func countApplied(ctx context.Context, db Conn, tableName string) (int, error) { +func (m *Migrator) countApplied(ctx context.Context, db Conn, tableName string) (int, error) { + // create migrations table if doesn't exist + if _, err := db.Exec(ctx, fmt.Sprintf(` + CREATE TABLE IF NOT EXISTS %s ( + id INT8 NOT NULL, + version VARCHAR(255) NOT NULL, + PRIMARY KEY (id) + ); + `, m.tableName)); err != nil { + return 0, err + } + // count applied migrations var count int row := db.QueryRow(ctx, fmt.Sprintf("SELECT count(*) FROM %s", tableName)) diff --git a/migrator_test.go b/migrator_test.go index 8df5148..6a36fcb 100644 --- a/migrator_test.go +++ b/migrator_test.go @@ -4,6 +4,7 @@ import ( "context" _ "embed" "errors" + "fmt" "os" "testing" @@ -25,7 +26,17 @@ var migrations = []mgx.Migration{ }), } -func TestMigrate(t *testing.T) { +var _ mgx.Logger = (*TestLogger)(nil) + +type TestLogger struct { + logged bool +} + +func (t *TestLogger) Log(_ string, _ map[string]any) { + t.logged = true +} + +func connectToDatabase(t *testing.T) *pgx.Conn { // create db connection url := os.Getenv("POSTGRES") if url == "" { @@ -36,7 +47,21 @@ func TestMigrate(t *testing.T) { if err != nil { t.Fatal(err) } - // TODO: replace with your db connection + + if _, err := db.Exec(context.Background(), "DROP SCHEMA public CASCADE"); err != nil { + t.Fatal(err) + } + if _, err := db.Exec(context.Background(), "CREATE SCHEMA public"); err != nil { + t.Fatal(err) + } + return db +} + +func TestMigrate(t *testing.T) { + db := connectToDatabase(t) + defer func(db *pgx.Conn, ctx context.Context) { + _ = db.Close(ctx) + }(db, context.Background()) // create migrator migrator, err := mgx.New(mgx.Migrations(migrations...)) @@ -50,3 +75,79 @@ func TestMigrate(t *testing.T) { t.Fatal(err) } } + +func TestMigrateWithCustomLogger(t *testing.T) { + db := connectToDatabase(t) + defer func(db *pgx.Conn, ctx context.Context) { + _ = db.Close(ctx) + }(db, context.Background()) + + l := new(TestLogger) + + // create migrator + migrator, err := mgx.New(mgx.Log(l)) + if err != nil { + t.Fatal(err) + } + + // run migrator + err = migrator.Migrate(context.Background(), db) + if err != nil { + t.Fatal(err) + } + + if !l.logged { + t.Fatal("custom logger was not called") + } +} + +func TestMigrateWithCustomTableName(t *testing.T) { + db := connectToDatabase(t) + defer func(db *pgx.Conn, ctx context.Context) { + _ = db.Close(ctx) + }(db, context.Background()) + + // create migrator + tableName := "custom_table_name" + migrator, err := mgx.New(mgx.TableName(tableName)) + if err != nil { + t.Fatal(err) + } + + // run migrator + err = migrator.Migrate(context.Background(), db) + if err != nil { + t.Fatal(err) + } + + // check if table exists + var rows int + if err := db.QueryRow( + context.Background(), + "SELECT COUNT(*) FROM "+tableName, + ).Scan(&rows); err != nil { + t.Fatal(err) + } +} + +func TestPending(t *testing.T) { + db := connectToDatabase(t) + defer func(db *pgx.Conn, ctx context.Context) { + _ = db.Close(ctx) + }(db, context.Background()) + + // create migrator + migrator, err := mgx.New(mgx.Migrations(migrations...)) + if err != nil { + t.Fatal(err) + } + + pending, err := migrator.Pending(context.Background(), db) + if err == nil && len(pending) != len(migrations) { + err = fmt.Errorf("there should be %d pending migrations, only %d found", len(migrations), len(pending)) + } + + if err != nil { + t.Fatal(err) + } +}