aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--go.mod3
-rw-r--r--go.sum7
-rw-r--r--pkg/database/database.go58
-rw-r--r--pkg/database/database_test.go143
-rw-r--r--pkg/database/error.go57
-rw-r--r--pkg/database/error_test.go15
-rw-r--r--pkg/database/migrations.go19
7 files changed, 301 insertions, 1 deletions
diff --git a/go.mod b/go.mod
index 255bfa4..254fdec 100644
--- a/go.mod
+++ b/go.mod
@@ -3,7 +3,10 @@ module git.adyxax.org/adyxax/trains
go 1.16
require (
+ github.com/DATA-DOG/go-sqlmock v1.5.0
+ github.com/davecgh/go-spew v1.1.1 // indirect
github.com/kr/pretty v0.2.1 // indirect
+ github.com/mattn/go-sqlite3 v1.14.6
github.com/stretchr/testify v1.7.0
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b
diff --git a/go.sum b/go.sum
index cd0dd5e..a268c59 100644
--- a/go.sum
+++ b/go.sum
@@ -1,10 +1,15 @@
-github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
+github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60=
+github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
+github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
+github.com/mattn/go-sqlite3 v1.14.6 h1:dNPt6NO46WmLVt2DLNpwczCmdV5boIZ6g/tlDrlRUbg=
+github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
diff --git a/pkg/database/database.go b/pkg/database/database.go
new file mode 100644
index 0000000..c89a334
--- /dev/null
+++ b/pkg/database/database.go
@@ -0,0 +1,58 @@
+package database
+
+import (
+ "database/sql"
+ "time"
+
+ _ "github.com/mattn/go-sqlite3"
+)
+
+// DBEnv is the struct that holds this package together
+type DBEnv struct {
+ db *sql.DB
+}
+
+// InitDB initializes database access and the connection pool
+func InitDB(dbType string, dsn string) (*DBEnv, error) {
+ db, err := sql.Open(dbType, dsn)
+ if err != nil {
+ return nil, newInitError(dsn, err)
+ }
+ db.SetMaxOpenConns(25)
+ db.SetMaxIdleConns(25)
+ db.SetConnMaxLifetime(5 * time.Minute)
+
+ if err := db.Ping(); err != nil {
+ _ = db.Close() // TODO think about handling this error case?
+ return nil, newInitError(dsn, err)
+ }
+
+ return &DBEnv{db: db}, nil
+}
+
+// Migrate performs the migrations of the database to the latest schema_version
+func (env *DBEnv) Migrate() error {
+ var currentVersion int
+ if err := env.db.QueryRow(`SELECT version FROM schema_version`).Scan(&currentVersion); err != nil {
+ currentVersion = 0
+ }
+ for version := currentVersion; version < len(migrations); version++ {
+ newVersion := version + 1
+ tx, err := env.db.Begin()
+ if err != nil {
+ return newTransactionError("Could not begin transaction", err)
+ }
+ if err := migrations[version](tx); err != nil {
+ tx.Rollback()
+ return newMigrationError(newVersion, err)
+ }
+ if _, err := tx.Exec(`DELETE FROM schema_version; INSERT INTO schema_version (version) VALUES ($1)`, newVersion); err != nil {
+ tx.Rollback()
+ return newMigrationError(newVersion, err)
+ }
+ if err := tx.Commit(); err != nil {
+ return newTransactionError("Could not commit transaction", err)
+ }
+ }
+ return nil
+}
diff --git a/pkg/database/database_test.go b/pkg/database/database_test.go
new file mode 100644
index 0000000..aad7953
--- /dev/null
+++ b/pkg/database/database_test.go
@@ -0,0 +1,143 @@
+package database
+
+import (
+ "database/sql"
+ "os"
+ "reflect"
+ "testing"
+
+ "github.com/DATA-DOG/go-sqlmock"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestInitDB(t *testing.T) {
+ // Test cases
+ testCases := []struct {
+ name string
+ dbType string
+ dsn string
+ expectedError interface{}
+ }{
+ {"Invalid dbType", "non-existant", "test", &InitError{}},
+ {"Non existant path", "sqlite3", "/non-existant/non-existant", &InitError{}},
+ {"Working DB", "sqlite3", "file::memory:?_foreign_keys=on", nil},
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ db, err := InitDB(tc.dbType, tc.dsn)
+ if tc.expectedError != nil {
+ require.Error(t, err)
+ assert.Equalf(t, reflect.TypeOf(err), reflect.TypeOf(tc.expectedError), "Invalid error type. Got %s but expected %s", reflect.TypeOf(err), reflect.TypeOf(tc.expectedError))
+ require.Nil(t, db)
+ } else {
+ require.NoError(t, err)
+ require.NotNil(t, db)
+ }
+ })
+ }
+}
+
+func TestMigrate(t *testing.T) {
+ badMigration := []func(tx *sql.Tx) error{
+ func(tx *sql.Tx) (err error) {
+ return newMigrationError(0, nil)
+ },
+ }
+ noSchemaVersionMigration := []func(tx *sql.Tx) error{
+ func(tx *sql.Tx) (err error) {
+ return nil
+ },
+ }
+ onlyFirstMigration := []func(tx *sql.Tx) error{migrations[0]}
+ notFromScratchMigration := []func(tx *sql.Tx) error{
+ func(tx *sql.Tx) (err error) {
+ return &MigrationError{version: 1, err: nil}
+ },
+ func(tx *sql.Tx) (err error) {
+ return nil
+ },
+ }
+ os.Remove("testfile_notFromScratch.db")
+ notFromScratchDB, err := InitDB("sqlite3", "file:testfile_notFromScratch.db?_foreign_keys=on")
+ if err != nil {
+ t.Fatalf("Failed to init testfile.db : %+v", err)
+ }
+ defer os.Remove("testfile_notFromScratch.db")
+ migrations = onlyFirstMigration
+ if err := notFromScratchDB.Migrate(); err != nil {
+ t.Fatalf("Failed to migrate testfile.db to first schema version : %+v", err)
+ }
+
+ // Test cases
+ testCases := []struct {
+ name string
+ dsn string
+ migrs []func(tx *sql.Tx) error
+ expectedError interface{}
+ }{
+ {"bad migration", "file::memory:?_foreign_keys=on", badMigration, &MigrationError{}},
+ {"no schema_version migration", "file::memory:?_foreign_keys=on", noSchemaVersionMigration, &MigrationError{}},
+ {"not from scratch", "file:testfile_notFromScratch.db?_foreign_keys=on", notFromScratchMigration, nil},
+ {"from scratch", "file::memory:?_foreign_keys=on", allMigrations, nil},
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ db, err := InitDB("sqlite3", tc.dsn)
+ require.NoError(t, err)
+ migrations = tc.migrs
+ err = db.Migrate()
+ if tc.expectedError != nil {
+ require.Error(t, err)
+ assert.Equalf(t, reflect.TypeOf(err), reflect.TypeOf(tc.expectedError), "Invalid error type. Got %s but expected %s", reflect.TypeOf(err), reflect.TypeOf(tc.expectedError))
+ } else {
+ require.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestMigrateWithSQLMock(t *testing.T) {
+ fakeMigration := []func(tx *sql.Tx) error{
+ func(tx *sql.Tx) (err error) {
+ return nil
+ },
+ }
+ // Transaction begin error
+ dbBeginError, _, err := sqlmock.New()
+ if err != nil {
+ t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
+ }
+ defer dbBeginError.Close()
+ // Transaction commit error
+ dbCommitError, mockCommitError, err := sqlmock.New()
+ if err != nil {
+ t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
+ }
+ defer dbCommitError.Close()
+ mockCommitError.ExpectBegin()
+ mockCommitError.ExpectExec(`DELETE FROM schema_version`).WillReturnResult(sqlmock.NewResult(1, 1))
+ // Test cases
+ testCases := []struct {
+ name string
+ db *DBEnv
+ migrs []func(tx *sql.Tx) error
+ expectedError interface{}
+ }{
+ {"begin transaction error", &DBEnv{db: dbBeginError}, fakeMigration, &TransactionError{}},
+ {"commit transaction error", &DBEnv{db: dbCommitError}, fakeMigration, &TransactionError{}},
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ migrations = tc.migrs
+ err = tc.db.Migrate()
+ migrations = allMigrations
+ if tc.expectedError != nil {
+ require.Error(t, err)
+ assert.Equalf(t, reflect.TypeOf(err), reflect.TypeOf(tc.expectedError), "Invalid error type. Got %s but expected %s", reflect.TypeOf(err), reflect.TypeOf(tc.expectedError))
+ } else {
+ require.NoError(t, err)
+ }
+ })
+ }
+}
diff --git a/pkg/database/error.go b/pkg/database/error.go
new file mode 100644
index 0000000..1681dd5
--- /dev/null
+++ b/pkg/database/error.go
@@ -0,0 +1,57 @@
+package database
+
+import "fmt"
+
+// database init error
+type InitError struct {
+ path string
+ err error
+}
+
+func (e *InitError) Error() string {
+ return fmt.Sprintf("Failed to open database : %s", e.path)
+}
+func (e *InitError) Unwrap() error { return e.err }
+
+func newInitError(path string, err error) error {
+ return &InitError{
+ path: path,
+ err: err,
+ }
+}
+
+// database migration error
+type MigrationError struct {
+ version int
+ err error
+}
+
+func (e *MigrationError) Error() string {
+ return fmt.Sprintf("Failed to migrate database to version %d : %s", e.version, e.err)
+}
+func (e *MigrationError) Unwrap() error { return e.err }
+
+func newMigrationError(version int, err error) error {
+ return &MigrationError{
+ version: version,
+ err: err,
+ }
+}
+
+// database transaction error
+type TransactionError struct {
+ msg string
+ err error
+}
+
+func (e *TransactionError) Error() string {
+ return fmt.Sprintf("Failed to perform transaction : %s", e.msg)
+}
+func (e *TransactionError) Unwrap() error { return e.err }
+
+func newTransactionError(msg string, err error) error {
+ return &TransactionError{
+ msg: msg,
+ err: err,
+ }
+}
diff --git a/pkg/database/error_test.go b/pkg/database/error_test.go
new file mode 100644
index 0000000..8031305
--- /dev/null
+++ b/pkg/database/error_test.go
@@ -0,0 +1,15 @@
+package database
+
+import "testing"
+
+func TestErrorsCoverage(t *testing.T) {
+ initErr := InitError{}
+ _ = initErr.Error()
+ _ = initErr.Unwrap()
+ migrationErr := MigrationError{}
+ _ = migrationErr.Error()
+ _ = migrationErr.Unwrap()
+ transactionErr := TransactionError{}
+ _ = transactionErr.Error()
+ _ = transactionErr.Unwrap()
+}
diff --git a/pkg/database/migrations.go b/pkg/database/migrations.go
new file mode 100644
index 0000000..cec3c41
--- /dev/null
+++ b/pkg/database/migrations.go
@@ -0,0 +1,19 @@
+package database
+
+import "database/sql"
+
+// allMigrations is the list of migrations to perform to get an up to date database
+// Order is important. Add new migrations at the end of the list.
+var allMigrations = []func(tx *sql.Tx) error{
+ func(tx *sql.Tx) (err error) {
+ sql := `
+ CREATE TABLE schema_version (
+ version INTEGER NOT NULL
+ );`
+ _, err = tx.Exec(sql)
+ return err
+ },
+}
+
+// This variable exists so that tests can override it
+var migrations = allMigrations