diff options
Diffstat (limited to 'navitia_api_client')
-rw-r--r-- | navitia_api_client/client.go | 10 | ||||
-rw-r--r-- | navitia_api_client/client_test.go | 1 | ||||
-rw-r--r-- | navitia_api_client/departures.go | 16 | ||||
-rw-r--r-- | navitia_api_client/departures_test.go | 9 |
4 files changed, 35 insertions, 1 deletions
diff --git a/navitia_api_client/client.go b/navitia_api_client/client.go index d08ca19..aef8d0e 100644 --- a/navitia_api_client/client.go +++ b/navitia_api_client/client.go @@ -3,12 +3,21 @@ package navitia_api_client import ( "fmt" "net/http" + "sync" "time" ) type Client struct { baseURL string httpClient *http.Client + + mutex sync.Mutex + cache map[string]cachedResult +} + +type cachedResult struct { + ts time.Time + result interface{} } func NewClient(token string) *Client { @@ -17,5 +26,6 @@ func NewClient(token string) *Client { httpClient: &http.Client{ Timeout: time.Minute, }, + cache: make(map[string]cachedResult), } } diff --git a/navitia_api_client/client_test.go b/navitia_api_client/client_test.go index ec51a06..332b222 100644 --- a/navitia_api_client/client_test.go +++ b/navitia_api_client/client_test.go @@ -13,6 +13,7 @@ func NewTestClient(ts *httptest.Server) *Client { return &Client{ baseURL: fmt.Sprintf(ts.URL), httpClient: ts.Client(), + cache: make(map[string]cachedResult), } } diff --git a/navitia_api_client/departures.go b/navitia_api_client/departures.go index a6dcf53..7738940 100644 --- a/navitia_api_client/departures.go +++ b/navitia_api_client/departures.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "net/http" + "time" ) type DeparturesResponse struct { @@ -43,7 +44,16 @@ type DeparturesResponse struct { } func (c *Client) GetDepartures() (departures *DeparturesResponse, err error) { - req, err := http.NewRequest("GET", fmt.Sprintf("%s/coverage/sncf/stop_areas/stop_area:SNCF:87723502/departures", c.baseURL), nil) + request := fmt.Sprintf("%s/coverage/sncf/stop_areas/stop_area:SNCF:87723502/departures", c.baseURL) + start := time.Now() + c.mutex.Lock() + defer c.mutex.Unlock() + if cachedResult, ok := c.cache[request]; ok { + if start.Sub(cachedResult.ts) < 60*1000*1000*1000 { + return cachedResult.result.(*DeparturesResponse), nil + } + } + req, err := http.NewRequest("GET", request, nil) if err != nil { return nil, err } @@ -55,5 +65,9 @@ func (c *Client) GetDepartures() (departures *DeparturesResponse, err error) { if err = json.NewDecoder(resp.Body).Decode(&departures); err != nil { return nil, err } + c.cache[request] = cachedResult{ + ts: start, + result: departures, + } return } diff --git a/navitia_api_client/departures_test.go b/navitia_api_client/departures_test.go index 2c87429..a1658d2 100644 --- a/navitia_api_client/departures_test.go +++ b/navitia_api_client/departures_test.go @@ -45,4 +45,13 @@ func TestGetDepartures(t *testing.T) { if len(departures.Departures) != 10 { t.Fatalf("did not decode normal-crepieux departures properly, got %d departures when expected 10", len(departures.Departures)) } + // test the cache (assuming the test takes less than 60 seconds (and it really should) it will be accurate) + ts.Close() + departures, err = client.GetDepartures() + if err != nil { + t.Fatalf("could not get normal-crepieux departures : %s", err) + } + if len(departures.Departures) != 10 { + t.Fatalf("did not decode normal-crepieux departures properly, got %d departures when expected 10", len(departures.Departures)) + } } |