aboutsummaryrefslogtreecommitdiff
path: root/pkg/database
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/database')
-rw-r--r--pkg/database/migrations.go2
-rw-r--r--pkg/database/train_stop.go11
-rw-r--r--pkg/database/train_stop_test.go21
3 files changed, 32 insertions, 2 deletions
diff --git a/pkg/database/migrations.go b/pkg/database/migrations.go
index e977960..7c74bcf 100644
--- a/pkg/database/migrations.go
+++ b/pkg/database/migrations.go
@@ -26,7 +26,7 @@ var allMigrations = []func(tx *sql.Tx) error{
);
CREATE TABLE train_stops (
id TEXT PRIMARY KEY,
- name TEXT NOT NULL UNIQUE
+ name TEXT NOT NULL
);`
_, err = tx.Exec(sql)
return err
diff --git a/pkg/database/train_stop.go b/pkg/database/train_stop.go
index 452f6c3..ee7be00 100644
--- a/pkg/database/train_stop.go
+++ b/pkg/database/train_stop.go
@@ -4,6 +4,15 @@ import (
"git.adyxax.org/adyxax/trains/pkg/model"
)
+func (env *DBEnv) CountTrainStops() (i int, err error) {
+ query := `SELECT count(*) from train_stops;`
+ err = env.db.QueryRow(query).Scan(&i)
+ if err != nil {
+ return 0, newQueryError("Could not run database query: most likely the schema is corrupted", err)
+ }
+ return
+}
+
func (env *DBEnv) ReplaceAndImportTrainStops(trainStops []model.TrainStop) error {
pre_query := `DELETE FROM train_stops;`
query := `
@@ -28,7 +37,7 @@ func (env *DBEnv) ReplaceAndImportTrainStops(trainStops []model.TrainStop) error
)
if err != nil {
tx.Rollback()
- return newQueryError("Could not run database query: ", err)
+ return newQueryError("Could not run database query", err)
}
}
if err := tx.Commit(); err != nil {
diff --git a/pkg/database/train_stop_test.go b/pkg/database/train_stop_test.go
index f593d51..b3f9459 100644
--- a/pkg/database/train_stop_test.go
+++ b/pkg/database/train_stop_test.go
@@ -10,6 +10,27 @@ import (
"github.com/stretchr/testify/require"
)
+func TestCountTrainStops(t *testing.T) {
+ trainStops := []model.TrainStop{
+ model.TrainStop{Id: "id1", Name: "name1"},
+ model.TrainStop{Id: "id2", Name: "name2"},
+ }
+ // test db setup
+ db, err := InitDB("sqlite3", "file::memory:?_foreign_keys=on")
+ require.NoError(t, err)
+ // check sql error
+ i, err := db.CountTrainStops()
+ require.Error(t, err)
+ assert.Equalf(t, reflect.TypeOf(err), reflect.TypeOf(&QueryError{}), "Invalid error type. Got %s but expected %s", reflect.TypeOf(err), reflect.TypeOf(&QueryError{}))
+ // normal check
+ err = db.Migrate()
+ require.NoError(t, err)
+ err = db.ReplaceAndImportTrainStops(trainStops)
+ i, err = db.CountTrainStops()
+ require.NoError(t, err)
+ assert.Equal(t, i, len(trainStops))
+}
+
func TestReplaceAndImportTrainStops(t *testing.T) {
// test db setup
db, err := InitDB("sqlite3", "file::memory:?_foreign_keys=on")