aboutsummaryrefslogtreecommitdiff
path: root/pkg
diff options
context:
space:
mode:
authorJulien Dessaux2021-04-09 23:56:16 +0200
committerJulien Dessaux2021-04-09 23:56:16 +0200
commit824226ef686d604bc3b1e8675143dabfe8df6555 (patch)
tree913e16e3ac767f89bcb668500d7c3ef53fe58423 /pkg
parentMade config tests more specific (diff)
downloadtrains-824226ef686d604bc3b1e8675143dabfe8df6555.tar.gz
trains-824226ef686d604bc3b1e8675143dabfe8df6555.tar.bz2
trains-824226ef686d604bc3b1e8675143dabfe8df6555.zip
Began implementing a database backend
Diffstat (limited to 'pkg')
-rw-r--r--pkg/database/database.go58
-rw-r--r--pkg/database/database_test.go143
-rw-r--r--pkg/database/error.go57
-rw-r--r--pkg/database/error_test.go15
-rw-r--r--pkg/database/migrations.go19
5 files changed, 292 insertions, 0 deletions
diff --git a/pkg/database/database.go b/pkg/database/database.go
new file mode 100644
index 0000000..c89a334
--- /dev/null
+++ b/pkg/database/database.go
@@ -0,0 +1,58 @@
+package database
+
+import (
+ "database/sql"
+ "time"
+
+ _ "github.com/mattn/go-sqlite3"
+)
+
+// DBEnv is the struct that holds this package together
+type DBEnv struct {
+ db *sql.DB
+}
+
+// InitDB initializes database access and the connection pool
+func InitDB(dbType string, dsn string) (*DBEnv, error) {
+ db, err := sql.Open(dbType, dsn)
+ if err != nil {
+ return nil, newInitError(dsn, err)
+ }
+ db.SetMaxOpenConns(25)
+ db.SetMaxIdleConns(25)
+ db.SetConnMaxLifetime(5 * time.Minute)
+
+ if err := db.Ping(); err != nil {
+ _ = db.Close() // TODO think about handling this error case?
+ return nil, newInitError(dsn, err)
+ }
+
+ return &DBEnv{db: db}, nil
+}
+
+// Migrate performs the migrations of the database to the latest schema_version
+func (env *DBEnv) Migrate() error {
+ var currentVersion int
+ if err := env.db.QueryRow(`SELECT version FROM schema_version`).Scan(&currentVersion); err != nil {
+ currentVersion = 0
+ }
+ for version := currentVersion; version < len(migrations); version++ {
+ newVersion := version + 1
+ tx, err := env.db.Begin()
+ if err != nil {
+ return newTransactionError("Could not begin transaction", err)
+ }
+ if err := migrations[version](tx); err != nil {
+ tx.Rollback()
+ return newMigrationError(newVersion, err)
+ }
+ if _, err := tx.Exec(`DELETE FROM schema_version; INSERT INTO schema_version (version) VALUES ($1)`, newVersion); err != nil {
+ tx.Rollback()
+ return newMigrationError(newVersion, err)
+ }
+ if err := tx.Commit(); err != nil {
+ return newTransactionError("Could not commit transaction", err)
+ }
+ }
+ return nil
+}
diff --git a/pkg/database/database_test.go b/pkg/database/database_test.go
new file mode 100644
index 0000000..aad7953
--- /dev/null
+++ b/pkg/database/database_test.go
@@ -0,0 +1,143 @@
+package database
+
+import (
+ "database/sql"
+ "os"
+ "reflect"
+ "testing"
+
+ "github.com/DATA-DOG/go-sqlmock"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestInitDB(t *testing.T) {
+ // Test cases
+ testCases := []struct {
+ name string
+ dbType string
+ dsn string
+ expectedError interface{}
+ }{
+ {"Invalid dbType", "non-existant", "test", &InitError{}},
+ {"Non existant path", "sqlite3", "/non-existant/non-existant", &InitError{}},
+ {"Working DB", "sqlite3", "file::memory:?_foreign_keys=on", nil},
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ db, err := InitDB(tc.dbType, tc.dsn)
+ if tc.expectedError != nil {
+ require.Error(t, err)
+ assert.Equalf(t, reflect.TypeOf(err), reflect.TypeOf(tc.expectedError), "Invalid error type. Got %s but expected %s", reflect.TypeOf(err), reflect.TypeOf(tc.expectedError))
+ require.Nil(t, db)
+ } else {
+ require.NoError(t, err)
+ require.NotNil(t, db)
+ }
+ })
+ }
+}
+
+func TestMigrate(t *testing.T) {
+ badMigration := []func(tx *sql.Tx) error{
+ func(tx *sql.Tx) (err error) {
+ return newMigrationError(0, nil)
+ },
+ }
+ noSchemaVersionMigration := []func(tx *sql.Tx) error{
+ func(tx *sql.Tx) (err error) {
+ return nil
+ },
+ }
+ onlyFirstMigration := []func(tx *sql.Tx) error{migrations[0]}
+ notFromScratchMigration := []func(tx *sql.Tx) error{
+ func(tx *sql.Tx) (err error) {
+ return &MigrationError{version: 1, err: nil}
+ },
+ func(tx *sql.Tx) (err error) {
+ return nil
+ },
+ }
+ os.Remove("testfile_notFromScratch.db")
+ notFromScratchDB, err := InitDB("sqlite3", "file:testfile_notFromScratch.db?_foreign_keys=on")
+ if err != nil {
+ t.Fatalf("Failed to init testfile.db : %+v", err)
+ }
+ defer os.Remove("testfile_notFromScratch.db")
+ migrations = onlyFirstMigration
+ if err := notFromScratchDB.Migrate(); err != nil {
+ t.Fatalf("Failed to migrate testfile.db to first schema version : %+v", err)
+ }
+
+ // Test cases
+ testCases := []struct {
+ name string
+ dsn string
+ migrs []func(tx *sql.Tx) error
+ expectedError interface{}
+ }{
+ {"bad migration", "file::memory:?_foreign_keys=on", badMigration, &MigrationError{}},
+ {"no schema_version migration", "file::memory:?_foreign_keys=on", noSchemaVersionMigration, &MigrationError{}},
+ {"not from scratch", "file:testfile_notFromScratch.db?_foreign_keys=on", notFromScratchMigration, nil},
+ {"from scratch", "file::memory:?_foreign_keys=on", allMigrations, nil},
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ db, err := InitDB("sqlite3", tc.dsn)
+ require.NoError(t, err)
+ migrations = tc.migrs
+ err = db.Migrate()
+ if tc.expectedError != nil {
+ require.Error(t, err)
+ assert.Equalf(t, reflect.TypeOf(err), reflect.TypeOf(tc.expectedError), "Invalid error type. Got %s but expected %s", reflect.TypeOf(err), reflect.TypeOf(tc.expectedError))
+ } else {
+ require.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestMigrateWithSQLMock(t *testing.T) {
+ fakeMigration := []func(tx *sql.Tx) error{
+ func(tx *sql.Tx) (err error) {
+ return nil
+ },
+ }
+ // Transaction begin error
+ dbBeginError, _, err := sqlmock.New()
+ if err != nil {
+ t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
+ }
+ defer dbBeginError.Close()
+ // Transaction commit error
+ dbCommitError, mockCommitError, err := sqlmock.New()
+ if err != nil {
+ t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
+ }
+ defer dbCommitError.Close()
+ mockCommitError.ExpectBegin()
+ mockCommitError.ExpectExec(`DELETE FROM schema_version`).WillReturnResult(sqlmock.NewResult(1, 1))
+ // Test cases
+ testCases := []struct {
+ name string
+ db *DBEnv
+ migrs []func(tx *sql.Tx) error
+ expectedError interface{}
+ }{
+ {"begin transaction error", &DBEnv{db: dbBeginError}, fakeMigration, &TransactionError{}},
+ {"commit transaction error", &DBEnv{db: dbCommitError}, fakeMigration, &TransactionError{}},
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ migrations = tc.migrs
+ err = tc.db.Migrate()
+ migrations = allMigrations
+ if tc.expectedError != nil {
+ require.Error(t, err)
+ assert.Equalf(t, reflect.TypeOf(err), reflect.TypeOf(tc.expectedError), "Invalid error type. Got %s but expected %s", reflect.TypeOf(err), reflect.TypeOf(tc.expectedError))
+ } else {
+ require.NoError(t, err)
+ }
+ })
+ }
+}
diff --git a/pkg/database/error.go b/pkg/database/error.go
new file mode 100644
index 0000000..1681dd5
--- /dev/null
+++ b/pkg/database/error.go
@@ -0,0 +1,57 @@
+package database
+
+import "fmt"
+
+// database init error
+type InitError struct {
+ path string
+ err error
+}
+
+func (e *InitError) Error() string {
+ return fmt.Sprintf("Failed to open database : %s", e.path)
+}
+func (e *InitError) Unwrap() error { return e.err }
+
+func newInitError(path string, err error) error {
+ return &InitError{
+ path: path,
+ err: err,
+ }
+}
+
+// database migration error
+type MigrationError struct {
+ version int
+ err error
+}
+
+func (e *MigrationError) Error() string {
+ return fmt.Sprintf("Failed to migrate database to version %d : %s", e.version, e.err)
+}
+func (e *MigrationError) Unwrap() error { return e.err }
+
+func newMigrationError(version int, err error) error {
+ return &MigrationError{
+ version: version,
+ err: err,
+ }
+}
+
+// database transaction error
+type TransactionError struct {
+ msg string
+ err error
+}
+
+func (e *TransactionError) Error() string {
+ return fmt.Sprintf("Failed to perform transaction : %s", e.msg)
+}
+func (e *TransactionError) Unwrap() error { return e.err }
+
+func newTransactionError(msg string, err error) error {
+ return &TransactionError{
+ msg: msg,
+ err: err,
+ }
+}
diff --git a/pkg/database/error_test.go b/pkg/database/error_test.go
new file mode 100644
index 0000000..8031305
--- /dev/null
+++ b/pkg/database/error_test.go
@@ -0,0 +1,15 @@
+package database
+
+import "testing"
+
+func TestErrorsCoverage(t *testing.T) {
+ initErr := InitError{}
+ _ = initErr.Error()
+ _ = initErr.Unwrap()
+ migrationErr := MigrationError{}
+ _ = migrationErr.Error()
+ _ = migrationErr.Unwrap()
+ transactionErr := TransactionError{}
+ _ = transactionErr.Error()
+ _ = transactionErr.Unwrap()
+}
diff --git a/pkg/database/migrations.go b/pkg/database/migrations.go
new file mode 100644
index 0000000..cec3c41
--- /dev/null
+++ b/pkg/database/migrations.go
@@ -0,0 +1,19 @@
+package database
+
+import "database/sql"
+
+// allMigrations is the list of migrations to perform to get an up to date database
+// Order is important. Add new migrations at the end of the list.
+var allMigrations = []func(tx *sql.Tx) error{
+ func(tx *sql.Tx) (err error) {
+ sql := `
+ CREATE TABLE schema_version (
+ version INTEGER NOT NULL
+ );`
+ _, err = tx.Exec(sql)
+ return err
+ },
+}
+
+// This variable exists so that tests can override it
+var migrations = allMigrations