diff options
Diffstat (limited to '')
-rw-r--r-- | pkg/database/migrations.go | 2 | ||||
-rw-r--r-- | pkg/database/train_stop.go | 11 | ||||
-rw-r--r-- | pkg/database/train_stop_test.go | 21 |
3 files changed, 32 insertions, 2 deletions
diff --git a/pkg/database/migrations.go b/pkg/database/migrations.go index e977960..7c74bcf 100644 --- a/pkg/database/migrations.go +++ b/pkg/database/migrations.go @@ -26,7 +26,7 @@ var allMigrations = []func(tx *sql.Tx) error{ ); CREATE TABLE train_stops ( id TEXT PRIMARY KEY, - name TEXT NOT NULL UNIQUE + name TEXT NOT NULL );` _, err = tx.Exec(sql) return err diff --git a/pkg/database/train_stop.go b/pkg/database/train_stop.go index 452f6c3..ee7be00 100644 --- a/pkg/database/train_stop.go +++ b/pkg/database/train_stop.go @@ -4,6 +4,15 @@ import ( "git.adyxax.org/adyxax/trains/pkg/model" ) +func (env *DBEnv) CountTrainStops() (i int, err error) { + query := `SELECT count(*) from train_stops;` + err = env.db.QueryRow(query).Scan(&i) + if err != nil { + return 0, newQueryError("Could not run database query: most likely the schema is corrupted", err) + } + return +} + func (env *DBEnv) ReplaceAndImportTrainStops(trainStops []model.TrainStop) error { pre_query := `DELETE FROM train_stops;` query := ` @@ -28,7 +37,7 @@ func (env *DBEnv) ReplaceAndImportTrainStops(trainStops []model.TrainStop) error ) if err != nil { tx.Rollback() - return newQueryError("Could not run database query: ", err) + return newQueryError("Could not run database query", err) } } if err := tx.Commit(); err != nil { diff --git a/pkg/database/train_stop_test.go b/pkg/database/train_stop_test.go index f593d51..b3f9459 100644 --- a/pkg/database/train_stop_test.go +++ b/pkg/database/train_stop_test.go @@ -10,6 +10,27 @@ import ( "github.com/stretchr/testify/require" ) +func TestCountTrainStops(t *testing.T) { + trainStops := []model.TrainStop{ + model.TrainStop{Id: "id1", Name: "name1"}, + model.TrainStop{Id: "id2", Name: "name2"}, + } + // test db setup + db, err := InitDB("sqlite3", "file::memory:?_foreign_keys=on") + require.NoError(t, err) + // check sql error + i, err := db.CountTrainStops() + require.Error(t, err) + assert.Equalf(t, reflect.TypeOf(err), reflect.TypeOf(&QueryError{}), "Invalid error type. Got %s but expected %s", reflect.TypeOf(err), reflect.TypeOf(&QueryError{})) + // normal check + err = db.Migrate() + require.NoError(t, err) + err = db.ReplaceAndImportTrainStops(trainStops) + i, err = db.CountTrainStops() + require.NoError(t, err) + assert.Equal(t, i, len(trainStops)) +} + func TestReplaceAndImportTrainStops(t *testing.T) { // test db setup db, err := InitDB("sqlite3", "file::memory:?_foreign_keys=on") |