package database import ( "database/sql" "os" "reflect" "testing" "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestInitDB(t *testing.T) { // Test cases testCases := []struct { name string dbType string dsn string expectedError interface{} }{ {"Invalid dbType", "non-existant", "test", &InitError{}}, {"Non existant path", "sqlite3", "/non-existant/non-existant", &InitError{}}, {"Working DB", "sqlite3", "file::memory:?_foreign_keys=on", nil}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { db, err := InitDB(tc.dbType, tc.dsn) if tc.expectedError != nil { require.Error(t, err) assert.Equalf(t, reflect.TypeOf(err), reflect.TypeOf(tc.expectedError), "Invalid error type. Got %s but expected %s", reflect.TypeOf(err), reflect.TypeOf(tc.expectedError)) require.Nil(t, db) } else { require.NoError(t, err) require.NotNil(t, db) } }) } } func TestMigrate(t *testing.T) { badMigration := []func(tx *sql.Tx) error{ func(tx *sql.Tx) (err error) { return newMigrationError(0, nil) }, } noSchemaVersionMigration := []func(tx *sql.Tx) error{ func(tx *sql.Tx) (err error) { return nil }, } onlyFirstMigration := []func(tx *sql.Tx) error{migrations[0]} notFromScratchMigration := []func(tx *sql.Tx) error{ func(tx *sql.Tx) (err error) { return &MigrationError{version: 1, err: nil} }, func(tx *sql.Tx) (err error) { return nil }, } os.Remove("testfile_notFromScratch.db") notFromScratchDB, err := InitDB("sqlite3", "file:testfile_notFromScratch.db?_foreign_keys=on") if err != nil { t.Fatalf("Failed to init testfile.db : %+v", err) } defer os.Remove("testfile_notFromScratch.db") migrations = onlyFirstMigration if err := notFromScratchDB.Migrate(); err != nil { t.Fatalf("Failed to migrate testfile.db to first schema version : %+v", err) } // Test cases testCases := []struct { name string dsn string migrs []func(tx *sql.Tx) error expectedError interface{} }{ {"bad migration", "file::memory:?_foreign_keys=on", badMigration, &MigrationError{}}, {"no schema_version migration", "file::memory:?_foreign_keys=on", noSchemaVersionMigration, &MigrationError{}}, {"not from scratch", "file:testfile_notFromScratch.db?_foreign_keys=on", notFromScratchMigration, nil}, {"from scratch", "file::memory:?_foreign_keys=on", allMigrations, nil}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { db, err := InitDB("sqlite3", tc.dsn) require.NoError(t, err) migrations = tc.migrs err = db.Migrate() if tc.expectedError != nil { require.Error(t, err) assert.Equalf(t, reflect.TypeOf(err), reflect.TypeOf(tc.expectedError), "Invalid error type. Got %s but expected %s", reflect.TypeOf(err), reflect.TypeOf(tc.expectedError)) } else { require.NoError(t, err) } }) } } func TestMigrateWithSQLMock(t *testing.T) { fakeMigration := []func(tx *sql.Tx) error{ func(tx *sql.Tx) (err error) { return nil }, } // Transaction begin error dbBeginError, _, err := sqlmock.New() if err != nil { t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) } defer dbBeginError.Close() // Transaction commit error dbCommitError, mockCommitError, err := sqlmock.New() if err != nil { t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) } defer dbCommitError.Close() mockCommitError.ExpectBegin() mockCommitError.ExpectExec(`DELETE FROM schema_version`).WillReturnResult(sqlmock.NewResult(1, 1)) // Test cases testCases := []struct { name string db *DBEnv migrs []func(tx *sql.Tx) error expectedError interface{} }{ {"begin transaction error", &DBEnv{db: dbBeginError}, fakeMigration, &TransactionError{}}, {"commit transaction error", &DBEnv{db: dbCommitError}, fakeMigration, &TransactionError{}}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { migrations = tc.migrs err = tc.db.Migrate() migrations = allMigrations if tc.expectedError != nil { require.Error(t, err) assert.Equalf(t, reflect.TypeOf(err), reflect.TypeOf(tc.expectedError), "Invalid error type. Got %s but expected %s", reflect.TypeOf(err), reflect.TypeOf(tc.expectedError)) } else { require.NoError(t, err) } }) } }