2024-09-30 00:58:49 +02:00
package database
import (
"database/sql"
2024-10-14 23:54:49 +02:00
"errors"
2024-10-03 00:13:09 +02:00
"fmt"
2024-10-14 23:54:49 +02:00
"slices"
2024-09-30 00:58:49 +02:00
)
2024-10-05 09:54:35 +02:00
// returns true in case of successful deletion
func ( db * DB ) DeleteState ( name string ) ( bool , error ) {
result , err := db . Exec ( ` DELETE FROM states WHERE name = ?; ` , name )
if err != nil {
return false , err
}
n , err := result . RowsAffected ( )
if err != nil {
return false , err
}
return n == 1 , nil
2024-10-02 08:26:53 +02:00
}
2024-09-30 00:58:49 +02:00
func ( db * DB ) GetState ( name string ) ( [ ] byte , error ) {
var encryptedData [ ] byte
2024-10-14 23:54:49 +02:00
err := db . QueryRow ( ` SELECT versions.data FROM versions JOIN states ON states.id = versions.state_id WHERE states.name = ? ORDER BY versions.id DESC LIMIT 1; ` , name ) . Scan ( & encryptedData )
2024-09-30 00:58:49 +02:00
if err != nil {
2024-10-14 23:54:49 +02:00
if errors . Is ( err , sql . ErrNoRows ) {
return [ ] byte { } , nil
}
2024-09-30 00:58:49 +02:00
return nil , err
}
2024-10-06 00:13:53 +02:00
if encryptedData == nil {
return [ ] byte { } , nil
}
2024-09-30 00:58:49 +02:00
return db . dataEncryptionKey . DecryptAES256 ( encryptedData )
}
2024-10-03 00:13:09 +02:00
// returns true in case of id mismatch
2024-10-14 23:54:49 +02:00
func ( db * DB ) SetState ( name string , data [ ] byte , lockID string ) ( bool , error ) {
2024-09-30 00:58:49 +02:00
encryptedData , err := db . dataEncryptionKey . EncryptAES256 ( data )
if err != nil {
2024-10-03 00:13:09 +02:00
return false , err
2024-09-30 00:58:49 +02:00
}
2024-10-14 23:54:49 +02:00
tx , err := db . Begin ( )
if err != nil {
2024-10-03 00:13:09 +02:00
return false , err
2024-10-14 23:54:49 +02:00
}
defer func ( ) {
2024-10-03 00:13:09 +02:00
if err != nil {
2024-10-14 23:54:49 +02:00
_ = tx . Rollback ( )
2024-10-03 00:13:09 +02:00
}
2024-10-14 23:54:49 +02:00
} ( )
var (
stateID int64
lockData [ ] byte
)
if err = tx . QueryRowContext ( db . ctx , ` SELECT id, lock->>'ID' FROM states WHERE name = ?; ` , name ) . Scan ( & stateID , & lockData ) ; err != nil {
if errors . Is ( err , sql . ErrNoRows ) {
var result sql . Result
result , err = tx . ExecContext ( db . ctx , ` INSERT INTO states(name) VALUES (?) ` , name )
if err != nil {
return false , err
}
stateID , err = result . LastInsertId ( )
if err != nil {
return false , err
}
} else {
2024-10-03 00:13:09 +02:00
return false , err
}
2024-09-30 00:58:49 +02:00
}
2024-10-14 23:54:49 +02:00
if lockID != "" && slices . Compare ( [ ] byte ( lockID ) , lockData ) != 0 {
err = fmt . Errorf ( "failed to update state, lock ID does not match" )
return true , err
}
_ , err = tx . ExecContext ( db . ctx , ` INSERT INTO versions(state_id, data) VALUES (?, ?); ` , stateID , encryptedData )
if err != nil {
return false , err
}
// TODO delete old states
err = tx . Commit ( )
return false , err
2024-09-30 00:58:49 +02:00
}