diff options
Diffstat (limited to '')
-rw-r--r-- | pkg/database/stop.go | 17 | ||||
-rw-r--r-- | pkg/database/stop_test.go | 50 |
2 files changed, 67 insertions, 0 deletions
diff --git a/pkg/database/stop.go b/pkg/database/stop.go index 47f9ce1..0c2547c 100644 --- a/pkg/database/stop.go +++ b/pkg/database/stop.go @@ -28,6 +28,23 @@ func (env *DBEnv) GetStop(id string) (*model.Stop, error) { return &stop, nil } +func (env *DBEnv) GetStops() (stops []model.Stop, err error) { + query := `SELECT id, name FROM stops;` + rows, err := env.db.Query(query) + if err != nil { + return nil, newQueryError("Could not run database query", err) + } + defer rows.Close() + for rows.Next() { + var stop model.Stop + if err := rows.Scan(&stop.Id, &stop.Name); err != nil { + return nil, newQueryError("Could not run database query", err) + } + stops = append(stops, stop) + } + return +} + func (env *DBEnv) ReplaceAndImportStops(trainStops []model.Stop) error { pre_query := `DELETE FROM stops;` query := ` diff --git a/pkg/database/stop_test.go b/pkg/database/stop_test.go index 2922cbe..a670393 100644 --- a/pkg/database/stop_test.go +++ b/pkg/database/stop_test.go @@ -1,6 +1,7 @@ package database import ( + "fmt" "reflect" "testing" @@ -52,6 +53,55 @@ func TestGetStop(t *testing.T) { assert.Equalf(t, reflect.TypeOf(err), reflect.TypeOf(&QueryError{}), "Invalid error type. Got %s but expected %s", reflect.TypeOf(err), reflect.TypeOf(&QueryError{})) } +func TestGetStops(t *testing.T) { + stops := []model.Stop{ + model.Stop{Id: "id1", Name: "name1"}, + model.Stop{Id: "id2", Name: "name2"}, + } + // test db setup + db, err := InitDB("sqlite3", "file::memory:?_foreign_keys=on") + require.NoError(t, err) + // error check + res, err := db.GetStops() + require.Error(t, err) + assert.Equalf(t, reflect.TypeOf(err), reflect.TypeOf(&QueryError{}), "Invalid error type. Got %s but expected %s", reflect.TypeOf(err), reflect.TypeOf(&QueryError{})) + // normal check + err = db.Migrate() + require.NoError(t, err) + err = db.ReplaceAndImportStops(stops) + res, err = db.GetStops() + require.NoError(t, err) + assert.Equal(t, res, stops) +} + +func TestGetStopsWithSQLMock(t *testing.T) { + // Transaction commit error + dbScanError, mockScanError, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + mockScanError.ExpectQuery(`SELECT id, name FROM stops;`).WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(nil, "b").RowError(1, fmt.Errorf("row error"))) + // Test cases + testCases := []struct { + name string + db *DBEnv + expectedError interface{} + }{ + {"query error when scanning row", &DBEnv{db: dbScanError}, &QueryError{}}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + _, err := tc.db.GetStops() + if tc.expectedError != nil { + require.Error(t, err) + assert.Equalf(t, reflect.TypeOf(err), reflect.TypeOf(tc.expectedError), "Invalid error type. Got %s but expected %s", reflect.TypeOf(err), reflect.TypeOf(tc.expectedError)) + } else { + require.NoError(t, err) + } + }) + } +} + func TestReplaceAndImportStops(t *testing.T) { // test db setup db, err := InitDB("sqlite3", "file::memory:?_foreign_keys=on") |