diff --git a/.gitignore b/.gitignore index 6cd06de..7048a4e 100644 --- a/.gitignore +++ b/.gitignore @@ -25,5 +25,6 @@ testdata/tern.conf /tern /tmp +.vscode .idea/* dist diff --git a/go.sum b/go.sum index 2e24d5c..6a33c74 100644 --- a/go.sum +++ b/go.sum @@ -26,8 +26,12 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA= github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw= github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= +github.com/jackc/pgx/v5 v5.7.4 h1:9wKznZrhWa2QiHL+NjTSPP6yjl3451BX3imWDnokYlg= +github.com/jackc/pgx/v5 v5.7.4/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ= github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= diff --git a/main.go b/main.go index 8a89d9e..0727a41 100644 --- a/main.go +++ b/main.go @@ -95,12 +95,13 @@ type Config struct { } var cliOptions struct { - destinationVersion string - currentVersion string - migrationsPath string - configPaths []string - editNewMigration bool - outputFile string // used for gengen or print-migrations + destinationVersion string + currentVersion string + migrationsPath string + configPaths []string + editNewMigration bool + outputFile string // used for gengen or print-migrations + cockroachDbCompatible bool connString string host string @@ -187,6 +188,7 @@ The word "last": Run: Migrate, } cmdMigrate.Flags().StringVarP(&cliOptions.destinationVersion, "destination", "d", "last", "destination migration version") + cmdMigrate.Flags().BoolVar(&cliOptions.cockroachDbCompatible, "cockroachdb", false, "CockroachDB compatibility flag avoiding advisory locks (default is false)") addConfigFlagsToCommand(cmdMigrate) cmdCode := &cobra.Command{ @@ -507,7 +509,9 @@ func Migrate(cmd *cobra.Command, args []string) { config, conn := loadConfigAndConnectToDB(ctx) defer conn.Close(ctx) - migrator, err := migrate.NewMigrator(ctx, conn, config.VersionTable) + migOpts := migrate.MigratorOptions{ CockroachDbCompatible: cliOptions.cockroachDbCompatible } + + migrator, err := migrate.NewMigratorEx(ctx, conn, config.VersionTable, &migOpts) if err != nil { fmt.Fprintf(os.Stderr, "Error initializing migrator:\n %v\n", err) os.Exit(1) @@ -612,7 +616,9 @@ func Gengen(cmd *cobra.Command, args []string) { os.Exit(1) } - migrator, err := migrate.NewMigrator(context.Background(), nil, config.VersionTable) + migOpts := migrate.MigratorOptions{ CockroachDbCompatible: cliOptions.cockroachDbCompatible } + + migrator, err := migrate.NewMigratorEx(context.Background(), nil, config.VersionTable, &migOpts) if err != nil { fmt.Fprintf(os.Stderr, "Error initializing migrator:\n %v\n", err) os.Exit(1) @@ -863,7 +869,9 @@ func Status(cmd *cobra.Command, args []string) { config, conn := loadConfigAndConnectToDB(ctx) defer conn.Close(ctx) - migrator, err := migrate.NewMigrator(ctx, conn, config.VersionTable) + migOpts := migrate.MigratorOptions{ CockroachDbCompatible: cliOptions.cockroachDbCompatible } + + migrator, err := migrate.NewMigratorEx(ctx, conn, config.VersionTable, &migOpts) if err != nil { fmt.Fprintf(os.Stderr, "Error initializing migrator:\n %v\n", err) os.Exit(1) @@ -1307,7 +1315,8 @@ func PrintMigrations(cmd *cobra.Command, args []string) { fmt.Fprintf(os.Stderr, "Error connecting to database:\n %v\n", err) os.Exit(1) } - migrator, err = migrate.NewMigrator(ctx, conn, config.VersionTable) + migOpts := migrate.MigratorOptions{ CockroachDbCompatible: cliOptions.cockroachDbCompatible } + migrator, err = migrate.NewMigratorEx(ctx, conn, config.VersionTable, &migOpts) if err != nil { fmt.Fprintf(os.Stderr, "Error initializing migrator:\n %v\n", err) os.Exit(1) @@ -1321,7 +1330,8 @@ func PrintMigrations(cmd *cobra.Command, args []string) { } currentVersion = int32(n) - migrator, err = migrate.NewMigrator(ctx, nil, config.VersionTable) + migOpts := migrate.MigratorOptions{ CockroachDbCompatible: cliOptions.cockroachDbCompatible } + migrator, err = migrate.NewMigratorEx(ctx, nil, config.VersionTable, &migOpts) if err != nil { fmt.Fprintf(os.Stderr, "Error initializing migrator:\n %v\n", err) os.Exit(1) diff --git a/migrate/migrate.go b/migrate/migrate.go index b72d7f2..0e4a127 100644 --- a/migrate/migrate.go +++ b/migrate/migrate.go @@ -8,6 +8,7 @@ import ( "io/fs" "path/filepath" "regexp" + "runtime" "strconv" "strings" "text/template" @@ -21,6 +22,7 @@ import ( var ( migrationPattern = regexp.MustCompile(`\A(\d+)_.+\.sql\z`) disableTxPattern = regexp.MustCompile(`(?m)^---- tern: disable-tx ----$`) + disableRowLocks = regexp.MustCompile(`(?m)^---- tern: disable-row-locks ----$`) ) const ( @@ -131,10 +133,18 @@ func (m *Migration) irreversible() bool { type MigratorOptions struct { // DisableTx causes the Migrator not to run migrations in a transaction. DisableTx bool + CockroachDbCompatible bool } type Migrator struct { conn *pgx.Conn + + // optionally, used for locking mechanisms + // instead of advisory locks on the primary conn + lockingConn *pgx.Conn + // optionally, Tx for the locking mechanism + lockingTx pgx.Tx + versionTable string options *MigratorOptions Migrations []*Migration @@ -142,23 +152,42 @@ type Migrator struct { Data map[string]interface{} // Data available to use in migrations } -// NewMigrator initializes a new Migrator. It is highly recommended that versionTable be schema qualified. func NewMigrator(ctx context.Context, conn *pgx.Conn, versionTable string) (m *Migrator, err error) { return NewMigratorEx(ctx, conn, versionTable, &MigratorOptions{}) } -// NewMigratorEx initializes a new Migrator. It is highly recommended that versionTable be schema qualified. +// NewMigrator initializes a new Migrator. It is highly recommended that versionTable be schema qualified. func NewMigratorEx(ctx context.Context, conn *pgx.Conn, versionTable string, opts *MigratorOptions) (m *Migrator, err error) { m = &Migrator{conn: conn, versionTable: versionTable, options: opts} + m.Migrations = make([]*Migration, 0) + m.Data = make(map[string]interface{}) + + if opts.CockroachDbCompatible { + m.lockingConn, err = pgx.ConnectConfig(ctx, conn.Config().Copy()) + if err != nil { + // try anyways and leave lockingconn nil? either way there's a failure + // For now, be explicit and block so it's clear + return + } + + // Migrator is the owner of this connection. Instead of requiring the user of Migrator + // to close (a usage change) let go manage the runtime + // Once go compat moves to 1.24 can replace with AddCleanup https://pkg.go.dev/runtime@master#AddCleanup + runtime.SetFinalizer(m.lockingConn, func(c *pgx.Conn) { + if err := c.Close(ctx); err != nil { + fmt.Println("trying to close lockingConn:", err.Error()) + } + }) + } + // This is a bit of a kludge for the gengen command. A migrator without a conn is normally not allowed. However, the // gengen command doesn't call any of the methods that require a conn. Potentially, we could refactor Migrator to // split out the migration loading and parsing from the actual migration execution. if conn != nil { err = m.ensureSchemaVersionTableExists(ctx) } - m.Migrations = make([]*Migration, 0) - m.Data = make(map[string]interface{}) + return } @@ -325,9 +354,6 @@ func (m *Migrator) AppendMigration(name, upSQL, downSQL string) { // Migrate runs pending migrations // It calls m.OnStart when it begins a migration func (m *Migrator) Migrate(ctx context.Context) error { - if err := m.validate(); err != nil { - return err - } return m.MigrateTo(ctx, m.highestSequenceNum()) } @@ -347,6 +373,56 @@ func (m *Migrator) validate() error { return nil } +func (m *Migrator) acquireLock(ctx context.Context) error { + if m.lockingConn != nil { + return m.acquireCustomLock(ctx) + } + + return acquireAdvisoryLock(ctx, m.conn) +} + +func (m *Migrator) releaseLock(ctx context.Context) error { + if m.lockingConn != nil { + return m.releaseCustomLock(ctx) + } + + return releaseAdvisoryLock(ctx, m.conn) +} + +var ErrLockNonRecursive = errors.New("lock is nonrecursive") + +// CockroachDB Compatible Locking Mechanism +func (m *Migrator) acquireCustomLock(ctx context.Context) (err error) { + query := fmt.Sprintf("select * from %s_lock for update nowait", m.versionTable) + + if m.lockingTx != nil { + return ErrLockNonRecursive + } + + m.lockingTx, err = m.lockingConn.Begin(ctx) + if err != nil { + return + } + + if _, err := m.lockingTx.Exec(ctx, query); err != nil { + return err + } + + return nil +} + +// CockroachDB Compatible Locking Mechanism +func (m *Migrator) releaseCustomLock(ctx context.Context) error { + err := m.lockingTx.Commit(ctx) + if err != nil { + return err + } + + m.lockingTx = nil + + return nil +} + // Lock to ensure multiple migrations cannot occur simultaneously const lockNum = int64(9628173550095224) // arbitrary random number @@ -366,12 +442,12 @@ func (m *Migrator) MigrateTo(ctx context.Context, targetVersion int32) (err erro return err } - err = acquireAdvisoryLock(ctx, m.conn) + err = m.acquireLock(ctx) if err != nil { return err } defer func() { - unlockErr := releaseAdvisoryLock(ctx, m.conn) + unlockErr := m.releaseLock(ctx) if err == nil && unlockErr != nil { err = unlockErr } @@ -456,7 +532,7 @@ func (m *Migrator) MigrateTo(ctx context.Context, targetVersion int32) (err erro m.conn.Exec(ctx, "reset all") // Add one to the version - _, err = m.conn.Exec(ctx, "update "+m.versionTable+" set version=$1", sequence) + _, err = m.conn.Exec(ctx, "update "+m.versionTable+" set version=$1 where version >= 0", sequence) if err != nil { return err } @@ -475,17 +551,30 @@ func (m *Migrator) MigrateTo(ctx context.Context, targetVersion int32) (err erro } func (m *Migrator) GetCurrentVersion(ctx context.Context) (v int32, err error) { - err = m.conn.QueryRow(ctx, "select version from "+m.versionTable).Scan(&v) - return v, err + query := "select version from "+m.versionTable+" where version >= 0" + + if m.lockingTx != nil { + err = m.lockingTx.QueryRow(ctx, query).Scan(&v) + } else { + err = m.conn.QueryRow(ctx, query).Scan(&v) + } + + return } func (m *Migrator) ensureSchemaVersionTableExists(ctx context.Context) (err error) { - err = acquireAdvisoryLock(ctx, m.conn) + if m.lockingConn != nil { + // solve the bootstrap problem needing the table + // to lock and needing a lock to create the table + return m.createIfNotExistsVersionTable(ctx) + } + + err = m.acquireLock(ctx) if err != nil { return err } defer func() { - unlockErr := releaseAdvisoryLock(ctx, m.conn) + unlockErr := m.releaseLock(ctx) if err == nil && unlockErr != nil { err = unlockErr } @@ -495,13 +584,34 @@ func (m *Migrator) ensureSchemaVersionTableExists(ctx context.Context) (err erro return err } - _, err = m.conn.Exec(ctx, fmt.Sprintf(` - create table if not exists %s(version int4 not null); + return m.createIfNotExistsVersionTable(ctx) +} + +// Not Thread Safe / Lock Safe +func (m *Migrator) createIfNotExistsVersionTable(ctx context.Context) error { + _, err := m.conn.Exec(ctx, fmt.Sprintf(` + create table if not exists %s(version int4 not null primary key); + + with initial(version) as (values (0)) + insert into %s(version) + select * from initial + where 0=(select count(*) from %s); + `, m.versionTable, m.versionTable, m.versionTable)) + if err != nil { + return err + } + + if m.options.CockroachDbCompatible { + _, err = m.conn.Exec(ctx, fmt.Sprintf(` + create table if not exists %s_lock(lock boolean not null primary key default true); + + with initial(lock) as (values (true)) + insert into %s_lock(lock) + select * from initial + where 0=(select count(*) from %s) + `, m.versionTable, m.versionTable, m.versionTable)) + } - insert into %s(version) - select 0 - where 0=(select count(*) from %s); - `, m.versionTable, m.versionTable, m.versionTable)) return err } @@ -545,13 +655,26 @@ func (m *Migrator) doSQLMigration(ctx context.Context, migration *Migration, dir } // Execute the migration for _, statement := range sqlStatements { - if _, err := m.conn.Exec(ctx, statement); err != nil { - if err, ok := err.(*pgconn.PgError); ok { - return MigrationPgError{MigrationName: migration.Name, Sql: statement, PgError: err} - } + if err := m.sqlExecMigration(ctx, migration, statement); err != nil { return err } } + return nil +} +func (m *Migrator) sqlExecMigration(ctx context.Context, migration *Migration, statement string) error { + if disableRowLocks.MatchString(statement) && m.lockingTx != nil { + m.releaseLock(ctx) + defer m.acquireLock(ctx) + } + + if _, err := m.conn.Exec(ctx, statement); err != nil { + if err, ok := err.(*pgconn.PgError); ok { + return MigrationPgError{MigrationName: migration.Name, Sql: statement, PgError: err} + } + return err + } + + return nil } diff --git a/migrate/migrate_test.go b/migrate/migrate_test.go index cb16d34..bc42f6c 100644 --- a/migrate/migrate_test.go +++ b/migrate/migrate_test.go @@ -46,7 +46,7 @@ func prepareDatabase(t testing.TB) { func currentVersion(t testing.TB, conn *pgx.Conn) int32 { var n int32 - err := conn.QueryRow(context.Background(), "select version from "+versionTable).Scan(&n) + err := conn.QueryRow(context.Background(), "select version from "+versionTable+" where version >= 0").Scan(&n) assert.NoError(t, err) return n } @@ -70,7 +70,7 @@ func tableExists(t testing.TB, conn *pgx.Conn, tableName string) bool { func createEmptyMigrator(t testing.TB, conn *pgx.Conn) *migrate.Migrator { var err error - m, err := migrate.NewMigrator(context.Background(), conn, versionTable) + m, err := migrate.NewMigrator(context.Background(), conn, versionTable, &migrate.MigratorOptions{}) assert.NoError(t, err) return m } @@ -156,7 +156,7 @@ func TestNewMigrator(t *testing.T) { defer conn.Close(context.Background()) // Initial run - m, err := migrate.NewMigrator(context.Background(), conn, versionTable) + m, err := migrate.NewMigrator(context.Background(), conn, versionTable, &migrate.MigratorOptions{}) assert.NoError(t, err) // Creates version table @@ -164,7 +164,7 @@ func TestNewMigrator(t *testing.T) { require.True(t, schemaVersionExists) // Succeeds when version table is already created - m, err = migrate.NewMigrator(context.Background(), conn, versionTable) + m, err = migrate.NewMigrator(context.Background(), conn, versionTable, &migrate.MigratorOptions{}) assert.NoError(t, err) initialVersion, err := m.GetCurrentVersion(context.Background()) @@ -326,7 +326,7 @@ func TestLoadMigrationsNoForward(t *testing.T) { conn := connectConn(t) defer conn.Close(context.Background()) - m, err := migrate.NewMigrator(context.Background(), conn, versionTable) + m, err := migrate.NewMigrator(context.Background(), conn, versionTable, &migrate.MigratorOptions{}) assert.NoError(t, err) m.Data = map[string]interface{}{"prefix": "foo"} @@ -476,12 +476,7 @@ func TestMigrateToBoundaries(t *testing.T) { require.EqualError(t, err, "destination version 4 is outside the valid versions of 0 to 3") // When schema version says it is negative - mustExec(t, conn, "update "+versionTable+" set version=-1") - err = m.MigrateTo(context.Background(), int32(1)) - require.EqualError(t, err, "current version -1 is outside the valid versions of 0 to 3") - - // When schema version says it is negative - mustExec(t, conn, "update "+versionTable+" set version=4") + mustExec(t, conn, "update "+versionTable+" set version=4 where version >= 0") err = m.MigrateTo(context.Background(), int32(1)) require.EqualError(t, err, "current version 4 is outside the valid versions of 0 to 3") } @@ -506,7 +501,7 @@ func TestMigrateToDisableTxInTx(t *testing.T) { tx, err := conn.Begin(ctx) assert.NoError(t, err) - m, err := migrate.NewMigratorEx(ctx, conn, versionTable, &migrate.MigratorOptions{DisableTx: true}) + m, err := migrate.NewMigrator(ctx, conn, versionTable, &migrate.MigratorOptions{DisableTx: true}) assert.NoError(t, err) m.AppendMigration("Create t1", "create table t1(id serial);", "drop table t1;") m.AppendMigration("Create t2", "create table t2(id serial);", "drop table t2;") @@ -531,7 +526,7 @@ func TestMigrateToDisableTx(t *testing.T) { conn := connectConn(t) defer conn.Close(context.Background()) - m, err := migrate.NewMigratorEx(context.Background(), conn, versionTable, &migrate.MigratorOptions{DisableTx: true}) + m, err := migrate.NewMigrator(context.Background(), conn, versionTable, &migrate.MigratorOptions{DisableTx: true}) assert.NoError(t, err) m.AppendMigration("Create t1", "create table t1(id serial);", "drop table t1;") m.AppendMigration("Create t2", "create table t2(id serial);", "drop table t2;") @@ -559,7 +554,7 @@ func TestMigrateToDisableTxInMigration(t *testing.T) { conn := connectConn(t) defer conn.Close(context.Background()) - m, err := migrate.NewMigratorEx(context.Background(), conn, versionTable, &migrate.MigratorOptions{}) + m, err := migrate.NewMigrator(context.Background(), conn, versionTable, &migrate.MigratorOptions{}) assert.NoError(t, err) m.AppendMigration( "Create t1", @@ -580,7 +575,7 @@ func TestMigrationDisableFuncTx(t *testing.T) { t.Run("with DisableFuncTx false Migrator runs function in a transaction ", func(t *testing.T) { var inTxn bool - m, err := migrate.NewMigrator(context.Background(), conn, versionTable) + m, err := migrate.NewMigrator(context.Background(), conn, versionTable, &migrate.MigratorOptions{}) assert.NoError(t, err) m.Migrations = []*migrate.Migration{ { @@ -610,7 +605,7 @@ func TestMigrationDisableFuncTx(t *testing.T) { t.Run("with DisableFuncTx true Migrator runs function outside transaction ", func(t *testing.T) { var inTxn bool - m, err := migrate.NewMigrator(context.Background(), conn, versionTable) + m, err := migrate.NewMigrator(context.Background(), conn, versionTable, &migrate.MigratorOptions{}) assert.NoError(t, err) m.Migrations = []*migrate.Migration{ { @@ -676,7 +671,7 @@ func Example_onStartMigrationProgressLogging() { } var m *migrate.Migrator - m, err = migrate.NewMigrator(context.Background(), conn, "schema_version") + m, err = migrate.NewMigrator(context.Background(), conn, "schema_version", &migrate.MigratorOptions{}) if err != nil { fmt.Printf("Unable to create migrator: %v", err) return