From 77904f19c6f29659d5adcf569336735b3b651912 Mon Sep 17 00:00:00 2001 From: Julien Dessaux Date: Mon, 12 Apr 2021 12:19:37 +0200 Subject: Implemented the user login database function --- pkg/database/password.go | 8 +++++ pkg/database/users.go | 25 ++++++++++++++++ pkg/database/users_test.go | 73 ++++++++++++++++++++++++++++++++++++++++++---- pkg/model/users.go | 5 ++++ 4 files changed, 106 insertions(+), 5 deletions(-) diff --git a/pkg/database/password.go b/pkg/database/password.go index d24183d..10abe4b 100644 --- a/pkg/database/password.go +++ b/pkg/database/password.go @@ -12,3 +12,11 @@ func hashPassword(password string) (string, error) { } return string(bytes), nil } + +func checkPassword(hash string, password string) error { + err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) + if err != nil { + return newPasswordError(err) + } + return nil +} diff --git a/pkg/database/users.go b/pkg/database/users.go index 3f75f62..10e6416 100644 --- a/pkg/database/users.go +++ b/pkg/database/users.go @@ -45,3 +45,28 @@ func (env *DBEnv) CreateUser(reg *model.UserRegistration) (*model.User, error) { } return &user, nil } + +// Login logs a user in if the password matches the hash in database +// a PasswordError is return if the passwords do not match +// a QueryError is returned if the username contains invalid sql characters like % +func (env *DBEnv) Login(login *model.UserLogin) (*model.User, error) { + query := `SELECT id, password, email FROM users WHERE username = $1;` + user := model.User{Username: login.Username} + var hash string + err := env.db.QueryRow( + query, + login.Username, + ).Scan( + &user.Id, + &hash, + &user.Email, + ) + if err != nil { + return nil, newQueryError("Could not run database query", err) + } + err = checkPassword(hash, login.Password) + if err != nil { + return nil, err + } + return &user, nil +} diff --git a/pkg/database/users_test.go b/pkg/database/users_test.go index 4b87d90..b1197db 100644 --- a/pkg/database/users_test.go +++ b/pkg/database/users_test.go @@ -32,18 +32,17 @@ func TestCreateUser(t *testing.T) { // Test cases testCases := []struct { name string - db *DBEnv input *model.UserRegistration expected int expectedError interface{} }{ - {"Normal user", db, &normalUser, 1, nil}, - {"Duplicate user", db, &normalUser, 0, &QueryError{}}, - {"Normal user 2", db, &normalUser2, 2, nil}, + {"Normal user", &normalUser, 1, nil}, + {"Duplicate user", &normalUser, 0, &QueryError{}}, + {"Normal user 2", &normalUser2, 2, nil}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - valid, err := tc.db.CreateUser(tc.input) + valid, err := db.CreateUser(tc.input) if tc.expectedError != nil { require.Error(t, err) assert.Equalf(t, reflect.TypeOf(err), reflect.TypeOf(tc.expectedError), "Invalid error type. Got %s but expected %s", reflect.TypeOf(err), reflect.TypeOf(tc.expectedError)) @@ -109,3 +108,67 @@ func TestCreateUserWithSQLMock(t *testing.T) { }) } } + +func TestLogin(t *testing.T) { + user1 := model.UserRegistration{ + Username: "user1", + Password: "user1_pass", + Email: "user1", + } + user2 := model.UserRegistration{ + Username: "user2", + Password: "user2_pass", + Email: "user2", + } + // test db setup + db, err := InitDB("sqlite3", "file::memory:?_foreign_keys=on") + require.NoError(t, err) + err = db.Migrate() + require.NoError(t, err) + _, err = db.CreateUser(&user1) + require.NoError(t, err) + _, err = db.CreateUser(&user2) + require.NoError(t, err) + // successful logins + loginUser1 := model.UserLogin{ + Username: user1.Username, + Password: user1.Password, + } + loginUser2 := model.UserLogin{ + Username: user2.Username, + Password: user2.Password, + } + // a failed login + failedUser1 := model.UserLogin{ + Username: user1.Username, + Password: user2.Password, + } + // Query error + queryError := model.UserLogin{ + Username: "%", + } + // Test cases + testCases := []struct { + name string + input *model.UserLogin + expectedError interface{} + }{ + {"login user1", &loginUser1, nil}, + {"login user2", &loginUser2, nil}, + {"failed login", &failedUser1, &PasswordError{}}, + {"query error", &queryError, &QueryError{}}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + valid, err := db.Login(tc.input) + if tc.expectedError != nil { + require.Error(t, err) + assert.Equalf(t, reflect.TypeOf(err), reflect.TypeOf(tc.expectedError), "Invalid error type. Got %s but expected %s", reflect.TypeOf(err), reflect.TypeOf(tc.expectedError)) + require.Nil(t, valid) + } else { + require.NoError(t, err) + assert.Equal(t, tc.input.Username, valid.Username) + } + }) + } +} diff --git a/pkg/model/users.go b/pkg/model/users.go index dfef67b..e4b4bf0 100644 --- a/pkg/model/users.go +++ b/pkg/model/users.go @@ -10,6 +10,11 @@ type User struct { Last_login_at *time.Time `json:"last_login_at"` } +type UserLogin struct { + Username string `json:"username"` + Password string `json:"password"` +} + type UserRegistration struct { Username string `json:"username"` Password string `json:"password"` -- cgit v1.2.3