diff options
Diffstat (limited to '')
-rw-r--r-- | go.mod | 1 | ||||
-rw-r--r-- | go.sum | 7 | ||||
-rw-r--r-- | pkg/database/error.go | 36 | ||||
-rw-r--r-- | pkg/database/error_test.go | 6 | ||||
-rw-r--r-- | pkg/database/migrations.go | 8 | ||||
-rw-r--r-- | pkg/database/password.go | 14 | ||||
-rw-r--r-- | pkg/database/users.go | 47 | ||||
-rw-r--r-- | pkg/database/users_test.go | 111 | ||||
-rw-r--r-- | pkg/model/users.go | 17 |
9 files changed, 246 insertions, 1 deletions
@@ -8,6 +8,7 @@ require ( github.com/kr/pretty v0.2.1 // indirect github.com/mattn/go-sqlite3 v1.14.6 github.com/stretchr/testify v1.7.0 + golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 // indirect gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b ) @@ -15,6 +15,13 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 h1:It14KIkyBFYkHkwZ7k45minvA9aorojkyjGk9KJ5B/w= +golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/pkg/database/error.go b/pkg/database/error.go index 1681dd5..0c19752 100644 --- a/pkg/database/error.go +++ b/pkg/database/error.go @@ -38,6 +38,40 @@ func newMigrationError(version int, err error) error { } } +// Password hash error +type PasswordError struct { + err error +} + +func (e *PasswordError) Error() string { + return fmt.Sprintf("Failed to hash password : %+v", e.err) +} +func (e *PasswordError) Unwrap() error { return e.err } + +func newPasswordError(err error) error { + return &PasswordError{ + err: err, + } +} + +// database query error +type QueryError struct { + msg string + err error +} + +func (e *QueryError) Error() string { + return fmt.Sprintf("Failed to perform query : %s", e.msg) +} +func (e *QueryError) Unwrap() error { return e.err } + +func newQueryError(msg string, err error) error { + return &QueryError{ + msg: msg, + err: err, + } +} + // database transaction error type TransactionError struct { msg string @@ -45,7 +79,7 @@ type TransactionError struct { } func (e *TransactionError) Error() string { - return fmt.Sprintf("Failed to perform transaction : %s", e.msg) + return fmt.Sprintf("Failed to perform query : %s", e.msg) } func (e *TransactionError) Unwrap() error { return e.err } diff --git a/pkg/database/error_test.go b/pkg/database/error_test.go index 8031305..fca19d5 100644 --- a/pkg/database/error_test.go +++ b/pkg/database/error_test.go @@ -9,6 +9,12 @@ func TestErrorsCoverage(t *testing.T) { migrationErr := MigrationError{} _ = migrationErr.Error() _ = migrationErr.Unwrap() + passwordError := PasswordError{} + _ = passwordError.Error() + _ = passwordError.Unwrap() + queryErr := QueryError{} + _ = queryErr.Error() + _ = queryErr.Unwrap() transactionErr := TransactionError{} _ = transactionErr.Error() _ = transactionErr.Unwrap() diff --git a/pkg/database/migrations.go b/pkg/database/migrations.go index cec3c41..8d23ed2 100644 --- a/pkg/database/migrations.go +++ b/pkg/database/migrations.go @@ -9,6 +9,14 @@ var allMigrations = []func(tx *sql.Tx) error{ sql := ` CREATE TABLE schema_version ( version INTEGER NOT NULL + ); + CREATE TABLE users ( + id INTEGER PRIMARY KEY, + username TEXT NOT NULL UNIQUE, + password TEXT, + email TEXT, + created_at DATE DEFAULT (datetime('now')), + last_login_at DATE DEFAULT NULL );` _, err = tx.Exec(sql) return err diff --git a/pkg/database/password.go b/pkg/database/password.go new file mode 100644 index 0000000..d24183d --- /dev/null +++ b/pkg/database/password.go @@ -0,0 +1,14 @@ +package database + +import "golang.org/x/crypto/bcrypt" + +// To allow for testing the error case (bad random is hard to trigger) +var passwordFunction = bcrypt.GenerateFromPassword + +func hashPassword(password string) (string, error) { + bytes, err := passwordFunction([]byte(password), bcrypt.DefaultCost) + if err != nil { + return "", newPasswordError(err) + } + return string(bytes), nil +} diff --git a/pkg/database/users.go b/pkg/database/users.go new file mode 100644 index 0000000..3f75f62 --- /dev/null +++ b/pkg/database/users.go @@ -0,0 +1,47 @@ +package database + +import ( + "git.adyxax.org/adyxax/trains/pkg/model" +) + +// Creates a new user in the database +// a QueryError is return if the username already exists (database constraints not met) +func (env *DBEnv) CreateUser(reg *model.UserRegistration) (*model.User, error) { + hash, err := hashPassword(reg.Password) + if err != nil { + return nil, err + } + query := ` + INSERT INTO users + (username, password, email) + VALUES + ($1, $2, $3);` + tx, err := env.db.Begin() + if err != nil { + return nil, newTransactionError("Could not Begin()", err) + } + result, err := tx.Exec( + query, + reg.Username, + hash, + reg.Email, + ) + if err != nil { + tx.Rollback() + return nil, newQueryError("Could not run database query, most likely the username already exists", err) + } + id, err := result.LastInsertId() + if err != nil { + tx.Rollback() + return nil, newTransactionError("Could not get LastInsertId, the database driver does not support this feature", err) + } + if err := tx.Commit(); err != nil { + return nil, newTransactionError("Could not commit transaction", err) + } + user := model.User{ + Id: int(id), + Username: reg.Username, + Email: reg.Email, + } + return &user, nil +} diff --git a/pkg/database/users_test.go b/pkg/database/users_test.go new file mode 100644 index 0000000..4b87d90 --- /dev/null +++ b/pkg/database/users_test.go @@ -0,0 +1,111 @@ +package database + +import ( + "reflect" + "testing" + + "git.adyxax.org/adyxax/trains/pkg/model" + "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/bcrypt" +) + +func TestCreateUser(t *testing.T) { + // test db setup + db, err := InitDB("sqlite3", "file::memory:?_foreign_keys=on") + require.NoError(t, err) + err = db.Migrate() + require.NoError(t, err) + // a normal user + normalUser := model.UserRegistration{ + Username: "testUsername", + Password: "testPassword", + Email: "testEmail", + } + // Test another normal user + normalUser2 := model.UserRegistration{ + Username: "testUsername2", + Password: "yei*j2Ien2xa?g6bieh~i=asoo5ii7Bi", + Email: "testEmail", + } + // Test cases + testCases := []struct { + name string + db *DBEnv + input *model.UserRegistration + expected int + expectedError interface{} + }{ + {"Normal user", db, &normalUser, 1, nil}, + {"Duplicate user", db, &normalUser, 0, &QueryError{}}, + {"Normal user 2", db, &normalUser2, 2, nil}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + valid, err := tc.db.CreateUser(tc.input) + if tc.expectedError != nil { + require.Error(t, err) + assert.Equalf(t, reflect.TypeOf(err), reflect.TypeOf(tc.expectedError), "Invalid error type. Got %s but expected %s", reflect.TypeOf(err), reflect.TypeOf(tc.expectedError)) + require.Nil(t, valid) + } else { + require.NoError(t, err) + assert.Equal(t, tc.expected, valid.Id) + } + }) + } + // Test for bad password + passwordFunction = func(password []byte, cost int) ([]byte, error) { return nil, newPasswordError(nil) } + valid, err := db.CreateUser(&normalUser) + passwordFunction = bcrypt.GenerateFromPassword + require.Error(t, err) + require.Nil(t, valid) + assert.Equalf(t, reflect.TypeOf(err), reflect.TypeOf(&PasswordError{}), "Invalid error type. Got %s but expected %s", reflect.TypeOf(err), reflect.TypeOf(&PasswordError{})) +} + +func TestCreateUserWithSQLMock(t *testing.T) { + // Transaction begin error + dbBeginError, _, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer dbBeginError.Close() + // Transaction LastInsertId not supported + dbLastInsertIdError, mockLastInsertIdError, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer dbLastInsertIdError.Close() + mockLastInsertIdError.ExpectBegin() + mockLastInsertIdError.ExpectExec(`INSERT INTO`).WillReturnResult(sqlmock.NewErrorResult(&TransactionError{"test", nil})) + // Transaction commit error + dbCommitError, mockCommitError, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer dbCommitError.Close() + mockCommitError.ExpectBegin() + mockCommitError.ExpectExec(`INSERT INTO`).WillReturnResult(sqlmock.NewResult(1, 1)) + // Test cases + testCases := []struct { + name string + db *DBEnv + expectedError interface{} + }{ + {"begin transaction error", &DBEnv{db: dbBeginError}, &TransactionError{}}, + {"last insert id transaction error", &DBEnv{db: dbLastInsertIdError}, &TransactionError{}}, + {"commit transaction error", &DBEnv{db: dbCommitError}, &TransactionError{}}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + valid, err := tc.db.CreateUser(&model.UserRegistration{}) + if tc.expectedError != nil { + require.Error(t, err) + assert.Equalf(t, reflect.TypeOf(err), reflect.TypeOf(tc.expectedError), "Invalid error type. Got %s but expected %s", reflect.TypeOf(err), reflect.TypeOf(tc.expectedError)) + require.Nil(t, valid) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/pkg/model/users.go b/pkg/model/users.go new file mode 100644 index 0000000..dfef67b --- /dev/null +++ b/pkg/model/users.go @@ -0,0 +1,17 @@ +package model + +import "time" + +type User struct { + Id int `json:"id"` + Username string `json:"username"` + Email string `json:"email"` + Created_at *time.Time `json:"created_at"` + Last_login_at *time.Time `json:"last_login_at"` +} + +type UserRegistration struct { + Username string `json:"username"` + Password string `json:"password"` + Email string `json:"email"` +} |