aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJulien Dessaux2021-09-07 17:32:47 +0200
committerJulien Dessaux2021-09-07 17:32:47 +0200
commitcf7f851a5f47bab8a979135c834dae9ead63a042 (patch)
treedc88a7650ae4ae769afcd3ac35d3b10680a08d84
parentAdded navitia code to fetch train stops names (diff)
downloadtrains-cf7f851a5f47bab8a979135c834dae9ead63a042.tar.gz
trains-cf7f851a5f47bab8a979135c834dae9ead63a042.tar.bz2
trains-cf7f851a5f47bab8a979135c834dae9ead63a042.zip
Added logic to fetch train stops data if missing when starting the webserver
-rw-r--r--internal/webui/webui.go14
-rw-r--r--pkg/database/migrations.go2
-rw-r--r--pkg/database/train_stop.go11
-rw-r--r--pkg/database/train_stop_test.go21
-rw-r--r--pkg/navitia_api_client/train_stops.go8
-rw-r--r--pkg/navitia_api_client/train_stops_test.go6
6 files changed, 54 insertions, 8 deletions
diff --git a/internal/webui/webui.go b/internal/webui/webui.go
index 6ce5bb4..1a251f4 100644
--- a/internal/webui/webui.go
+++ b/internal/webui/webui.go
@@ -19,6 +19,20 @@ func Run(c *config.Config, dbEnv *database.DBEnv) {
http.Handle("/login", handler{&e, loginHandler})
http.Handle("/static/", http.FileServer(http.FS(staticFS)))
+ if i, err := dbEnv.CountTrainStops(); err == nil && i == 0 {
+ log.Printf("No trains stops data found, updating...")
+ if trainStops, err := e.navitia.GetTrainStops(); err == nil {
+ log.Printf("Updated trains stops data from navitia api, got %d results", len(trainStops))
+ if err = dbEnv.ReplaceAndImportTrainStops(trainStops); err != nil {
+ if dberr, ok := err.(*database.QueryError); ok {
+ log.Printf("%+v", dberr.Unwrap())
+ }
+ }
+ } else {
+ log.Printf("Failed to get trains stops data from navitia api : %+v", err)
+ }
+ }
+
listenStr := c.Address + ":" + c.Port
log.Printf("Starting webui on %s", listenStr)
log.Fatal(http.ListenAndServe(listenStr, nil))
diff --git a/pkg/database/migrations.go b/pkg/database/migrations.go
index e977960..7c74bcf 100644
--- a/pkg/database/migrations.go
+++ b/pkg/database/migrations.go
@@ -26,7 +26,7 @@ var allMigrations = []func(tx *sql.Tx) error{
);
CREATE TABLE train_stops (
id TEXT PRIMARY KEY,
- name TEXT NOT NULL UNIQUE
+ name TEXT NOT NULL
);`
_, err = tx.Exec(sql)
return err
diff --git a/pkg/database/train_stop.go b/pkg/database/train_stop.go
index 452f6c3..ee7be00 100644
--- a/pkg/database/train_stop.go
+++ b/pkg/database/train_stop.go
@@ -4,6 +4,15 @@ import (
"git.adyxax.org/adyxax/trains/pkg/model"
)
+func (env *DBEnv) CountTrainStops() (i int, err error) {
+ query := `SELECT count(*) from train_stops;`
+ err = env.db.QueryRow(query).Scan(&i)
+ if err != nil {
+ return 0, newQueryError("Could not run database query: most likely the schema is corrupted", err)
+ }
+ return
+}
+
func (env *DBEnv) ReplaceAndImportTrainStops(trainStops []model.TrainStop) error {
pre_query := `DELETE FROM train_stops;`
query := `
@@ -28,7 +37,7 @@ func (env *DBEnv) ReplaceAndImportTrainStops(trainStops []model.TrainStop) error
)
if err != nil {
tx.Rollback()
- return newQueryError("Could not run database query: ", err)
+ return newQueryError("Could not run database query", err)
}
}
if err := tx.Commit(); err != nil {
diff --git a/pkg/database/train_stop_test.go b/pkg/database/train_stop_test.go
index f593d51..b3f9459 100644
--- a/pkg/database/train_stop_test.go
+++ b/pkg/database/train_stop_test.go
@@ -10,6 +10,27 @@ import (
"github.com/stretchr/testify/require"
)
+func TestCountTrainStops(t *testing.T) {
+ trainStops := []model.TrainStop{
+ model.TrainStop{Id: "id1", Name: "name1"},
+ model.TrainStop{Id: "id2", Name: "name2"},
+ }
+ // test db setup
+ db, err := InitDB("sqlite3", "file::memory:?_foreign_keys=on")
+ require.NoError(t, err)
+ // check sql error
+ i, err := db.CountTrainStops()
+ require.Error(t, err)
+ assert.Equalf(t, reflect.TypeOf(err), reflect.TypeOf(&QueryError{}), "Invalid error type. Got %s but expected %s", reflect.TypeOf(err), reflect.TypeOf(&QueryError{}))
+ // normal check
+ err = db.Migrate()
+ require.NoError(t, err)
+ err = db.ReplaceAndImportTrainStops(trainStops)
+ i, err = db.CountTrainStops()
+ require.NoError(t, err)
+ assert.Equal(t, i, len(trainStops))
+}
+
func TestReplaceAndImportTrainStops(t *testing.T) {
// test db setup
db, err := InitDB("sqlite3", "file::memory:?_foreign_keys=on")
diff --git a/pkg/navitia_api_client/train_stops.go b/pkg/navitia_api_client/train_stops.go
index c8a8ff2..a31ccf7 100644
--- a/pkg/navitia_api_client/train_stops.go
+++ b/pkg/navitia_api_client/train_stops.go
@@ -3,7 +3,6 @@ package navitia_api_client
import (
"encoding/json"
"fmt"
- "log"
"net/http"
"git.adyxax.org/adyxax/trains/pkg/model"
@@ -22,7 +21,7 @@ type TrainStopsResponse struct {
Codes []interface{} `json:"codes"`
Links []interface{} `json:"links"`
Coord interface{} `json:"coord"`
- Label interface{} `json:"label"`
+ Label string `json:"label"`
Timezone interface{} `json:"timezone"`
AdministrativeRegion interface{} `json:"administrative_regions"`
} `json:"stop_areas"`
@@ -53,10 +52,11 @@ func getTrainStopsPage(c *NavitiaClient, i int) (trainStops []model.TrainStop, e
return nil, newJsonDecodeError("GetTrainStops ", err)
}
for i := 0; i < len(data.StopAreas); i++ {
- trainStops = append(trainStops, model.TrainStop{data.StopAreas[i].ID, data.StopAreas[i].Name})
+ if data.StopAreas[i].Label != "" {
+ trainStops = append(trainStops, model.TrainStop{data.StopAreas[i].ID, data.StopAreas[i].Label})
+ }
}
if data.Pagination.ItemsOnPage+data.Pagination.ItemsPerPage*data.Pagination.StartPage < data.Pagination.TotalResult {
- log.Printf("pagination %d\n", i)
tss, err := getTrainStopsPage(c, i+1)
if err != nil {
return nil, err
diff --git a/pkg/navitia_api_client/train_stops_test.go b/pkg/navitia_api_client/train_stops_test.go
index 93fc43f..9e1c982 100644
--- a/pkg/navitia_api_client/train_stops_test.go
+++ b/pkg/navitia_api_client/train_stops_test.go
@@ -76,7 +76,8 @@ func TestGetTrainStops(t *testing.T) {
if err != nil {
t.Fatalf("could not get train stops : %s", err)
}
- if len(trainStops) != 4 {
+ // 4 records but one is empty (navitia api quirk)
+ if len(trainStops) != 3 {
t.Fatalf("did not decode train stops properly, got %d train stops when expected 4", len(trainStops))
}
// normal request in multiple pages
@@ -90,7 +91,8 @@ func TestGetTrainStops(t *testing.T) {
if err != nil {
t.Fatalf("could not get train stops : %+v", err)
}
- if len(trainStops) != 12 {
+ // 12 records but one is empty (navitia api quirk)
+ if len(trainStops) != 11 {
t.Fatalf("did not decode train stops properly, got %d train stops when expected 4", len(trainStops))
}
// failing request in multiple pages with last one missing