aboutsummaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--internal/webui/login.go4
-rw-r--r--internal/webui/utils.go4
-rw-r--r--internal/webui/utils_test.go23
-rw-r--r--internal/webui/webui.go2
4 files changed, 18 insertions, 15 deletions
diff --git a/internal/webui/login.go b/internal/webui/login.go
index 706cd29..84d477f 100644
--- a/internal/webui/login.go
+++ b/internal/webui/login.go
@@ -57,10 +57,10 @@ func loginHandler(e *env, w http.ResponseWriter, r *http.Request) error {
user, err := e.dbEnv.Login(&model.UserLogin{Username: username[0], Password: password[0]})
if err != nil {
switch e := err.(type) {
- case *database.PasswordError:
+ case database.PasswordError:
// TODO : handle in page
return e
- case *database.QueryError:
+ case database.QueryError:
return e
default:
return e
diff --git a/internal/webui/utils.go b/internal/webui/utils.go
index 28a7add..5bc1e04 100644
--- a/internal/webui/utils.go
+++ b/internal/webui/utils.go
@@ -41,8 +41,8 @@ type statusError struct {
err error
}
-func (e *statusError) Error() string { return e.err.Error() }
-func (e *statusError) Status() int { return e.code }
+func (e statusError) Error() string { return e.err.Error() }
+func (e statusError) Status() int { return e.code }
func newStatusError(code int, err error) error { return &statusError{code: code, err: err} }
type handler struct {
diff --git a/internal/webui/utils_test.go b/internal/webui/utils_test.go
index bd24dd5..6de1aad 100644
--- a/internal/webui/utils_test.go
+++ b/internal/webui/utils_test.go
@@ -10,10 +10,13 @@ import (
"testing"
"git.adyxax.org/adyxax/trains/pkg/model"
- "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
+func requireErrorTypeMatch(t *testing.T, err error, expected error) {
+ require.Equalf(t, reflect.TypeOf(err), reflect.TypeOf(expected), "Invalid error type. Got %s but expected %s", reflect.TypeOf(err), reflect.TypeOf(expected))
+}
+
type NavitiaMockClient struct {
departures []model.Departure
stops []model.Stop
@@ -46,7 +49,7 @@ type httpTestExpect struct {
bodyString string
location string
setsCookie bool
- err interface{}
+ err error
}
func runHttpTest(t *testing.T, e *env, h func(e *env, w http.ResponseWriter, r *http.Request) error, tc *httpTestCase) {
@@ -64,22 +67,22 @@ func runHttpTest(t *testing.T, e *env, h func(e *env, w http.ResponseWriter, r *
err := h(e, rr, req)
if tc.expect.err != nil {
require.Error(t, err)
- assert.Equalf(t, reflect.TypeOf(err), reflect.TypeOf(tc.expect.err), "Invalid error type. Got %s but expected %s", reflect.TypeOf(err), reflect.TypeOf(tc.expect.err))
+ requireErrorTypeMatch(t, err, tc.expect.err)
} else {
require.NoError(t, err)
- assert.Equal(t, tc.expect.code, rr.Code)
+ require.Equal(t, tc.expect.code, rr.Code)
if tc.expect.bodyString != "" {
- assert.Contains(t, rr.Body.String(), tc.expect.bodyString)
+ require.Contains(t, rr.Body.String(), tc.expect.bodyString)
}
if tc.expect.location != "" {
- assert.Contains(t, rr.HeaderMap, "Location")
- assert.Len(t, rr.HeaderMap["Location"], 1)
- assert.Equal(t, rr.HeaderMap["Location"][0], tc.expect.location)
+ require.Contains(t, rr.HeaderMap, "Location")
+ require.Len(t, rr.HeaderMap["Location"], 1)
+ require.Equal(t, rr.HeaderMap["Location"][0], tc.expect.location)
}
if tc.expect.setsCookie {
- assert.Contains(t, rr.HeaderMap, "Set-Cookie")
+ require.Contains(t, rr.HeaderMap, "Set-Cookie")
} else {
- assert.NotContains(t, rr.HeaderMap, "Set-Cookie")
+ require.NotContains(t, rr.HeaderMap, "Set-Cookie")
}
}
})
diff --git a/internal/webui/webui.go b/internal/webui/webui.go
index 63173c6..8685df5 100644
--- a/internal/webui/webui.go
+++ b/internal/webui/webui.go
@@ -26,7 +26,7 @@ func Run(c *config.Config, dbEnv *database.DBEnv) {
if stops, err := e.navitia.GetStops(); err == nil {
log.Printf("Updated trains stops data from navitia api, got %d results", len(stops))
if err = dbEnv.ReplaceAndImportStops(stops); err != nil {
- if dberr, ok := err.(*database.QueryError); ok {
+ if dberr, ok := err.(database.QueryError); ok {
log.Printf("%+v", dberr.Unwrap())
}
}