diff options
Diffstat (limited to 'pkg/database')
-rw-r--r-- | pkg/database/database.go | 58 | ||||
-rw-r--r-- | pkg/database/database_test.go | 143 | ||||
-rw-r--r-- | pkg/database/error.go | 57 | ||||
-rw-r--r-- | pkg/database/error_test.go | 15 | ||||
-rw-r--r-- | pkg/database/migrations.go | 19 |
5 files changed, 292 insertions, 0 deletions
diff --git a/pkg/database/database.go b/pkg/database/database.go new file mode 100644 index 0000000..c89a334 --- /dev/null +++ b/pkg/database/database.go @@ -0,0 +1,58 @@ +package database + +import ( + "database/sql" + "time" + + _ "github.com/mattn/go-sqlite3" +) + +// DBEnv is the struct that holds this package together +type DBEnv struct { + db *sql.DB +} + +// InitDB initializes database access and the connection pool +func InitDB(dbType string, dsn string) (*DBEnv, error) { + db, err := sql.Open(dbType, dsn) + if err != nil { + return nil, newInitError(dsn, err) + } + db.SetMaxOpenConns(25) + db.SetMaxIdleConns(25) + db.SetConnMaxLifetime(5 * time.Minute) + + if err := db.Ping(); err != nil { + _ = db.Close() // TODO think about handling this error case? + return nil, newInitError(dsn, err) + } + + return &DBEnv{db: db}, nil +} + +// Migrate performs the migrations of the database to the latest schema_version +func (env *DBEnv) Migrate() error { + var currentVersion int + if err := env.db.QueryRow(`SELECT version FROM schema_version`).Scan(¤tVersion); err != nil { + currentVersion = 0 + } + for version := currentVersion; version < len(migrations); version++ { + newVersion := version + 1 + tx, err := env.db.Begin() + if err != nil { + return newTransactionError("Could not begin transaction", err) + } + if err := migrations[version](tx); err != nil { + tx.Rollback() + return newMigrationError(newVersion, err) + } + if _, err := tx.Exec(`DELETE FROM schema_version; INSERT INTO schema_version (version) VALUES ($1)`, newVersion); err != nil { + tx.Rollback() + return newMigrationError(newVersion, err) + } + if err := tx.Commit(); err != nil { + return newTransactionError("Could not commit transaction", err) + } + } + return nil +} diff --git a/pkg/database/database_test.go b/pkg/database/database_test.go new file mode 100644 index 0000000..aad7953 --- /dev/null +++ b/pkg/database/database_test.go @@ -0,0 +1,143 @@ +package database + +import ( + "database/sql" + "os" + "reflect" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestInitDB(t *testing.T) { + // Test cases + testCases := []struct { + name string + dbType string + dsn string + expectedError interface{} + }{ + {"Invalid dbType", "non-existant", "test", &InitError{}}, + {"Non existant path", "sqlite3", "/non-existant/non-existant", &InitError{}}, + {"Working DB", "sqlite3", "file::memory:?_foreign_keys=on", nil}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + db, err := InitDB(tc.dbType, tc.dsn) + if tc.expectedError != nil { + require.Error(t, err) + assert.Equalf(t, reflect.TypeOf(err), reflect.TypeOf(tc.expectedError), "Invalid error type. Got %s but expected %s", reflect.TypeOf(err), reflect.TypeOf(tc.expectedError)) + require.Nil(t, db) + } else { + require.NoError(t, err) + require.NotNil(t, db) + } + }) + } +} + +func TestMigrate(t *testing.T) { + badMigration := []func(tx *sql.Tx) error{ + func(tx *sql.Tx) (err error) { + return newMigrationError(0, nil) + }, + } + noSchemaVersionMigration := []func(tx *sql.Tx) error{ + func(tx *sql.Tx) (err error) { + return nil + }, + } + onlyFirstMigration := []func(tx *sql.Tx) error{migrations[0]} + notFromScratchMigration := []func(tx *sql.Tx) error{ + func(tx *sql.Tx) (err error) { + return &MigrationError{version: 1, err: nil} + }, + func(tx *sql.Tx) (err error) { + return nil + }, + } + os.Remove("testfile_notFromScratch.db") + notFromScratchDB, err := InitDB("sqlite3", "file:testfile_notFromScratch.db?_foreign_keys=on") + if err != nil { + t.Fatalf("Failed to init testfile.db : %+v", err) + } + defer os.Remove("testfile_notFromScratch.db") + migrations = onlyFirstMigration + if err := notFromScratchDB.Migrate(); err != nil { + t.Fatalf("Failed to migrate testfile.db to first schema version : %+v", err) + } + + // Test cases + testCases := []struct { + name string + dsn string + migrs []func(tx *sql.Tx) error + expectedError interface{} + }{ + {"bad migration", "file::memory:?_foreign_keys=on", badMigration, &MigrationError{}}, + {"no schema_version migration", "file::memory:?_foreign_keys=on", noSchemaVersionMigration, &MigrationError{}}, + {"not from scratch", "file:testfile_notFromScratch.db?_foreign_keys=on", notFromScratchMigration, nil}, + {"from scratch", "file::memory:?_foreign_keys=on", allMigrations, nil}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + db, err := InitDB("sqlite3", tc.dsn) + require.NoError(t, err) + migrations = tc.migrs + err = db.Migrate() + if tc.expectedError != nil { + require.Error(t, err) + assert.Equalf(t, reflect.TypeOf(err), reflect.TypeOf(tc.expectedError), "Invalid error type. Got %s but expected %s", reflect.TypeOf(err), reflect.TypeOf(tc.expectedError)) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestMigrateWithSQLMock(t *testing.T) { + fakeMigration := []func(tx *sql.Tx) error{ + func(tx *sql.Tx) (err error) { + return nil + }, + } + // Transaction begin error + dbBeginError, _, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer dbBeginError.Close() + // Transaction commit error + dbCommitError, mockCommitError, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer dbCommitError.Close() + mockCommitError.ExpectBegin() + mockCommitError.ExpectExec(`DELETE FROM schema_version`).WillReturnResult(sqlmock.NewResult(1, 1)) + // Test cases + testCases := []struct { + name string + db *DBEnv + migrs []func(tx *sql.Tx) error + expectedError interface{} + }{ + {"begin transaction error", &DBEnv{db: dbBeginError}, fakeMigration, &TransactionError{}}, + {"commit transaction error", &DBEnv{db: dbCommitError}, fakeMigration, &TransactionError{}}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + migrations = tc.migrs + err = tc.db.Migrate() + migrations = allMigrations + if tc.expectedError != nil { + require.Error(t, err) + assert.Equalf(t, reflect.TypeOf(err), reflect.TypeOf(tc.expectedError), "Invalid error type. Got %s but expected %s", reflect.TypeOf(err), reflect.TypeOf(tc.expectedError)) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/pkg/database/error.go b/pkg/database/error.go new file mode 100644 index 0000000..1681dd5 --- /dev/null +++ b/pkg/database/error.go @@ -0,0 +1,57 @@ +package database + +import "fmt" + +// database init error +type InitError struct { + path string + err error +} + +func (e *InitError) Error() string { + return fmt.Sprintf("Failed to open database : %s", e.path) +} +func (e *InitError) Unwrap() error { return e.err } + +func newInitError(path string, err error) error { + return &InitError{ + path: path, + err: err, + } +} + +// database migration error +type MigrationError struct { + version int + err error +} + +func (e *MigrationError) Error() string { + return fmt.Sprintf("Failed to migrate database to version %d : %s", e.version, e.err) +} +func (e *MigrationError) Unwrap() error { return e.err } + +func newMigrationError(version int, err error) error { + return &MigrationError{ + version: version, + err: err, + } +} + +// database transaction error +type TransactionError struct { + msg string + err error +} + +func (e *TransactionError) Error() string { + return fmt.Sprintf("Failed to perform transaction : %s", e.msg) +} +func (e *TransactionError) Unwrap() error { return e.err } + +func newTransactionError(msg string, err error) error { + return &TransactionError{ + msg: msg, + err: err, + } +} diff --git a/pkg/database/error_test.go b/pkg/database/error_test.go new file mode 100644 index 0000000..8031305 --- /dev/null +++ b/pkg/database/error_test.go @@ -0,0 +1,15 @@ +package database + +import "testing" + +func TestErrorsCoverage(t *testing.T) { + initErr := InitError{} + _ = initErr.Error() + _ = initErr.Unwrap() + migrationErr := MigrationError{} + _ = migrationErr.Error() + _ = migrationErr.Unwrap() + transactionErr := TransactionError{} + _ = transactionErr.Error() + _ = transactionErr.Unwrap() +} diff --git a/pkg/database/migrations.go b/pkg/database/migrations.go new file mode 100644 index 0000000..cec3c41 --- /dev/null +++ b/pkg/database/migrations.go @@ -0,0 +1,19 @@ +package database + +import "database/sql" + +// allMigrations is the list of migrations to perform to get an up to date database +// Order is important. Add new migrations at the end of the list. +var allMigrations = []func(tx *sql.Tx) error{ + func(tx *sql.Tx) (err error) { + sql := ` + CREATE TABLE schema_version ( + version INTEGER NOT NULL + );` + _, err = tx.Exec(sql) + return err + }, +} + +// This variable exists so that tests can override it +var migrations = allMigrations |