diff options
Diffstat (limited to 'pkg/database')
-rw-r--r-- | pkg/database/error.go | 36 | ||||
-rw-r--r-- | pkg/database/error_test.go | 6 | ||||
-rw-r--r-- | pkg/database/migrations.go | 8 | ||||
-rw-r--r-- | pkg/database/password.go | 14 | ||||
-rw-r--r-- | pkg/database/users.go | 47 | ||||
-rw-r--r-- | pkg/database/users_test.go | 111 |
6 files changed, 221 insertions, 1 deletions
diff --git a/pkg/database/error.go b/pkg/database/error.go index 1681dd5..0c19752 100644 --- a/pkg/database/error.go +++ b/pkg/database/error.go @@ -38,6 +38,40 @@ func newMigrationError(version int, err error) error { } } +// Password hash error +type PasswordError struct { + err error +} + +func (e *PasswordError) Error() string { + return fmt.Sprintf("Failed to hash password : %+v", e.err) +} +func (e *PasswordError) Unwrap() error { return e.err } + +func newPasswordError(err error) error { + return &PasswordError{ + err: err, + } +} + +// database query error +type QueryError struct { + msg string + err error +} + +func (e *QueryError) Error() string { + return fmt.Sprintf("Failed to perform query : %s", e.msg) +} +func (e *QueryError) Unwrap() error { return e.err } + +func newQueryError(msg string, err error) error { + return &QueryError{ + msg: msg, + err: err, + } +} + // database transaction error type TransactionError struct { msg string @@ -45,7 +79,7 @@ type TransactionError struct { } func (e *TransactionError) Error() string { - return fmt.Sprintf("Failed to perform transaction : %s", e.msg) + return fmt.Sprintf("Failed to perform query : %s", e.msg) } func (e *TransactionError) Unwrap() error { return e.err } diff --git a/pkg/database/error_test.go b/pkg/database/error_test.go index 8031305..fca19d5 100644 --- a/pkg/database/error_test.go +++ b/pkg/database/error_test.go @@ -9,6 +9,12 @@ func TestErrorsCoverage(t *testing.T) { migrationErr := MigrationError{} _ = migrationErr.Error() _ = migrationErr.Unwrap() + passwordError := PasswordError{} + _ = passwordError.Error() + _ = passwordError.Unwrap() + queryErr := QueryError{} + _ = queryErr.Error() + _ = queryErr.Unwrap() transactionErr := TransactionError{} _ = transactionErr.Error() _ = transactionErr.Unwrap() diff --git a/pkg/database/migrations.go b/pkg/database/migrations.go index cec3c41..8d23ed2 100644 --- a/pkg/database/migrations.go +++ b/pkg/database/migrations.go @@ -9,6 +9,14 @@ var allMigrations = []func(tx *sql.Tx) error{ sql := ` CREATE TABLE schema_version ( version INTEGER NOT NULL + ); + CREATE TABLE users ( + id INTEGER PRIMARY KEY, + username TEXT NOT NULL UNIQUE, + password TEXT, + email TEXT, + created_at DATE DEFAULT (datetime('now')), + last_login_at DATE DEFAULT NULL );` _, err = tx.Exec(sql) return err diff --git a/pkg/database/password.go b/pkg/database/password.go new file mode 100644 index 0000000..d24183d --- /dev/null +++ b/pkg/database/password.go @@ -0,0 +1,14 @@ +package database + +import "golang.org/x/crypto/bcrypt" + +// To allow for testing the error case (bad random is hard to trigger) +var passwordFunction = bcrypt.GenerateFromPassword + +func hashPassword(password string) (string, error) { + bytes, err := passwordFunction([]byte(password), bcrypt.DefaultCost) + if err != nil { + return "", newPasswordError(err) + } + return string(bytes), nil +} diff --git a/pkg/database/users.go b/pkg/database/users.go new file mode 100644 index 0000000..3f75f62 --- /dev/null +++ b/pkg/database/users.go @@ -0,0 +1,47 @@ +package database + +import ( + "git.adyxax.org/adyxax/trains/pkg/model" +) + +// Creates a new user in the database +// a QueryError is return if the username already exists (database constraints not met) +func (env *DBEnv) CreateUser(reg *model.UserRegistration) (*model.User, error) { + hash, err := hashPassword(reg.Password) + if err != nil { + return nil, err + } + query := ` + INSERT INTO users + (username, password, email) + VALUES + ($1, $2, $3);` + tx, err := env.db.Begin() + if err != nil { + return nil, newTransactionError("Could not Begin()", err) + } + result, err := tx.Exec( + query, + reg.Username, + hash, + reg.Email, + ) + if err != nil { + tx.Rollback() + return nil, newQueryError("Could not run database query, most likely the username already exists", err) + } + id, err := result.LastInsertId() + if err != nil { + tx.Rollback() + return nil, newTransactionError("Could not get LastInsertId, the database driver does not support this feature", err) + } + if err := tx.Commit(); err != nil { + return nil, newTransactionError("Could not commit transaction", err) + } + user := model.User{ + Id: int(id), + Username: reg.Username, + Email: reg.Email, + } + return &user, nil +} diff --git a/pkg/database/users_test.go b/pkg/database/users_test.go new file mode 100644 index 0000000..4b87d90 --- /dev/null +++ b/pkg/database/users_test.go @@ -0,0 +1,111 @@ +package database + +import ( + "reflect" + "testing" + + "git.adyxax.org/adyxax/trains/pkg/model" + "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/bcrypt" +) + +func TestCreateUser(t *testing.T) { + // test db setup + db, err := InitDB("sqlite3", "file::memory:?_foreign_keys=on") + require.NoError(t, err) + err = db.Migrate() + require.NoError(t, err) + // a normal user + normalUser := model.UserRegistration{ + Username: "testUsername", + Password: "testPassword", + Email: "testEmail", + } + // Test another normal user + normalUser2 := model.UserRegistration{ + Username: "testUsername2", + Password: "yei*j2Ien2xa?g6bieh~i=asoo5ii7Bi", + Email: "testEmail", + } + // Test cases + testCases := []struct { + name string + db *DBEnv + input *model.UserRegistration + expected int + expectedError interface{} + }{ + {"Normal user", db, &normalUser, 1, nil}, + {"Duplicate user", db, &normalUser, 0, &QueryError{}}, + {"Normal user 2", db, &normalUser2, 2, nil}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + valid, err := tc.db.CreateUser(tc.input) + if tc.expectedError != nil { + require.Error(t, err) + assert.Equalf(t, reflect.TypeOf(err), reflect.TypeOf(tc.expectedError), "Invalid error type. Got %s but expected %s", reflect.TypeOf(err), reflect.TypeOf(tc.expectedError)) + require.Nil(t, valid) + } else { + require.NoError(t, err) + assert.Equal(t, tc.expected, valid.Id) + } + }) + } + // Test for bad password + passwordFunction = func(password []byte, cost int) ([]byte, error) { return nil, newPasswordError(nil) } + valid, err := db.CreateUser(&normalUser) + passwordFunction = bcrypt.GenerateFromPassword + require.Error(t, err) + require.Nil(t, valid) + assert.Equalf(t, reflect.TypeOf(err), reflect.TypeOf(&PasswordError{}), "Invalid error type. Got %s but expected %s", reflect.TypeOf(err), reflect.TypeOf(&PasswordError{})) +} + +func TestCreateUserWithSQLMock(t *testing.T) { + // Transaction begin error + dbBeginError, _, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer dbBeginError.Close() + // Transaction LastInsertId not supported + dbLastInsertIdError, mockLastInsertIdError, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer dbLastInsertIdError.Close() + mockLastInsertIdError.ExpectBegin() + mockLastInsertIdError.ExpectExec(`INSERT INTO`).WillReturnResult(sqlmock.NewErrorResult(&TransactionError{"test", nil})) + // Transaction commit error + dbCommitError, mockCommitError, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer dbCommitError.Close() + mockCommitError.ExpectBegin() + mockCommitError.ExpectExec(`INSERT INTO`).WillReturnResult(sqlmock.NewResult(1, 1)) + // Test cases + testCases := []struct { + name string + db *DBEnv + expectedError interface{} + }{ + {"begin transaction error", &DBEnv{db: dbBeginError}, &TransactionError{}}, + {"last insert id transaction error", &DBEnv{db: dbLastInsertIdError}, &TransactionError{}}, + {"commit transaction error", &DBEnv{db: dbCommitError}, &TransactionError{}}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + valid, err := tc.db.CreateUser(&model.UserRegistration{}) + if tc.expectedError != nil { + require.Error(t, err) + assert.Equalf(t, reflect.TypeOf(err), reflect.TypeOf(tc.expectedError), "Invalid error type. Got %s but expected %s", reflect.TypeOf(err), reflect.TypeOf(tc.expectedError)) + require.Nil(t, valid) + } else { + require.NoError(t, err) + } + }) + } +} |