131 lines
3.8 KiB
Go
131 lines
3.8 KiB
Go
package engine
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"log/slog"
|
|
"testing"
|
|
|
|
"git.wntrmute.dev/kyle/metacrypt/internal/barrier"
|
|
)
|
|
|
|
// mockEngine implements Engine for testing.
|
|
type mockEngine struct {
|
|
engineType EngineType
|
|
initialized bool
|
|
unsealed bool
|
|
}
|
|
|
|
func (m *mockEngine) Type() EngineType { return m.engineType }
|
|
func (m *mockEngine) Initialize(_ context.Context, _ barrier.Barrier, _ string, _ map[string]interface{}) error {
|
|
m.initialized = true
|
|
return nil
|
|
}
|
|
func (m *mockEngine) Unseal(_ context.Context, _ barrier.Barrier, _ string) error {
|
|
m.unsealed = true
|
|
return nil
|
|
}
|
|
func (m *mockEngine) Seal() error { m.unsealed = false; return nil }
|
|
func (m *mockEngine) HandleRequest(_ context.Context, _ *Request) (*Response, error) {
|
|
return &Response{Data: map[string]interface{}{"ok": true}}, nil
|
|
}
|
|
|
|
type mockBarrier struct{}
|
|
|
|
func (m *mockBarrier) Unseal(_ []byte) error { return nil }
|
|
func (m *mockBarrier) Seal() error { return nil }
|
|
func (m *mockBarrier) IsSealed() bool { return false }
|
|
func (m *mockBarrier) Get(_ context.Context, _ string) ([]byte, error) {
|
|
return nil, barrier.ErrNotFound
|
|
}
|
|
func (m *mockBarrier) Put(_ context.Context, _ string, _ []byte) error { return nil }
|
|
func (m *mockBarrier) Delete(_ context.Context, _ string) error { return nil }
|
|
func (m *mockBarrier) List(_ context.Context, _ string) ([]string, error) { return nil, nil }
|
|
|
|
func TestRegistryMountUnmount(t *testing.T) {
|
|
reg := NewRegistry(&mockBarrier{}, slog.Default())
|
|
reg.RegisterFactory(EngineTypeTransit, func() Engine {
|
|
return &mockEngine{engineType: EngineTypeTransit}
|
|
})
|
|
|
|
ctx := context.Background()
|
|
if err := reg.Mount(ctx, "default", EngineTypeTransit, nil); err != nil {
|
|
t.Fatalf("Mount: %v", err)
|
|
}
|
|
|
|
mounts := reg.ListMounts()
|
|
if len(mounts) != 1 {
|
|
t.Fatalf("ListMounts: got %d, want 1", len(mounts))
|
|
}
|
|
if mounts[0].Name != "default" {
|
|
t.Errorf("mount name: got %q, want %q", mounts[0].Name, "default")
|
|
}
|
|
|
|
// Duplicate mount should fail.
|
|
if err := reg.Mount(ctx, "default", EngineTypeTransit, nil); !errors.Is(err, ErrMountExists) {
|
|
t.Fatalf("expected ErrMountExists, got: %v", err)
|
|
}
|
|
|
|
if err := reg.Unmount(ctx, "default"); err != nil {
|
|
t.Fatalf("Unmount: %v", err)
|
|
}
|
|
|
|
mounts = reg.ListMounts()
|
|
if len(mounts) != 0 {
|
|
t.Fatalf("after unmount: got %d mounts", len(mounts))
|
|
}
|
|
}
|
|
|
|
func TestRegistryUnmountNotFound(t *testing.T) {
|
|
reg := NewRegistry(&mockBarrier{}, slog.Default())
|
|
if err := reg.Unmount(context.Background(), "nonexistent"); !errors.Is(err, ErrMountNotFound) {
|
|
t.Fatalf("expected ErrMountNotFound, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestRegistryUnknownType(t *testing.T) {
|
|
reg := NewRegistry(&mockBarrier{}, slog.Default())
|
|
err := reg.Mount(context.Background(), "test", EngineTypeTransit, nil)
|
|
if err == nil {
|
|
t.Fatal("expected error for unknown engine type")
|
|
}
|
|
}
|
|
|
|
func TestRegistryHandleRequest(t *testing.T) {
|
|
reg := NewRegistry(&mockBarrier{}, slog.Default())
|
|
reg.RegisterFactory(EngineTypeTransit, func() Engine {
|
|
return &mockEngine{engineType: EngineTypeTransit}
|
|
})
|
|
|
|
ctx := context.Background()
|
|
_ = reg.Mount(ctx, "test", EngineTypeTransit, nil)
|
|
|
|
resp, err := reg.HandleRequest(ctx, "test", &Request{Operation: "encrypt"})
|
|
if err != nil {
|
|
t.Fatalf("HandleRequest: %v", err)
|
|
}
|
|
if resp.Data["ok"] != true {
|
|
t.Error("expected ok=true in response")
|
|
}
|
|
|
|
_, err = reg.HandleRequest(ctx, "nonexistent", &Request{})
|
|
if !errors.Is(err, ErrMountNotFound) {
|
|
t.Fatalf("expected ErrMountNotFound, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestRegistrySealAll(t *testing.T) {
|
|
reg := NewRegistry(&mockBarrier{}, slog.Default())
|
|
reg.RegisterFactory(EngineTypeTransit, func() Engine {
|
|
return &mockEngine{engineType: EngineTypeTransit}
|
|
})
|
|
|
|
ctx := context.Background()
|
|
_ = reg.Mount(ctx, "eng1", EngineTypeTransit, nil)
|
|
_ = reg.Mount(ctx, "eng2", EngineTypeTransit, nil)
|
|
|
|
if err := reg.SealAll(); err != nil {
|
|
t.Fatalf("SealAll: %v", err)
|
|
}
|
|
}
|