124 lines
3.4 KiB
Go
124 lines
3.4 KiB
Go
package database
|
|
|
|
import (
|
|
"database/sql"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"time"
|
|
|
|
"git.adyxax.org/adyxax/tfstated/pkg/helpers"
|
|
"git.adyxax.org/adyxax/tfstated/pkg/model"
|
|
"git.adyxax.org/adyxax/tfstated/pkg/scrypto"
|
|
)
|
|
|
|
func (db *DB) CreateSession(sessionData *model.SessionData) (string, *model.Session, error) {
|
|
sessionBytes := scrypto.RandomBytes(32)
|
|
sessionId := base64.RawURLEncoding.EncodeToString(sessionBytes[:])
|
|
sessionHash := helpers.HashSessionId(sessionBytes, db.sessionsSalt.Bytes())
|
|
if sessionData == nil {
|
|
var err error
|
|
sessionData, err = model.NewSessionData(nil, nil)
|
|
if err != nil {
|
|
return "", nil, fmt.Errorf("failed to generate new session data: %w", err)
|
|
}
|
|
}
|
|
data, err := json.Marshal(sessionData)
|
|
if err != nil {
|
|
return "", nil, fmt.Errorf("failed to marshal session data: %w", err)
|
|
}
|
|
if _, err := db.Exec(
|
|
`INSERT INTO sessions(id, data)
|
|
VALUES (?, jsonb(?));`,
|
|
sessionHash,
|
|
data,
|
|
); err != nil {
|
|
return "", nil, fmt.Errorf("failed insert new session in database: %w", err)
|
|
}
|
|
return sessionId, &model.Session{
|
|
Id: sessionHash,
|
|
Data: sessionData,
|
|
}, nil
|
|
}
|
|
|
|
func (db *DB) DeleteExpiredSessions() error {
|
|
expires := time.Now().Add(-12 * time.Hour)
|
|
_, err := db.Exec(`DELETE FROM sessions WHERE created < ?`, expires.Unix())
|
|
if err != nil {
|
|
return fmt.Errorf("failed to delete expired session: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (db *DB) DeleteSession(session *model.Session) error {
|
|
_, err := db.Exec(`DELETE FROM sessions WHERE id = ?`, session.Id)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to delete session: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (db *DB) DeleteSessions(account *model.Account) error {
|
|
_, err := db.Exec(
|
|
`DELETE FROM sessions WHERE data->'account'->>'id' = ?`,
|
|
account.Id)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to delete sessions: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (db *DB) LoadSessionById(id string) (*model.Session, error) {
|
|
sessionBytes, err := base64.RawURLEncoding.DecodeString(id)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to decode base64 session id: %w", err)
|
|
}
|
|
sessionHash := helpers.HashSessionId(sessionBytes, db.sessionsSalt.Bytes())
|
|
session := model.Session{
|
|
Id: sessionHash,
|
|
}
|
|
var (
|
|
created int64
|
|
updated int64
|
|
data []byte
|
|
)
|
|
err = db.QueryRow(
|
|
`SELECT created,
|
|
updated,
|
|
json_extract(data, '$')
|
|
FROM sessions
|
|
WHERE id = ?;`,
|
|
sessionHash,
|
|
).Scan(
|
|
&created,
|
|
&updated,
|
|
&data)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, nil
|
|
}
|
|
return nil, fmt.Errorf("failed to load session by id %s: %w", id, err)
|
|
}
|
|
if err := json.Unmarshal(data, &session.Data); err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal session data: %w", err)
|
|
}
|
|
session.Created = time.Unix(created, 0)
|
|
session.Updated = time.Unix(updated, 0)
|
|
return &session, nil
|
|
}
|
|
|
|
func (db *DB) MigrateSession(session *model.Session, account *model.Account) (string, *model.Session, error) {
|
|
if err := db.DeleteSession(session); err != nil {
|
|
return "", nil, fmt.Errorf("failed to delete session: %w", err)
|
|
}
|
|
sessionData, err := model.NewSessionData(account, session.Data.Settings)
|
|
if err != nil {
|
|
return "", nil, fmt.Errorf("failed to generate new session data: %w", err)
|
|
}
|
|
sessionId, session, err := db.CreateSession(sessionData)
|
|
if err != nil {
|
|
return "", nil, fmt.Errorf("failed to create session: %w", err)
|
|
}
|
|
return sessionId, session, nil
|
|
}
|