From cf7f851a5f47bab8a979135c834dae9ead63a042 Mon Sep 17 00:00:00 2001 From: Julien Dessaux Date: Tue, 7 Sep 2021 17:32:47 +0200 Subject: Added logic to fetch train stops data if missing when starting the webserver --- internal/webui/webui.go | 14 ++++++++++++++ pkg/database/migrations.go | 2 +- pkg/database/train_stop.go | 11 ++++++++++- pkg/database/train_stop_test.go | 21 +++++++++++++++++++++ pkg/navitia_api_client/train_stops.go | 8 ++++---- pkg/navitia_api_client/train_stops_test.go | 6 ++++-- 6 files changed, 54 insertions(+), 8 deletions(-) diff --git a/internal/webui/webui.go b/internal/webui/webui.go index 6ce5bb4..1a251f4 100644 --- a/internal/webui/webui.go +++ b/internal/webui/webui.go @@ -19,6 +19,20 @@ func Run(c *config.Config, dbEnv *database.DBEnv) { http.Handle("/login", handler{&e, loginHandler}) http.Handle("/static/", http.FileServer(http.FS(staticFS))) + if i, err := dbEnv.CountTrainStops(); err == nil && i == 0 { + log.Printf("No trains stops data found, updating...") + if trainStops, err := e.navitia.GetTrainStops(); err == nil { + log.Printf("Updated trains stops data from navitia api, got %d results", len(trainStops)) + if err = dbEnv.ReplaceAndImportTrainStops(trainStops); err != nil { + if dberr, ok := err.(*database.QueryError); ok { + log.Printf("%+v", dberr.Unwrap()) + } + } + } else { + log.Printf("Failed to get trains stops data from navitia api : %+v", err) + } + } + listenStr := c.Address + ":" + c.Port log.Printf("Starting webui on %s", listenStr) log.Fatal(http.ListenAndServe(listenStr, nil)) diff --git a/pkg/database/migrations.go b/pkg/database/migrations.go index e977960..7c74bcf 100644 --- a/pkg/database/migrations.go +++ b/pkg/database/migrations.go @@ -26,7 +26,7 @@ var allMigrations = []func(tx *sql.Tx) error{ ); CREATE TABLE train_stops ( id TEXT PRIMARY KEY, - name TEXT NOT NULL UNIQUE + name TEXT NOT NULL );` _, err = tx.Exec(sql) return err diff --git a/pkg/database/train_stop.go b/pkg/database/train_stop.go index 452f6c3..ee7be00 100644 --- a/pkg/database/train_stop.go +++ b/pkg/database/train_stop.go @@ -4,6 +4,15 @@ import ( "git.adyxax.org/adyxax/trains/pkg/model" ) +func (env *DBEnv) CountTrainStops() (i int, err error) { + query := `SELECT count(*) from train_stops;` + err = env.db.QueryRow(query).Scan(&i) + if err != nil { + return 0, newQueryError("Could not run database query: most likely the schema is corrupted", err) + } + return +} + func (env *DBEnv) ReplaceAndImportTrainStops(trainStops []model.TrainStop) error { pre_query := `DELETE FROM train_stops;` query := ` @@ -28,7 +37,7 @@ func (env *DBEnv) ReplaceAndImportTrainStops(trainStops []model.TrainStop) error ) if err != nil { tx.Rollback() - return newQueryError("Could not run database query: ", err) + return newQueryError("Could not run database query", err) } } if err := tx.Commit(); err != nil { diff --git a/pkg/database/train_stop_test.go b/pkg/database/train_stop_test.go index f593d51..b3f9459 100644 --- a/pkg/database/train_stop_test.go +++ b/pkg/database/train_stop_test.go @@ -10,6 +10,27 @@ import ( "github.com/stretchr/testify/require" ) +func TestCountTrainStops(t *testing.T) { + trainStops := []model.TrainStop{ + model.TrainStop{Id: "id1", Name: "name1"}, + model.TrainStop{Id: "id2", Name: "name2"}, + } + // test db setup + db, err := InitDB("sqlite3", "file::memory:?_foreign_keys=on") + require.NoError(t, err) + // check sql error + i, err := db.CountTrainStops() + require.Error(t, err) + assert.Equalf(t, reflect.TypeOf(err), reflect.TypeOf(&QueryError{}), "Invalid error type. Got %s but expected %s", reflect.TypeOf(err), reflect.TypeOf(&QueryError{})) + // normal check + err = db.Migrate() + require.NoError(t, err) + err = db.ReplaceAndImportTrainStops(trainStops) + i, err = db.CountTrainStops() + require.NoError(t, err) + assert.Equal(t, i, len(trainStops)) +} + func TestReplaceAndImportTrainStops(t *testing.T) { // test db setup db, err := InitDB("sqlite3", "file::memory:?_foreign_keys=on") diff --git a/pkg/navitia_api_client/train_stops.go b/pkg/navitia_api_client/train_stops.go index c8a8ff2..a31ccf7 100644 --- a/pkg/navitia_api_client/train_stops.go +++ b/pkg/navitia_api_client/train_stops.go @@ -3,7 +3,6 @@ package navitia_api_client import ( "encoding/json" "fmt" - "log" "net/http" "git.adyxax.org/adyxax/trains/pkg/model" @@ -22,7 +21,7 @@ type TrainStopsResponse struct { Codes []interface{} `json:"codes"` Links []interface{} `json:"links"` Coord interface{} `json:"coord"` - Label interface{} `json:"label"` + Label string `json:"label"` Timezone interface{} `json:"timezone"` AdministrativeRegion interface{} `json:"administrative_regions"` } `json:"stop_areas"` @@ -53,10 +52,11 @@ func getTrainStopsPage(c *NavitiaClient, i int) (trainStops []model.TrainStop, e return nil, newJsonDecodeError("GetTrainStops ", err) } for i := 0; i < len(data.StopAreas); i++ { - trainStops = append(trainStops, model.TrainStop{data.StopAreas[i].ID, data.StopAreas[i].Name}) + if data.StopAreas[i].Label != "" { + trainStops = append(trainStops, model.TrainStop{data.StopAreas[i].ID, data.StopAreas[i].Label}) + } } if data.Pagination.ItemsOnPage+data.Pagination.ItemsPerPage*data.Pagination.StartPage < data.Pagination.TotalResult { - log.Printf("pagination %d\n", i) tss, err := getTrainStopsPage(c, i+1) if err != nil { return nil, err diff --git a/pkg/navitia_api_client/train_stops_test.go b/pkg/navitia_api_client/train_stops_test.go index 93fc43f..9e1c982 100644 --- a/pkg/navitia_api_client/train_stops_test.go +++ b/pkg/navitia_api_client/train_stops_test.go @@ -76,7 +76,8 @@ func TestGetTrainStops(t *testing.T) { if err != nil { t.Fatalf("could not get train stops : %s", err) } - if len(trainStops) != 4 { + // 4 records but one is empty (navitia api quirk) + if len(trainStops) != 3 { t.Fatalf("did not decode train stops properly, got %d train stops when expected 4", len(trainStops)) } // normal request in multiple pages @@ -90,7 +91,8 @@ func TestGetTrainStops(t *testing.T) { if err != nil { t.Fatalf("could not get train stops : %+v", err) } - if len(trainStops) != 12 { + // 12 records but one is empty (navitia api quirk) + if len(trainStops) != 11 { t.Fatalf("did not decode train stops properly, got %d train stops when expected 4", len(trainStops)) } // failing request in multiple pages with last one missing -- cgit v1.2.3