aboutsummaryrefslogtreecommitdiff
path: root/pkg/database/database_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/database/database_test.go')
-rw-r--r--pkg/database/database_test.go31
1 files changed, 11 insertions, 20 deletions
diff --git a/pkg/database/database_test.go b/pkg/database/database_test.go
index aad7953..69f0012 100644
--- a/pkg/database/database_test.go
+++ b/pkg/database/database_test.go
@@ -3,11 +3,9 @@ package database
import (
"database/sql"
"os"
- "reflect"
"testing"
"github.com/DATA-DOG/go-sqlmock"
- "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -17,7 +15,7 @@ func TestInitDB(t *testing.T) {
name string
dbType string
dsn string
- expectedError interface{}
+ expectedError error
}{
{"Invalid dbType", "non-existant", "test", &InitError{}},
{"Non existant path", "sqlite3", "/non-existant/non-existant", &InitError{}},
@@ -28,7 +26,7 @@ func TestInitDB(t *testing.T) {
db, err := InitDB(tc.dbType, tc.dsn)
if tc.expectedError != nil {
require.Error(t, err)
- assert.Equalf(t, reflect.TypeOf(err), reflect.TypeOf(tc.expectedError), "Invalid error type. Got %s but expected %s", reflect.TypeOf(err), reflect.TypeOf(tc.expectedError))
+ requireErrorTypeMatch(t, err, tc.expectedError)
require.Nil(t, db)
} else {
require.NoError(t, err)
@@ -60,21 +58,18 @@ func TestMigrate(t *testing.T) {
}
os.Remove("testfile_notFromScratch.db")
notFromScratchDB, err := InitDB("sqlite3", "file:testfile_notFromScratch.db?_foreign_keys=on")
- if err != nil {
- t.Fatalf("Failed to init testfile.db : %+v", err)
- }
+ require.NoError(t, err, "Failed to init testfile.db : %+v", err)
defer os.Remove("testfile_notFromScratch.db")
migrations = onlyFirstMigration
- if err := notFromScratchDB.Migrate(); err != nil {
- t.Fatalf("Failed to migrate testfile.db to first schema version : %+v", err)
- }
+ err = notFromScratchDB.Migrate()
+ require.NoError(t, err, "Failed to migrate testfile.db to first schema version : %+v", err)
// Test cases
testCases := []struct {
name string
dsn string
migrs []func(tx *sql.Tx) error
- expectedError interface{}
+ expectedError error
}{
{"bad migration", "file::memory:?_foreign_keys=on", badMigration, &MigrationError{}},
{"no schema_version migration", "file::memory:?_foreign_keys=on", noSchemaVersionMigration, &MigrationError{}},
@@ -89,7 +84,7 @@ func TestMigrate(t *testing.T) {
err = db.Migrate()
if tc.expectedError != nil {
require.Error(t, err)
- assert.Equalf(t, reflect.TypeOf(err), reflect.TypeOf(tc.expectedError), "Invalid error type. Got %s but expected %s", reflect.TypeOf(err), reflect.TypeOf(tc.expectedError))
+ requireErrorTypeMatch(t, err, tc.expectedError)
} else {
require.NoError(t, err)
}
@@ -105,15 +100,11 @@ func TestMigrateWithSQLMock(t *testing.T) {
}
// Transaction begin error
dbBeginError, _, err := sqlmock.New()
- if err != nil {
- t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
- }
+ require.NoError(t, err, "an error '%s' was not expected when opening a stub database connection", err)
defer dbBeginError.Close()
// Transaction commit error
dbCommitError, mockCommitError, err := sqlmock.New()
- if err != nil {
- t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
- }
+ require.NoError(t, err, "an error '%s' was not expected when opening a stub database connection", err)
defer dbCommitError.Close()
mockCommitError.ExpectBegin()
mockCommitError.ExpectExec(`DELETE FROM schema_version`).WillReturnResult(sqlmock.NewResult(1, 1))
@@ -122,7 +113,7 @@ func TestMigrateWithSQLMock(t *testing.T) {
name string
db *DBEnv
migrs []func(tx *sql.Tx) error
- expectedError interface{}
+ expectedError error
}{
{"begin transaction error", &DBEnv{db: dbBeginError}, fakeMigration, &TransactionError{}},
{"commit transaction error", &DBEnv{db: dbCommitError}, fakeMigration, &TransactionError{}},
@@ -134,7 +125,7 @@ func TestMigrateWithSQLMock(t *testing.T) {
migrations = allMigrations
if tc.expectedError != nil {
require.Error(t, err)
- assert.Equalf(t, reflect.TypeOf(err), reflect.TypeOf(tc.expectedError), "Invalid error type. Got %s but expected %s", reflect.TypeOf(err), reflect.TypeOf(tc.expectedError))
+ requireErrorTypeMatch(t, err, tc.expectedError)
} else {
require.NoError(t, err)
}