139 lines
3.7 KiB
Go
139 lines
3.7 KiB
Go
package seal
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"log/slog"
|
|
"path/filepath"
|
|
"testing"
|
|
|
|
"git.wntrmute.dev/kyle/metacrypt/internal/barrier"
|
|
"git.wntrmute.dev/kyle/metacrypt/internal/crypto"
|
|
"git.wntrmute.dev/kyle/metacrypt/internal/db"
|
|
)
|
|
|
|
func setupSeal(t *testing.T) (*Manager, func()) {
|
|
t.Helper()
|
|
dir := t.TempDir()
|
|
database, err := db.Open(filepath.Join(dir, "test.db"))
|
|
if err != nil {
|
|
t.Fatalf("open db: %v", err)
|
|
}
|
|
if err := db.Migrate(database); err != nil {
|
|
t.Fatalf("migrate: %v", err)
|
|
}
|
|
b := barrier.NewAESGCMBarrier(database)
|
|
mgr := NewManager(database, b, slog.Default())
|
|
return mgr, func() { database.Close() }
|
|
}
|
|
|
|
func TestSealInitializeAndUnseal(t *testing.T) {
|
|
mgr, cleanup := setupSeal(t)
|
|
defer cleanup()
|
|
|
|
if err := mgr.CheckInitialized(); err != nil {
|
|
t.Fatalf("CheckInitialized: %v", err)
|
|
}
|
|
if mgr.State() != StateUninitialized {
|
|
t.Fatalf("state: got %v, want Uninitialized", mgr.State())
|
|
}
|
|
|
|
password := []byte("test-password-123")
|
|
// Use fast params for testing.
|
|
params := crypto.Argon2Params{Time: 1, Memory: 64 * 1024, Threads: 1}
|
|
|
|
if err := mgr.Initialize(context.Background(), password, params); err != nil {
|
|
t.Fatalf("Initialize: %v", err)
|
|
}
|
|
if mgr.State() != StateUnsealed {
|
|
t.Fatalf("state after init: got %v, want Unsealed", mgr.State())
|
|
}
|
|
|
|
// Seal.
|
|
if err := mgr.Seal(); err != nil {
|
|
t.Fatalf("Seal: %v", err)
|
|
}
|
|
if mgr.State() != StateSealed {
|
|
t.Fatalf("state after seal: got %v, want Sealed", mgr.State())
|
|
}
|
|
|
|
// Unseal with correct password.
|
|
if err := mgr.Unseal(password); err != nil {
|
|
t.Fatalf("Unseal: %v", err)
|
|
}
|
|
if mgr.State() != StateUnsealed {
|
|
t.Fatalf("state after unseal: got %v, want Unsealed", mgr.State())
|
|
}
|
|
}
|
|
|
|
func TestSealWrongPassword(t *testing.T) {
|
|
mgr, cleanup := setupSeal(t)
|
|
defer cleanup()
|
|
mgr.CheckInitialized()
|
|
|
|
params := crypto.Argon2Params{Time: 1, Memory: 64 * 1024, Threads: 1}
|
|
mgr.Initialize(context.Background(), []byte("correct"), params)
|
|
mgr.Seal()
|
|
|
|
err := mgr.Unseal([]byte("wrong"))
|
|
if !errors.Is(err, ErrInvalidPassword) {
|
|
t.Fatalf("expected ErrInvalidPassword, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestSealDoubleInitialize(t *testing.T) {
|
|
mgr, cleanup := setupSeal(t)
|
|
defer cleanup()
|
|
mgr.CheckInitialized()
|
|
|
|
params := crypto.Argon2Params{Time: 1, Memory: 64 * 1024, Threads: 1}
|
|
mgr.Initialize(context.Background(), []byte("password"), params)
|
|
|
|
err := mgr.Initialize(context.Background(), []byte("password"), params)
|
|
if !errors.Is(err, ErrAlreadyInitialized) {
|
|
t.Fatalf("expected ErrAlreadyInitialized, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestSealCheckInitializedPersists(t *testing.T) {
|
|
dir := t.TempDir()
|
|
dbPath := filepath.Join(dir, "test.db")
|
|
|
|
// First: initialize.
|
|
database, _ := db.Open(dbPath)
|
|
db.Migrate(database)
|
|
b := barrier.NewAESGCMBarrier(database)
|
|
mgr := NewManager(database, b, slog.Default())
|
|
mgr.CheckInitialized()
|
|
params := crypto.Argon2Params{Time: 1, Memory: 64 * 1024, Threads: 1}
|
|
mgr.Initialize(context.Background(), []byte("password"), params)
|
|
database.Close()
|
|
|
|
// Second: reopen and check.
|
|
database2, _ := db.Open(dbPath)
|
|
defer database2.Close()
|
|
b2 := barrier.NewAESGCMBarrier(database2)
|
|
mgr2 := NewManager(database2, b2, slog.Default())
|
|
mgr2.CheckInitialized()
|
|
if mgr2.State() != StateSealed {
|
|
t.Fatalf("state after reopen: got %v, want Sealed", mgr2.State())
|
|
}
|
|
}
|
|
|
|
func TestSealStateString(t *testing.T) {
|
|
tests := []struct {
|
|
want string
|
|
state ServiceState
|
|
}{
|
|
{want: "uninitialized", state: StateUninitialized},
|
|
{want: "sealed", state: StateSealed},
|
|
{want: "initializing", state: StateInitializing},
|
|
{want: "unsealed", state: StateUnsealed},
|
|
}
|
|
for _, tt := range tests {
|
|
if got := tt.state.String(); got != tt.want {
|
|
t.Errorf("State(%d).String() = %q, want %q", tt.state, got, tt.want)
|
|
}
|
|
}
|
|
}
|