Implement Phase 1: core framework, operational tooling, and runbook
Core packages: crypto (Argon2id/AES-256-GCM), config (TOML/viper), db (SQLite/migrations), barrier (encrypted storage), seal (state machine with rate-limited unseal), auth (MCIAS integration with token cache), policy (priority-based ACL engine), engine (interface + registry). Server: HTTPS with TLS 1.2+, REST API, auth/admin middleware, htmx web UI (init, unseal, login, dashboard pages). CLI: cobra/viper subcommands (server, init, status, snapshot) with env var override support (METACRYPT_ prefix). Operational tooling: Dockerfile (multi-stage, non-root), docker-compose, hardened systemd units (service + daily backup timer), install script, backup script with retention pruning, production config examples. Runbook covering installation, configuration, daily operations, backup/restore, monitoring, troubleshooting, and security procedures. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
125
internal/auth/auth.go
Normal file
125
internal/auth/auth.go
Normal file
@@ -0,0 +1,125 @@
|
||||
// Package auth provides MCIAS authentication integration with token caching.
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
mcias "git.wntrmute.dev/kyle/mcias/clients/go"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidCredentials = errors.New("auth: invalid credentials")
|
||||
ErrInvalidToken = errors.New("auth: invalid token")
|
||||
)
|
||||
|
||||
const tokenCacheTTL = 30 * time.Second
|
||||
|
||||
// TokenInfo holds validated token information.
|
||||
type TokenInfo struct {
|
||||
Username string
|
||||
Roles []string
|
||||
IsAdmin bool
|
||||
}
|
||||
|
||||
// cachedClaims holds a cached token validation result.
|
||||
type cachedClaims struct {
|
||||
info *TokenInfo
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
// Authenticator provides MCIAS-backed authentication.
|
||||
type Authenticator struct {
|
||||
client *mcias.Client
|
||||
|
||||
mu sync.RWMutex
|
||||
cache map[string]*cachedClaims // keyed by SHA-256(token)
|
||||
}
|
||||
|
||||
// NewAuthenticator creates a new authenticator with the given MCIAS client.
|
||||
func NewAuthenticator(client *mcias.Client) *Authenticator {
|
||||
return &Authenticator{
|
||||
client: client,
|
||||
cache: make(map[string]*cachedClaims),
|
||||
}
|
||||
}
|
||||
|
||||
// Login authenticates a user via MCIAS and returns the token.
|
||||
func (a *Authenticator) Login(username, password, totpCode string) (token string, expiresAt string, err error) {
|
||||
tok, exp, err := a.client.Login(username, password, totpCode)
|
||||
if err != nil {
|
||||
var authErr *mcias.MciasAuthError
|
||||
if errors.As(err, &authErr) {
|
||||
return "", "", ErrInvalidCredentials
|
||||
}
|
||||
return "", "", err
|
||||
}
|
||||
return tok, exp, nil
|
||||
}
|
||||
|
||||
// ValidateToken validates a bearer token, using a short-lived cache.
|
||||
func (a *Authenticator) ValidateToken(token string) (*TokenInfo, error) {
|
||||
key := tokenHash(token)
|
||||
|
||||
// Check cache.
|
||||
a.mu.RLock()
|
||||
cached, ok := a.cache[key]
|
||||
a.mu.RUnlock()
|
||||
if ok && time.Now().Before(cached.expiresAt) {
|
||||
return cached.info, nil
|
||||
}
|
||||
|
||||
// Validate with MCIAS.
|
||||
claims, err := a.client.ValidateToken(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !claims.Valid {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
info := &TokenInfo{
|
||||
Username: claims.Sub,
|
||||
Roles: claims.Roles,
|
||||
IsAdmin: hasAdminRole(claims.Roles),
|
||||
}
|
||||
|
||||
// Cache the result.
|
||||
a.mu.Lock()
|
||||
a.cache[key] = &cachedClaims{
|
||||
info: info,
|
||||
expiresAt: time.Now().Add(tokenCacheTTL),
|
||||
}
|
||||
a.mu.Unlock()
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// Logout invalidates a token via MCIAS. The client must have the token set.
|
||||
func (a *Authenticator) Logout(client *mcias.Client) error {
|
||||
return client.Logout()
|
||||
}
|
||||
|
||||
// ClearCache removes all cached token validations.
|
||||
func (a *Authenticator) ClearCache() {
|
||||
a.mu.Lock()
|
||||
a.cache = make(map[string]*cachedClaims)
|
||||
a.mu.Unlock()
|
||||
}
|
||||
|
||||
func tokenHash(token string) string {
|
||||
h := sha256.Sum256([]byte(token))
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
|
||||
func hasAdminRole(roles []string) bool {
|
||||
for _, r := range roles {
|
||||
if r == "admin" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
52
internal/auth/auth_test.go
Normal file
52
internal/auth/auth_test.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestTokenHash(t *testing.T) {
|
||||
h1 := tokenHash("token-abc")
|
||||
h2 := tokenHash("token-abc")
|
||||
h3 := tokenHash("token-def")
|
||||
|
||||
if h1 != h2 {
|
||||
t.Error("same input should produce same hash")
|
||||
}
|
||||
if h1 == h3 {
|
||||
t.Error("different inputs should produce different hashes")
|
||||
}
|
||||
if len(h1) != 64 { // SHA-256 hex
|
||||
t.Errorf("hash length: got %d, want 64", len(h1))
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasAdminRole(t *testing.T) {
|
||||
if !hasAdminRole([]string{"user", "admin"}) {
|
||||
t.Error("should detect admin role")
|
||||
}
|
||||
if hasAdminRole([]string{"user", "operator"}) {
|
||||
t.Error("should not detect admin role when absent")
|
||||
}
|
||||
if hasAdminRole(nil) {
|
||||
t.Error("nil roles should not be admin")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAuthenticator(t *testing.T) {
|
||||
a := NewAuthenticator(nil)
|
||||
if a == nil {
|
||||
t.Fatal("NewAuthenticator returned nil")
|
||||
}
|
||||
if a.cache == nil {
|
||||
t.Error("cache should be initialized")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearCache(t *testing.T) {
|
||||
a := NewAuthenticator(nil)
|
||||
a.cache["test"] = &cachedClaims{info: &TokenInfo{Username: "test"}}
|
||||
a.ClearCache()
|
||||
if len(a.cache) != 0 {
|
||||
t.Error("cache should be empty after clear")
|
||||
}
|
||||
}
|
||||
167
internal/barrier/barrier.go
Normal file
167
internal/barrier/barrier.go
Normal file
@@ -0,0 +1,167 @@
|
||||
// Package barrier provides an encrypted storage barrier backed by SQLite.
|
||||
package barrier
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/crypto"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrSealed = errors.New("barrier: sealed")
|
||||
ErrNotFound = errors.New("barrier: entry not found")
|
||||
)
|
||||
|
||||
// Barrier is the encrypted storage barrier interface.
|
||||
type Barrier interface {
|
||||
// Unseal opens the barrier with the given master encryption key.
|
||||
Unseal(mek []byte) error
|
||||
// Seal closes the barrier and zeroizes the key material.
|
||||
Seal() error
|
||||
// IsSealed returns true if the barrier is sealed.
|
||||
IsSealed() bool
|
||||
|
||||
// Get retrieves and decrypts a value by path.
|
||||
Get(ctx context.Context, path string) ([]byte, error)
|
||||
// Put encrypts and stores a value at the given path.
|
||||
Put(ctx context.Context, path string, value []byte) error
|
||||
// Delete removes an entry by path.
|
||||
Delete(ctx context.Context, path string) error
|
||||
// List returns paths with the given prefix.
|
||||
List(ctx context.Context, prefix string) ([]string, error)
|
||||
}
|
||||
|
||||
// AESGCMBarrier implements Barrier using AES-256-GCM encryption.
|
||||
type AESGCMBarrier struct {
|
||||
db *sql.DB
|
||||
mu sync.RWMutex
|
||||
mek []byte // nil when sealed
|
||||
}
|
||||
|
||||
// NewAESGCMBarrier creates a new AES-GCM barrier backed by the given database.
|
||||
func NewAESGCMBarrier(db *sql.DB) *AESGCMBarrier {
|
||||
return &AESGCMBarrier{db: db}
|
||||
}
|
||||
|
||||
func (b *AESGCMBarrier) Unseal(mek []byte) error {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
k := make([]byte, len(mek))
|
||||
copy(k, mek)
|
||||
b.mek = k
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *AESGCMBarrier) Seal() error {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
if b.mek != nil {
|
||||
crypto.Zeroize(b.mek)
|
||||
b.mek = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *AESGCMBarrier) IsSealed() bool {
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
return b.mek == nil
|
||||
}
|
||||
|
||||
func (b *AESGCMBarrier) Get(ctx context.Context, path string) ([]byte, error) {
|
||||
b.mu.RLock()
|
||||
mek := b.mek
|
||||
b.mu.RUnlock()
|
||||
if mek == nil {
|
||||
return nil, ErrSealed
|
||||
}
|
||||
|
||||
var encrypted []byte
|
||||
err := b.db.QueryRowContext(ctx,
|
||||
"SELECT value FROM barrier_entries WHERE path = ?", path).Scan(&encrypted)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("barrier: get %q: %w", path, err)
|
||||
}
|
||||
|
||||
plaintext, err := crypto.Decrypt(mek, encrypted)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("barrier: decrypt %q: %w", path, err)
|
||||
}
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
func (b *AESGCMBarrier) Put(ctx context.Context, path string, value []byte) error {
|
||||
b.mu.RLock()
|
||||
mek := b.mek
|
||||
b.mu.RUnlock()
|
||||
if mek == nil {
|
||||
return ErrSealed
|
||||
}
|
||||
|
||||
encrypted, err := crypto.Encrypt(mek, value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("barrier: encrypt %q: %w", path, err)
|
||||
}
|
||||
|
||||
_, err = b.db.ExecContext(ctx, `
|
||||
INSERT INTO barrier_entries (path, value) VALUES (?, ?)
|
||||
ON CONFLICT(path) DO UPDATE SET value = excluded.value, updated_at = datetime('now')`,
|
||||
path, encrypted)
|
||||
if err != nil {
|
||||
return fmt.Errorf("barrier: put %q: %w", path, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *AESGCMBarrier) Delete(ctx context.Context, path string) error {
|
||||
b.mu.RLock()
|
||||
mek := b.mek
|
||||
b.mu.RUnlock()
|
||||
if mek == nil {
|
||||
return ErrSealed
|
||||
}
|
||||
|
||||
_, err := b.db.ExecContext(ctx,
|
||||
"DELETE FROM barrier_entries WHERE path = ?", path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("barrier: delete %q: %w", path, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *AESGCMBarrier) List(ctx context.Context, prefix string) ([]string, error) {
|
||||
b.mu.RLock()
|
||||
mek := b.mek
|
||||
b.mu.RUnlock()
|
||||
if mek == nil {
|
||||
return nil, ErrSealed
|
||||
}
|
||||
|
||||
rows, err := b.db.QueryContext(ctx,
|
||||
"SELECT path FROM barrier_entries WHERE path LIKE ?",
|
||||
prefix+"%")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("barrier: list %q: %w", prefix, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var paths []string
|
||||
for rows.Next() {
|
||||
var p string
|
||||
if err := rows.Scan(&p); err != nil {
|
||||
return nil, fmt.Errorf("barrier: list scan: %w", err)
|
||||
}
|
||||
// Strip the prefix and return just the next segment.
|
||||
remainder := strings.TrimPrefix(p, prefix)
|
||||
paths = append(paths, remainder)
|
||||
}
|
||||
return paths, rows.Err()
|
||||
}
|
||||
159
internal/barrier/barrier_test.go
Normal file
159
internal/barrier/barrier_test.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package barrier
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/crypto"
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/db"
|
||||
)
|
||||
|
||||
func setupBarrier(t *testing.T) (*AESGCMBarrier, func()) {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
database, err := db.Open(filepath.Join(dir, "test.db"))
|
||||
if err != nil {
|
||||
t.Fatalf("open db: %v", err)
|
||||
}
|
||||
if err := db.Migrate(database); err != nil {
|
||||
t.Fatalf("migrate: %v", err)
|
||||
}
|
||||
b := NewAESGCMBarrier(database)
|
||||
return b, func() { database.Close() }
|
||||
}
|
||||
|
||||
func TestBarrierSealUnseal(t *testing.T) {
|
||||
b, cleanup := setupBarrier(t)
|
||||
defer cleanup()
|
||||
|
||||
if !b.IsSealed() {
|
||||
t.Fatal("new barrier should be sealed")
|
||||
}
|
||||
|
||||
mek, _ := crypto.GenerateKey()
|
||||
if err := b.Unseal(mek); err != nil {
|
||||
t.Fatalf("Unseal: %v", err)
|
||||
}
|
||||
if b.IsSealed() {
|
||||
t.Fatal("barrier should be unsealed")
|
||||
}
|
||||
|
||||
if err := b.Seal(); err != nil {
|
||||
t.Fatalf("Seal: %v", err)
|
||||
}
|
||||
if !b.IsSealed() {
|
||||
t.Fatal("barrier should be sealed after Seal()")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBarrierPutGet(t *testing.T) {
|
||||
b, cleanup := setupBarrier(t)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
mek, _ := crypto.GenerateKey()
|
||||
b.Unseal(mek)
|
||||
|
||||
data := []byte("test value")
|
||||
if err := b.Put(ctx, "test/path", data); err != nil {
|
||||
t.Fatalf("Put: %v", err)
|
||||
}
|
||||
|
||||
got, err := b.Get(ctx, "test/path")
|
||||
if err != nil {
|
||||
t.Fatalf("Get: %v", err)
|
||||
}
|
||||
if string(got) != string(data) {
|
||||
t.Fatalf("Get: got %q, want %q", got, data)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBarrierGetNotFound(t *testing.T) {
|
||||
b, cleanup := setupBarrier(t)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
mek, _ := crypto.GenerateKey()
|
||||
b.Unseal(mek)
|
||||
|
||||
_, err := b.Get(ctx, "nonexistent")
|
||||
if err != ErrNotFound {
|
||||
t.Fatalf("expected ErrNotFound, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBarrierDelete(t *testing.T) {
|
||||
b, cleanup := setupBarrier(t)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
mek, _ := crypto.GenerateKey()
|
||||
b.Unseal(mek)
|
||||
|
||||
b.Put(ctx, "test/delete-me", []byte("data"))
|
||||
if err := b.Delete(ctx, "test/delete-me"); err != nil {
|
||||
t.Fatalf("Delete: %v", err)
|
||||
}
|
||||
_, err := b.Get(ctx, "test/delete-me")
|
||||
if err != ErrNotFound {
|
||||
t.Fatalf("expected ErrNotFound after delete, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBarrierList(t *testing.T) {
|
||||
b, cleanup := setupBarrier(t)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
mek, _ := crypto.GenerateKey()
|
||||
b.Unseal(mek)
|
||||
|
||||
b.Put(ctx, "engine/ca/default/config", []byte("cfg"))
|
||||
b.Put(ctx, "engine/ca/default/dek", []byte("key"))
|
||||
b.Put(ctx, "engine/transit/main/config", []byte("cfg"))
|
||||
|
||||
paths, err := b.List(ctx, "engine/ca/")
|
||||
if err != nil {
|
||||
t.Fatalf("List: %v", err)
|
||||
}
|
||||
if len(paths) != 2 {
|
||||
t.Fatalf("List: got %d paths, want 2", len(paths))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBarrierSealedOperations(t *testing.T) {
|
||||
b, cleanup := setupBarrier(t)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
if _, err := b.Get(ctx, "test"); err != ErrSealed {
|
||||
t.Fatalf("Get when sealed: expected ErrSealed, got: %v", err)
|
||||
}
|
||||
if err := b.Put(ctx, "test", []byte("data")); err != ErrSealed {
|
||||
t.Fatalf("Put when sealed: expected ErrSealed, got: %v", err)
|
||||
}
|
||||
if err := b.Delete(ctx, "test"); err != ErrSealed {
|
||||
t.Fatalf("Delete when sealed: expected ErrSealed, got: %v", err)
|
||||
}
|
||||
if _, err := b.List(ctx, "test"); err != ErrSealed {
|
||||
t.Fatalf("List when sealed: expected ErrSealed, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBarrierOverwrite(t *testing.T) {
|
||||
b, cleanup := setupBarrier(t)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
mek, _ := crypto.GenerateKey()
|
||||
b.Unseal(mek)
|
||||
|
||||
b.Put(ctx, "test/overwrite", []byte("v1"))
|
||||
b.Put(ctx, "test/overwrite", []byte("v2"))
|
||||
|
||||
got, _ := b.Get(ctx, "test/overwrite")
|
||||
if string(got) != "v2" {
|
||||
t.Fatalf("overwrite: got %q, want %q", got, "v2")
|
||||
}
|
||||
}
|
||||
101
internal/config/config.go
Normal file
101
internal/config/config.go
Normal file
@@ -0,0 +1,101 @@
|
||||
// Package config provides TOML configuration loading and validation.
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/pelletier/go-toml/v2"
|
||||
)
|
||||
|
||||
// Config is the top-level configuration for Metacrypt.
|
||||
type Config struct {
|
||||
Server ServerConfig `toml:"server"`
|
||||
Database DatabaseConfig `toml:"database"`
|
||||
MCIAS MCIASConfig `toml:"mcias"`
|
||||
Seal SealConfig `toml:"seal"`
|
||||
Log LogConfig `toml:"log"`
|
||||
}
|
||||
|
||||
// ServerConfig holds HTTP/gRPC server settings.
|
||||
type ServerConfig struct {
|
||||
ListenAddr string `toml:"listen_addr"`
|
||||
GRPCAddr string `toml:"grpc_addr"`
|
||||
TLSCert string `toml:"tls_cert"`
|
||||
TLSKey string `toml:"tls_key"`
|
||||
}
|
||||
|
||||
// DatabaseConfig holds SQLite database settings.
|
||||
type DatabaseConfig struct {
|
||||
Path string `toml:"path"`
|
||||
}
|
||||
|
||||
// MCIASConfig holds MCIAS integration settings.
|
||||
type MCIASConfig struct {
|
||||
ServerURL string `toml:"server_url"`
|
||||
CACert string `toml:"ca_cert"`
|
||||
}
|
||||
|
||||
// SealConfig holds Argon2id parameters for the seal process.
|
||||
type SealConfig struct {
|
||||
Argon2Time uint32 `toml:"argon2_time"`
|
||||
Argon2Memory uint32 `toml:"argon2_memory"`
|
||||
Argon2Threads uint8 `toml:"argon2_threads"`
|
||||
}
|
||||
|
||||
// LogConfig holds logging settings.
|
||||
type LogConfig struct {
|
||||
Level string `toml:"level"`
|
||||
}
|
||||
|
||||
// Load reads and parses a TOML config file.
|
||||
func Load(path string) (*Config, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("config: read file: %w", err)
|
||||
}
|
||||
var cfg Config
|
||||
if err := toml.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("config: parse: %w", err)
|
||||
}
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
// Validate checks required fields and applies defaults.
|
||||
func (c *Config) Validate() error {
|
||||
if c.Server.ListenAddr == "" {
|
||||
return fmt.Errorf("config: server.listen_addr is required")
|
||||
}
|
||||
if c.Server.TLSCert == "" {
|
||||
return fmt.Errorf("config: server.tls_cert is required")
|
||||
}
|
||||
if c.Server.TLSKey == "" {
|
||||
return fmt.Errorf("config: server.tls_key is required")
|
||||
}
|
||||
if c.Database.Path == "" {
|
||||
return fmt.Errorf("config: database.path is required")
|
||||
}
|
||||
if c.MCIAS.ServerURL == "" {
|
||||
return fmt.Errorf("config: mcias.server_url is required")
|
||||
}
|
||||
|
||||
// Apply defaults for seal parameters.
|
||||
if c.Seal.Argon2Time == 0 {
|
||||
c.Seal.Argon2Time = 3
|
||||
}
|
||||
if c.Seal.Argon2Memory == 0 {
|
||||
c.Seal.Argon2Memory = 128 * 1024
|
||||
}
|
||||
if c.Seal.Argon2Threads == 0 {
|
||||
c.Seal.Argon2Threads = 4
|
||||
}
|
||||
|
||||
if c.Log.Level == "" {
|
||||
c.Log.Level = "info"
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
64
internal/config/config_test.go
Normal file
64
internal/config/config_test.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoadValid(t *testing.T) {
|
||||
content := `
|
||||
[server]
|
||||
listen_addr = ":8443"
|
||||
tls_cert = "cert.pem"
|
||||
tls_key = "key.pem"
|
||||
|
||||
[database]
|
||||
path = "test.db"
|
||||
|
||||
[mcias]
|
||||
server_url = "https://mcias.example.com"
|
||||
`
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "test.toml")
|
||||
os.WriteFile(path, []byte(content), 0600)
|
||||
|
||||
cfg, err := Load(path)
|
||||
if err != nil {
|
||||
t.Fatalf("Load: %v", err)
|
||||
}
|
||||
if cfg.Server.ListenAddr != ":8443" {
|
||||
t.Errorf("ListenAddr: got %q", cfg.Server.ListenAddr)
|
||||
}
|
||||
if cfg.Seal.Argon2Time != 3 {
|
||||
t.Errorf("Argon2Time default: got %d, want 3", cfg.Seal.Argon2Time)
|
||||
}
|
||||
if cfg.Seal.Argon2Memory != 128*1024 {
|
||||
t.Errorf("Argon2Memory default: got %d", cfg.Seal.Argon2Memory)
|
||||
}
|
||||
if cfg.Log.Level != "info" {
|
||||
t.Errorf("Log.Level default: got %q", cfg.Log.Level)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadMissingRequired(t *testing.T) {
|
||||
content := `
|
||||
[server]
|
||||
listen_addr = ":8443"
|
||||
`
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "test.toml")
|
||||
os.WriteFile(path, []byte(content), 0600)
|
||||
|
||||
_, err := Load(path)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing required fields")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadMissingFile(t *testing.T) {
|
||||
_, err := Load("/nonexistent/path.toml")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing file")
|
||||
}
|
||||
}
|
||||
137
internal/crypto/crypto.go
Normal file
137
internal/crypto/crypto.go
Normal file
@@ -0,0 +1,137 @@
|
||||
// Package crypto provides Argon2id KDF, AES-256-GCM encryption, and key helpers.
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/crypto/argon2"
|
||||
)
|
||||
|
||||
const (
|
||||
// KeySize is the size of AES-256 keys in bytes.
|
||||
KeySize = 32
|
||||
// NonceSize is the size of AES-GCM nonces in bytes.
|
||||
NonceSize = 12
|
||||
// SaltSize is the size of Argon2id salts in bytes.
|
||||
SaltSize = 32
|
||||
|
||||
// BarrierVersion is the version byte prefix for encrypted barrier entries.
|
||||
BarrierVersion byte = 0x01
|
||||
|
||||
// Default Argon2id parameters.
|
||||
DefaultArgon2Time = 3
|
||||
DefaultArgon2Memory = 128 * 1024 // 128 MiB in KiB
|
||||
DefaultArgon2Threads = 4
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidCiphertext = errors.New("crypto: invalid ciphertext")
|
||||
ErrDecryptionFailed = errors.New("crypto: decryption failed")
|
||||
)
|
||||
|
||||
// Argon2Params holds Argon2id KDF parameters.
|
||||
type Argon2Params struct {
|
||||
Time uint32
|
||||
Memory uint32 // in KiB
|
||||
Threads uint8
|
||||
}
|
||||
|
||||
// DefaultArgon2Params returns the default Argon2id parameters.
|
||||
func DefaultArgon2Params() Argon2Params {
|
||||
return Argon2Params{
|
||||
Time: DefaultArgon2Time,
|
||||
Memory: DefaultArgon2Memory,
|
||||
Threads: DefaultArgon2Threads,
|
||||
}
|
||||
}
|
||||
|
||||
// DeriveKey derives a 256-bit key from password and salt using Argon2id.
|
||||
func DeriveKey(password []byte, salt []byte, params Argon2Params) []byte {
|
||||
return argon2.IDKey(password, salt, params.Time, params.Memory, params.Threads, KeySize)
|
||||
}
|
||||
|
||||
// GenerateKey generates a random 256-bit key.
|
||||
func GenerateKey() ([]byte, error) {
|
||||
key := make([]byte, KeySize)
|
||||
if _, err := rand.Read(key); err != nil {
|
||||
return nil, fmt.Errorf("crypto: generate key: %w", err)
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// GenerateSalt generates a random salt for Argon2id.
|
||||
func GenerateSalt() ([]byte, error) {
|
||||
salt := make([]byte, SaltSize)
|
||||
if _, err := rand.Read(salt); err != nil {
|
||||
return nil, fmt.Errorf("crypto: generate salt: %w", err)
|
||||
}
|
||||
return salt, nil
|
||||
}
|
||||
|
||||
// Encrypt encrypts plaintext with AES-256-GCM using the given key.
|
||||
// Returns: [version byte][12-byte nonce][ciphertext+tag]
|
||||
func Encrypt(key, plaintext []byte) ([]byte, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("crypto: new cipher: %w", err)
|
||||
}
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("crypto: new gcm: %w", err)
|
||||
}
|
||||
nonce := make([]byte, NonceSize)
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return nil, fmt.Errorf("crypto: generate nonce: %w", err)
|
||||
}
|
||||
ciphertext := gcm.Seal(nil, nonce, plaintext, nil)
|
||||
|
||||
// Format: [version][nonce][ciphertext+tag]
|
||||
result := make([]byte, 1+NonceSize+len(ciphertext))
|
||||
result[0] = BarrierVersion
|
||||
copy(result[1:1+NonceSize], nonce)
|
||||
copy(result[1+NonceSize:], ciphertext)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts ciphertext produced by Encrypt.
|
||||
func Decrypt(key, data []byte) ([]byte, error) {
|
||||
if len(data) < 1+NonceSize+aes.BlockSize {
|
||||
return nil, ErrInvalidCiphertext
|
||||
}
|
||||
if data[0] != BarrierVersion {
|
||||
return nil, fmt.Errorf("crypto: unsupported version: %d", data[0])
|
||||
}
|
||||
nonce := data[1 : 1+NonceSize]
|
||||
ciphertext := data[1+NonceSize:]
|
||||
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("crypto: new cipher: %w", err)
|
||||
}
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("crypto: new gcm: %w", err)
|
||||
}
|
||||
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
|
||||
if err != nil {
|
||||
return nil, ErrDecryptionFailed
|
||||
}
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
// Zeroize overwrites a byte slice with zeros.
|
||||
func Zeroize(b []byte) {
|
||||
for i := range b {
|
||||
b[i] = 0
|
||||
}
|
||||
}
|
||||
|
||||
// ConstantTimeEqual compares two byte slices in constant time.
|
||||
func ConstantTimeEqual(a, b []byte) bool {
|
||||
return subtle.ConstantTimeCompare(a, b) == 1
|
||||
}
|
||||
132
internal/crypto/crypto_test.go
Normal file
132
internal/crypto/crypto_test.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGenerateKey(t *testing.T) {
|
||||
key, err := GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateKey: %v", err)
|
||||
}
|
||||
if len(key) != KeySize {
|
||||
t.Fatalf("key length: got %d, want %d", len(key), KeySize)
|
||||
}
|
||||
// Should be random (not all zeros).
|
||||
if bytes.Equal(key, make([]byte, KeySize)) {
|
||||
t.Fatal("key is all zeros")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateSalt(t *testing.T) {
|
||||
salt, err := GenerateSalt()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateSalt: %v", err)
|
||||
}
|
||||
if len(salt) != SaltSize {
|
||||
t.Fatalf("salt length: got %d, want %d", len(salt), SaltSize)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptDecrypt(t *testing.T) {
|
||||
key, _ := GenerateKey()
|
||||
plaintext := []byte("hello, metacrypt!")
|
||||
|
||||
ciphertext, err := Encrypt(key, plaintext)
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt: %v", err)
|
||||
}
|
||||
|
||||
// Version byte should be present.
|
||||
if ciphertext[0] != BarrierVersion {
|
||||
t.Fatalf("version byte: got %d, want %d", ciphertext[0], BarrierVersion)
|
||||
}
|
||||
|
||||
decrypted, err := Decrypt(key, ciphertext)
|
||||
if err != nil {
|
||||
t.Fatalf("Decrypt: %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(plaintext, decrypted) {
|
||||
t.Fatalf("roundtrip failed: got %q, want %q", decrypted, plaintext)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecryptWrongKey(t *testing.T) {
|
||||
key1, _ := GenerateKey()
|
||||
key2, _ := GenerateKey()
|
||||
plaintext := []byte("secret data")
|
||||
|
||||
ciphertext, _ := Encrypt(key1, plaintext)
|
||||
_, err := Decrypt(key2, ciphertext)
|
||||
if err != ErrDecryptionFailed {
|
||||
t.Fatalf("expected ErrDecryptionFailed, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecryptInvalidCiphertext(t *testing.T) {
|
||||
key, _ := GenerateKey()
|
||||
_, err := Decrypt(key, []byte("short"))
|
||||
if err != ErrInvalidCiphertext {
|
||||
t.Fatalf("expected ErrInvalidCiphertext, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeriveKey(t *testing.T) {
|
||||
password := []byte("test-password")
|
||||
salt, _ := GenerateSalt()
|
||||
params := Argon2Params{Time: 1, Memory: 64 * 1024, Threads: 1}
|
||||
|
||||
key := DeriveKey(password, salt, params)
|
||||
if len(key) != KeySize {
|
||||
t.Fatalf("derived key length: got %d, want %d", len(key), KeySize)
|
||||
}
|
||||
|
||||
// Same inputs should produce same output.
|
||||
key2 := DeriveKey(password, salt, params)
|
||||
if !bytes.Equal(key, key2) {
|
||||
t.Fatal("determinism: same inputs produced different keys")
|
||||
}
|
||||
|
||||
// Different password should produce different output.
|
||||
key3 := DeriveKey([]byte("different"), salt, params)
|
||||
if bytes.Equal(key, key3) {
|
||||
t.Fatal("different passwords produced same key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestZeroize(t *testing.T) {
|
||||
data := []byte{1, 2, 3, 4, 5}
|
||||
Zeroize(data)
|
||||
for i, b := range data {
|
||||
if b != 0 {
|
||||
t.Fatalf("byte %d not zeroed: %d", i, b)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConstantTimeEqual(t *testing.T) {
|
||||
a := []byte("hello")
|
||||
b := []byte("hello")
|
||||
c := []byte("world")
|
||||
|
||||
if !ConstantTimeEqual(a, b) {
|
||||
t.Fatal("equal slices reported as not equal")
|
||||
}
|
||||
if ConstantTimeEqual(a, c) {
|
||||
t.Fatal("different slices reported as equal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptProducesDifferentCiphertext(t *testing.T) {
|
||||
key, _ := GenerateKey()
|
||||
plaintext := []byte("same data")
|
||||
|
||||
ct1, _ := Encrypt(key, plaintext)
|
||||
ct2, _ := Encrypt(key, plaintext)
|
||||
|
||||
if bytes.Equal(ct1, ct2) {
|
||||
t.Fatal("two encryptions of same plaintext produced identical ciphertext (nonce reuse)")
|
||||
}
|
||||
}
|
||||
43
internal/db/db.go
Normal file
43
internal/db/db.go
Normal file
@@ -0,0 +1,43 @@
|
||||
// Package db provides SQLite database access and migrations.
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
// Open opens or creates a SQLite database at the given path with secure
|
||||
// file permissions (0600) and WAL mode enabled.
|
||||
func Open(path string) (*sql.DB, error) {
|
||||
// Ensure the file has restrictive permissions if it doesn't exist yet.
|
||||
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||||
f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, 0600)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db: create file: %w", err)
|
||||
}
|
||||
f.Close()
|
||||
}
|
||||
|
||||
db, err := sql.Open("sqlite", path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db: open: %w", err)
|
||||
}
|
||||
|
||||
// Enable WAL mode and foreign keys.
|
||||
pragmas := []string{
|
||||
"PRAGMA journal_mode=WAL",
|
||||
"PRAGMA foreign_keys=ON",
|
||||
"PRAGMA busy_timeout=5000",
|
||||
}
|
||||
for _, p := range pragmas {
|
||||
if _, err := db.Exec(p); err != nil {
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("db: pragma %q: %w", p, err)
|
||||
}
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
44
internal/db/db_test.go
Normal file
44
internal/db/db_test.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestOpenAndMigrate(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "test.db")
|
||||
|
||||
database, err := Open(path)
|
||||
if err != nil {
|
||||
t.Fatalf("Open: %v", err)
|
||||
}
|
||||
defer database.Close()
|
||||
|
||||
if err := Migrate(database); err != nil {
|
||||
t.Fatalf("Migrate: %v", err)
|
||||
}
|
||||
|
||||
// Verify tables exist.
|
||||
tables := []string{"seal_config", "barrier_entries", "schema_migrations"}
|
||||
for _, table := range tables {
|
||||
var name string
|
||||
err := database.QueryRow(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name=?", table).Scan(&name)
|
||||
if err != nil {
|
||||
t.Errorf("table %q not found: %v", table, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Migration should be idempotent.
|
||||
if err := Migrate(database); err != nil {
|
||||
t.Fatalf("second Migrate: %v", err)
|
||||
}
|
||||
|
||||
// Check migration version.
|
||||
var version int
|
||||
database.QueryRow("SELECT MAX(version) FROM schema_migrations").Scan(&version)
|
||||
if version != 1 {
|
||||
t.Errorf("migration version: got %d, want 1", version)
|
||||
}
|
||||
}
|
||||
70
internal/db/migrate.go
Normal file
70
internal/db/migrate.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// migrations is an ordered list of SQL DDL statements. Each index is the
|
||||
// migration version (1-based).
|
||||
var migrations = []string{
|
||||
// Version 1: initial schema
|
||||
`CREATE TABLE IF NOT EXISTS seal_config (
|
||||
id INTEGER PRIMARY KEY CHECK (id = 1),
|
||||
encrypted_mek BLOB NOT NULL,
|
||||
kdf_salt BLOB NOT NULL,
|
||||
argon2_time INTEGER NOT NULL,
|
||||
argon2_memory INTEGER NOT NULL,
|
||||
argon2_threads INTEGER NOT NULL,
|
||||
initialized_at DATETIME NOT NULL DEFAULT (datetime('now'))
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS barrier_entries (
|
||||
path TEXT PRIMARY KEY,
|
||||
value BLOB NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT (datetime('now')),
|
||||
updated_at DATETIME NOT NULL DEFAULT (datetime('now'))
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS schema_migrations (
|
||||
version INTEGER PRIMARY KEY,
|
||||
applied_at DATETIME NOT NULL DEFAULT (datetime('now'))
|
||||
);`,
|
||||
}
|
||||
|
||||
// Migrate applies all pending migrations.
|
||||
func Migrate(db *sql.DB) error {
|
||||
// Ensure the migrations table exists (bootstrap).
|
||||
if _, err := db.Exec(`CREATE TABLE IF NOT EXISTS schema_migrations (
|
||||
version INTEGER PRIMARY KEY,
|
||||
applied_at DATETIME NOT NULL DEFAULT (datetime('now'))
|
||||
)`); err != nil {
|
||||
return fmt.Errorf("db: create migrations table: %w", err)
|
||||
}
|
||||
|
||||
var current int
|
||||
row := db.QueryRow("SELECT COALESCE(MAX(version), 0) FROM schema_migrations")
|
||||
if err := row.Scan(¤t); err != nil {
|
||||
return fmt.Errorf("db: get migration version: %w", err)
|
||||
}
|
||||
|
||||
for i := current; i < len(migrations); i++ {
|
||||
version := i + 1
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return fmt.Errorf("db: begin migration %d: %w", version, err)
|
||||
}
|
||||
if _, err := tx.Exec(migrations[i]); err != nil {
|
||||
tx.Rollback()
|
||||
return fmt.Errorf("db: migration %d: %w", version, err)
|
||||
}
|
||||
if _, err := tx.Exec("INSERT INTO schema_migrations (version) VALUES (?)", version); err != nil {
|
||||
tx.Rollback()
|
||||
return fmt.Errorf("db: record migration %d: %w", version, err)
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return fmt.Errorf("db: commit migration %d: %w", version, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
179
internal/engine/engine.go
Normal file
179
internal/engine/engine.go
Normal file
@@ -0,0 +1,179 @@
|
||||
// Package engine defines the Engine interface and mount registry.
|
||||
// Phase 1: interface and registry only, no concrete implementations.
|
||||
package engine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/barrier"
|
||||
)
|
||||
|
||||
// EngineType identifies a cryptographic engine type.
|
||||
type EngineType string
|
||||
|
||||
const (
|
||||
EngineTypeCA EngineType = "ca"
|
||||
EngineTypeSSHCA EngineType = "sshca"
|
||||
EngineTypeTransit EngineType = "transit"
|
||||
EngineTypeUser EngineType = "user"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrMountExists = errors.New("engine: mount already exists")
|
||||
ErrMountNotFound = errors.New("engine: mount not found")
|
||||
ErrUnknownType = errors.New("engine: unknown engine type")
|
||||
)
|
||||
|
||||
// Request is a request to an engine.
|
||||
type Request struct {
|
||||
Operation string
|
||||
Path string
|
||||
Data map[string]interface{}
|
||||
}
|
||||
|
||||
// Response is a response from an engine.
|
||||
type Response struct {
|
||||
Data map[string]interface{}
|
||||
}
|
||||
|
||||
// Engine is the interface that all cryptographic engines must implement.
|
||||
type Engine interface {
|
||||
// Type returns the engine type.
|
||||
Type() EngineType
|
||||
// Initialize sets up the engine for first use.
|
||||
Initialize(ctx context.Context, b barrier.Barrier, mountPath string) error
|
||||
// Unseal opens the engine using state from the barrier.
|
||||
Unseal(ctx context.Context, b barrier.Barrier, mountPath string) error
|
||||
// Seal closes the engine and zeroizes key material.
|
||||
Seal() error
|
||||
// HandleRequest processes a request.
|
||||
HandleRequest(ctx context.Context, req *Request) (*Response, error)
|
||||
}
|
||||
|
||||
// Factory creates a new engine instance of a given type.
|
||||
type Factory func() Engine
|
||||
|
||||
// Mount represents a mounted engine instance.
|
||||
type Mount struct {
|
||||
Name string `json:"name"`
|
||||
Type EngineType `json:"type"`
|
||||
MountPath string `json:"mount_path"`
|
||||
engine Engine
|
||||
}
|
||||
|
||||
// Registry manages mounted engine instances.
|
||||
type Registry struct {
|
||||
mu sync.RWMutex
|
||||
mounts map[string]*Mount
|
||||
factories map[EngineType]Factory
|
||||
barrier barrier.Barrier
|
||||
}
|
||||
|
||||
// NewRegistry creates a new engine registry.
|
||||
func NewRegistry(b barrier.Barrier) *Registry {
|
||||
return &Registry{
|
||||
mounts: make(map[string]*Mount),
|
||||
factories: make(map[EngineType]Factory),
|
||||
barrier: b,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterFactory registers a factory for the given engine type.
|
||||
func (r *Registry) RegisterFactory(t EngineType, f Factory) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.factories[t] = f
|
||||
}
|
||||
|
||||
// Mount creates and initializes a new engine mount.
|
||||
func (r *Registry) Mount(ctx context.Context, name string, engineType EngineType) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if _, exists := r.mounts[name]; exists {
|
||||
return ErrMountExists
|
||||
}
|
||||
|
||||
factory, ok := r.factories[engineType]
|
||||
if !ok {
|
||||
return fmt.Errorf("%w: %s", ErrUnknownType, engineType)
|
||||
}
|
||||
|
||||
eng := factory()
|
||||
mountPath := fmt.Sprintf("engine/%s/%s/", engineType, name)
|
||||
|
||||
if err := eng.Initialize(ctx, r.barrier, mountPath); err != nil {
|
||||
return fmt.Errorf("engine: initialize %q: %w", name, err)
|
||||
}
|
||||
|
||||
r.mounts[name] = &Mount{
|
||||
Name: name,
|
||||
Type: engineType,
|
||||
MountPath: mountPath,
|
||||
engine: eng,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unmount removes and seals an engine mount.
|
||||
func (r *Registry) Unmount(name string) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
mount, exists := r.mounts[name]
|
||||
if !exists {
|
||||
return ErrMountNotFound
|
||||
}
|
||||
|
||||
if err := mount.engine.Seal(); err != nil {
|
||||
return fmt.Errorf("engine: seal %q: %w", name, err)
|
||||
}
|
||||
|
||||
delete(r.mounts, name)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListMounts returns all current mounts.
|
||||
func (r *Registry) ListMounts() []Mount {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
mounts := make([]Mount, 0, len(r.mounts))
|
||||
for _, m := range r.mounts {
|
||||
mounts = append(mounts, Mount{
|
||||
Name: m.Name,
|
||||
Type: m.Type,
|
||||
MountPath: m.MountPath,
|
||||
})
|
||||
}
|
||||
return mounts
|
||||
}
|
||||
|
||||
// HandleRequest routes a request to the appropriate engine.
|
||||
func (r *Registry) HandleRequest(ctx context.Context, mountName string, req *Request) (*Response, error) {
|
||||
r.mu.RLock()
|
||||
mount, exists := r.mounts[mountName]
|
||||
r.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil, ErrMountNotFound
|
||||
}
|
||||
|
||||
return mount.engine.HandleRequest(ctx, req)
|
||||
}
|
||||
|
||||
// SealAll seals all mounted engines.
|
||||
func (r *Registry) SealAll() error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
for name, mount := range r.mounts {
|
||||
if err := mount.engine.Seal(); err != nil {
|
||||
return fmt.Errorf("engine: seal %q: %w", name, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
120
internal/engine/engine_test.go
Normal file
120
internal/engine/engine_test.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package engine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/barrier"
|
||||
)
|
||||
|
||||
// mockEngine implements Engine for testing.
|
||||
type mockEngine struct {
|
||||
engineType EngineType
|
||||
initialized bool
|
||||
unsealed bool
|
||||
}
|
||||
|
||||
func (m *mockEngine) Type() EngineType { return m.engineType }
|
||||
func (m *mockEngine) Initialize(_ context.Context, _ barrier.Barrier, _ string) error { m.initialized = true; return nil }
|
||||
func (m *mockEngine) Unseal(_ context.Context, _ barrier.Barrier, _ string) error { m.unsealed = true; return nil }
|
||||
func (m *mockEngine) Seal() error { m.unsealed = false; return nil }
|
||||
func (m *mockEngine) HandleRequest(_ context.Context, _ *Request) (*Response, error) {
|
||||
return &Response{Data: map[string]interface{}{"ok": true}}, nil
|
||||
}
|
||||
|
||||
type mockBarrier struct{}
|
||||
|
||||
func (m *mockBarrier) Unseal(_ []byte) error { return nil }
|
||||
func (m *mockBarrier) Seal() error { return nil }
|
||||
func (m *mockBarrier) IsSealed() bool { return false }
|
||||
func (m *mockBarrier) Get(_ context.Context, _ string) ([]byte, error) { return nil, barrier.ErrNotFound }
|
||||
func (m *mockBarrier) Put(_ context.Context, _ string, _ []byte) error { return nil }
|
||||
func (m *mockBarrier) Delete(_ context.Context, _ string) error { return nil }
|
||||
func (m *mockBarrier) List(_ context.Context, _ string) ([]string, error) { return nil, nil }
|
||||
|
||||
func TestRegistryMountUnmount(t *testing.T) {
|
||||
reg := NewRegistry(&mockBarrier{})
|
||||
reg.RegisterFactory(EngineTypeTransit, func() Engine {
|
||||
return &mockEngine{engineType: EngineTypeTransit}
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
if err := reg.Mount(ctx, "default", EngineTypeTransit); err != nil {
|
||||
t.Fatalf("Mount: %v", err)
|
||||
}
|
||||
|
||||
mounts := reg.ListMounts()
|
||||
if len(mounts) != 1 {
|
||||
t.Fatalf("ListMounts: got %d, want 1", len(mounts))
|
||||
}
|
||||
if mounts[0].Name != "default" {
|
||||
t.Errorf("mount name: got %q, want %q", mounts[0].Name, "default")
|
||||
}
|
||||
|
||||
// Duplicate mount should fail.
|
||||
if err := reg.Mount(ctx, "default", EngineTypeTransit); err != ErrMountExists {
|
||||
t.Fatalf("expected ErrMountExists, got: %v", err)
|
||||
}
|
||||
|
||||
if err := reg.Unmount("default"); err != nil {
|
||||
t.Fatalf("Unmount: %v", err)
|
||||
}
|
||||
|
||||
mounts = reg.ListMounts()
|
||||
if len(mounts) != 0 {
|
||||
t.Fatalf("after unmount: got %d mounts", len(mounts))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryUnmountNotFound(t *testing.T) {
|
||||
reg := NewRegistry(&mockBarrier{})
|
||||
if err := reg.Unmount("nonexistent"); err != ErrMountNotFound {
|
||||
t.Fatalf("expected ErrMountNotFound, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryUnknownType(t *testing.T) {
|
||||
reg := NewRegistry(&mockBarrier{})
|
||||
err := reg.Mount(context.Background(), "test", EngineTypeTransit)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unknown engine type")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryHandleRequest(t *testing.T) {
|
||||
reg := NewRegistry(&mockBarrier{})
|
||||
reg.RegisterFactory(EngineTypeTransit, func() Engine {
|
||||
return &mockEngine{engineType: EngineTypeTransit}
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
reg.Mount(ctx, "test", EngineTypeTransit)
|
||||
|
||||
resp, err := reg.HandleRequest(ctx, "test", &Request{Operation: "encrypt"})
|
||||
if err != nil {
|
||||
t.Fatalf("HandleRequest: %v", err)
|
||||
}
|
||||
if resp.Data["ok"] != true {
|
||||
t.Error("expected ok=true in response")
|
||||
}
|
||||
|
||||
_, err = reg.HandleRequest(ctx, "nonexistent", &Request{})
|
||||
if err != ErrMountNotFound {
|
||||
t.Fatalf("expected ErrMountNotFound, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistrySealAll(t *testing.T) {
|
||||
reg := NewRegistry(&mockBarrier{})
|
||||
reg.RegisterFactory(EngineTypeTransit, func() Engine {
|
||||
return &mockEngine{engineType: EngineTypeTransit}
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
reg.Mount(ctx, "eng1", EngineTypeTransit)
|
||||
reg.Mount(ctx, "eng2", EngineTypeTransit)
|
||||
|
||||
if err := reg.SealAll(); err != nil {
|
||||
t.Fatalf("SealAll: %v", err)
|
||||
}
|
||||
}
|
||||
188
internal/policy/policy.go
Normal file
188
internal/policy/policy.go
Normal file
@@ -0,0 +1,188 @@
|
||||
// Package policy implements the Metacrypt policy engine with priority-based ACL rules.
|
||||
package policy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/barrier"
|
||||
)
|
||||
|
||||
const rulesPrefix = "policy/rules/"
|
||||
|
||||
// Effect represents a policy decision.
|
||||
type Effect string
|
||||
|
||||
const (
|
||||
EffectAllow Effect = "allow"
|
||||
EffectDeny Effect = "deny"
|
||||
)
|
||||
|
||||
// Rule is a policy rule stored in the barrier.
|
||||
type Rule struct {
|
||||
ID string `json:"id"`
|
||||
Priority int `json:"priority"`
|
||||
Effect Effect `json:"effect"`
|
||||
Usernames []string `json:"usernames,omitempty"` // match specific users
|
||||
Roles []string `json:"roles,omitempty"` // match roles
|
||||
Resources []string `json:"resources,omitempty"` // glob patterns for engine mounts/paths
|
||||
Actions []string `json:"actions,omitempty"` // e.g., "read", "write", "admin"
|
||||
}
|
||||
|
||||
// Request represents an authorization request.
|
||||
type Request struct {
|
||||
Username string
|
||||
Roles []string
|
||||
Resource string // e.g., "engine/transit/default/encrypt"
|
||||
Action string // e.g., "write"
|
||||
}
|
||||
|
||||
// Engine evaluates policy rules from the barrier.
|
||||
type Engine struct {
|
||||
barrier barrier.Barrier
|
||||
}
|
||||
|
||||
// NewEngine creates a new policy engine.
|
||||
func NewEngine(b barrier.Barrier) *Engine {
|
||||
return &Engine{barrier: b}
|
||||
}
|
||||
|
||||
// Evaluate checks if the request is allowed. Admin role always allows.
|
||||
// Otherwise: collect matching rules, sort by priority (lower = higher priority),
|
||||
// first match wins, default deny.
|
||||
func (e *Engine) Evaluate(ctx context.Context, req *Request) (Effect, error) {
|
||||
// Admin bypass.
|
||||
for _, r := range req.Roles {
|
||||
if r == "admin" {
|
||||
return EffectAllow, nil
|
||||
}
|
||||
}
|
||||
|
||||
rules, err := e.listRules(ctx)
|
||||
if err != nil {
|
||||
return EffectDeny, err
|
||||
}
|
||||
|
||||
// Sort by priority ascending (lower number = higher priority).
|
||||
sort.Slice(rules, func(i, j int) bool {
|
||||
return rules[i].Priority < rules[j].Priority
|
||||
})
|
||||
|
||||
for _, rule := range rules {
|
||||
if matchesRule(&rule, req) {
|
||||
return rule.Effect, nil
|
||||
}
|
||||
}
|
||||
|
||||
return EffectDeny, nil // default deny
|
||||
}
|
||||
|
||||
// CreateRule stores a new policy rule.
|
||||
func (e *Engine) CreateRule(ctx context.Context, rule *Rule) error {
|
||||
data, err := json.Marshal(rule)
|
||||
if err != nil {
|
||||
return fmt.Errorf("policy: marshal rule: %w", err)
|
||||
}
|
||||
return e.barrier.Put(ctx, rulesPrefix+rule.ID, data)
|
||||
}
|
||||
|
||||
// GetRule retrieves a policy rule by ID.
|
||||
func (e *Engine) GetRule(ctx context.Context, id string) (*Rule, error) {
|
||||
data, err := e.barrier.Get(ctx, rulesPrefix+id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var rule Rule
|
||||
if err := json.Unmarshal(data, &rule); err != nil {
|
||||
return nil, fmt.Errorf("policy: unmarshal rule: %w", err)
|
||||
}
|
||||
return &rule, nil
|
||||
}
|
||||
|
||||
// DeleteRule removes a policy rule.
|
||||
func (e *Engine) DeleteRule(ctx context.Context, id string) error {
|
||||
return e.barrier.Delete(ctx, rulesPrefix+id)
|
||||
}
|
||||
|
||||
// ListRules returns all policy rules.
|
||||
func (e *Engine) ListRules(ctx context.Context) ([]Rule, error) {
|
||||
return e.listRules(ctx)
|
||||
}
|
||||
|
||||
func (e *Engine) listRules(ctx context.Context) ([]Rule, error) {
|
||||
paths, err := e.barrier.List(ctx, rulesPrefix)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("policy: list rules: %w", err)
|
||||
}
|
||||
|
||||
var rules []Rule
|
||||
for _, p := range paths {
|
||||
data, err := e.barrier.Get(ctx, rulesPrefix+p)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("policy: get rule %q: %w", p, err)
|
||||
}
|
||||
var rule Rule
|
||||
if err := json.Unmarshal(data, &rule); err != nil {
|
||||
return nil, fmt.Errorf("policy: unmarshal rule %q: %w", p, err)
|
||||
}
|
||||
rules = append(rules, rule)
|
||||
}
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
func matchesRule(rule *Rule, req *Request) bool {
|
||||
// Check username match.
|
||||
if len(rule.Usernames) > 0 && !containsString(rule.Usernames, req.Username) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check role match.
|
||||
if len(rule.Roles) > 0 && !hasAnyRole(rule.Roles, req.Roles) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check resource match (glob patterns).
|
||||
if len(rule.Resources) > 0 && !matchesAnyGlob(rule.Resources, req.Resource) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check action match.
|
||||
if len(rule.Actions) > 0 && !containsString(rule.Actions, req.Action) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func containsString(haystack []string, needle string) bool {
|
||||
for _, s := range haystack {
|
||||
if strings.EqualFold(s, needle) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func hasAnyRole(required, actual []string) bool {
|
||||
for _, r := range required {
|
||||
for _, a := range actual {
|
||||
if strings.EqualFold(r, a) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func matchesAnyGlob(patterns []string, value string) bool {
|
||||
for _, p := range patterns {
|
||||
if matched, _ := filepath.Match(p, value); matched {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
177
internal/policy/policy_test.go
Normal file
177
internal/policy/policy_test.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package policy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/barrier"
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/crypto"
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/db"
|
||||
)
|
||||
|
||||
func setupPolicy(t *testing.T) (*Engine, func()) {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
database, err := db.Open(filepath.Join(dir, "test.db"))
|
||||
if err != nil {
|
||||
t.Fatalf("open db: %v", err)
|
||||
}
|
||||
if err := db.Migrate(database); err != nil {
|
||||
t.Fatalf("migrate: %v", err)
|
||||
}
|
||||
b := barrier.NewAESGCMBarrier(database)
|
||||
mek, _ := crypto.GenerateKey()
|
||||
b.Unseal(mek)
|
||||
e := NewEngine(b)
|
||||
return e, func() { database.Close() }
|
||||
}
|
||||
|
||||
func TestAdminBypass(t *testing.T) {
|
||||
e, cleanup := setupPolicy(t)
|
||||
defer cleanup()
|
||||
|
||||
effect, err := e.Evaluate(context.Background(), &Request{
|
||||
Username: "admin-user",
|
||||
Roles: []string{"admin"},
|
||||
Resource: "engine/transit/default/encrypt",
|
||||
Action: "write",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Evaluate: %v", err)
|
||||
}
|
||||
if effect != EffectAllow {
|
||||
t.Fatalf("admin should always be allowed, got: %s", effect)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultDeny(t *testing.T) {
|
||||
e, cleanup := setupPolicy(t)
|
||||
defer cleanup()
|
||||
|
||||
effect, err := e.Evaluate(context.Background(), &Request{
|
||||
Username: "user1",
|
||||
Roles: []string{"viewer"},
|
||||
Resource: "engine/transit/default/encrypt",
|
||||
Action: "write",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Evaluate: %v", err)
|
||||
}
|
||||
if effect != EffectDeny {
|
||||
t.Fatalf("default should deny, got: %s", effect)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPolicyRuleCRUD(t *testing.T) {
|
||||
e, cleanup := setupPolicy(t)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
rule := &Rule{
|
||||
ID: "test-rule",
|
||||
Priority: 100,
|
||||
Effect: EffectAllow,
|
||||
Roles: []string{"operator"},
|
||||
Resources: []string{"engine/transit/*"},
|
||||
Actions: []string{"read", "write"},
|
||||
}
|
||||
|
||||
if err := e.CreateRule(ctx, rule); err != nil {
|
||||
t.Fatalf("CreateRule: %v", err)
|
||||
}
|
||||
|
||||
got, err := e.GetRule(ctx, "test-rule")
|
||||
if err != nil {
|
||||
t.Fatalf("GetRule: %v", err)
|
||||
}
|
||||
if got.Priority != 100 {
|
||||
t.Errorf("priority: got %d, want 100", got.Priority)
|
||||
}
|
||||
|
||||
rules, err := e.ListRules(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("ListRules: %v", err)
|
||||
}
|
||||
if len(rules) != 1 {
|
||||
t.Fatalf("ListRules: got %d rules, want 1", len(rules))
|
||||
}
|
||||
|
||||
if err := e.DeleteRule(ctx, "test-rule"); err != nil {
|
||||
t.Fatalf("DeleteRule: %v", err)
|
||||
}
|
||||
|
||||
rules, _ = e.ListRules(ctx)
|
||||
if len(rules) != 0 {
|
||||
t.Fatalf("after delete: got %d rules, want 0", len(rules))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPolicyPriorityOrder(t *testing.T) {
|
||||
e, cleanup := setupPolicy(t)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
// Lower priority number = higher priority. Deny should win.
|
||||
e.CreateRule(ctx, &Rule{
|
||||
ID: "allow-rule",
|
||||
Priority: 200,
|
||||
Effect: EffectAllow,
|
||||
Roles: []string{"operator"},
|
||||
Resources: []string{"engine/transit/*"},
|
||||
Actions: []string{"write"},
|
||||
})
|
||||
e.CreateRule(ctx, &Rule{
|
||||
ID: "deny-rule",
|
||||
Priority: 100,
|
||||
Effect: EffectDeny,
|
||||
Roles: []string{"operator"},
|
||||
Resources: []string{"engine/transit/*"},
|
||||
Actions: []string{"write"},
|
||||
})
|
||||
|
||||
effect, _ := e.Evaluate(ctx, &Request{
|
||||
Username: "user1",
|
||||
Roles: []string{"operator"},
|
||||
Resource: "engine/transit/default",
|
||||
Action: "write",
|
||||
})
|
||||
if effect != EffectDeny {
|
||||
t.Fatalf("higher priority deny should win, got: %s", effect)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPolicyUsernameMatch(t *testing.T) {
|
||||
e, cleanup := setupPolicy(t)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
e.CreateRule(ctx, &Rule{
|
||||
ID: "user-specific",
|
||||
Priority: 100,
|
||||
Effect: EffectAllow,
|
||||
Usernames: []string{"alice"},
|
||||
Resources: []string{"engine/*"},
|
||||
Actions: []string{"read"},
|
||||
})
|
||||
|
||||
effect, _ := e.Evaluate(ctx, &Request{
|
||||
Username: "alice",
|
||||
Roles: []string{"user"},
|
||||
Resource: "engine/ca",
|
||||
Action: "read",
|
||||
})
|
||||
if effect != EffectAllow {
|
||||
t.Fatalf("alice should be allowed, got: %s", effect)
|
||||
}
|
||||
|
||||
effect, _ = e.Evaluate(ctx, &Request{
|
||||
Username: "bob",
|
||||
Roles: []string{"user"},
|
||||
Resource: "engine/ca",
|
||||
Action: "read",
|
||||
})
|
||||
if effect != EffectDeny {
|
||||
t.Fatalf("bob should be denied, got: %s", effect)
|
||||
}
|
||||
}
|
||||
242
internal/seal/seal.go
Normal file
242
internal/seal/seal.go
Normal file
@@ -0,0 +1,242 @@
|
||||
// Package seal implements the seal/unseal state machine for Metacrypt.
|
||||
package seal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/barrier"
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/crypto"
|
||||
)
|
||||
|
||||
// ServiceState represents the current state of the Metacrypt service.
|
||||
type ServiceState int
|
||||
|
||||
const (
|
||||
StateUninitialized ServiceState = iota
|
||||
StateSealed
|
||||
StateInitializing
|
||||
StateUnsealed
|
||||
)
|
||||
|
||||
func (s ServiceState) String() string {
|
||||
switch s {
|
||||
case StateUninitialized:
|
||||
return "uninitialized"
|
||||
case StateSealed:
|
||||
return "sealed"
|
||||
case StateInitializing:
|
||||
return "initializing"
|
||||
case StateUnsealed:
|
||||
return "unsealed"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
ErrAlreadyInitialized = errors.New("seal: already initialized")
|
||||
ErrNotInitialized = errors.New("seal: not initialized")
|
||||
ErrInvalidPassword = errors.New("seal: invalid password")
|
||||
ErrSealed = errors.New("seal: service is sealed")
|
||||
ErrNotSealed = errors.New("seal: service is not sealed")
|
||||
ErrRateLimited = errors.New("seal: too many unseal attempts, try again later")
|
||||
)
|
||||
|
||||
// Manager manages the seal/unseal lifecycle.
|
||||
type Manager struct {
|
||||
db *sql.DB
|
||||
barrier *barrier.AESGCMBarrier
|
||||
|
||||
mu sync.RWMutex
|
||||
state ServiceState
|
||||
mek []byte // nil when sealed
|
||||
|
||||
// Rate limiting for unseal attempts.
|
||||
unsealAttempts int
|
||||
lastAttempt time.Time
|
||||
lockoutUntil time.Time
|
||||
}
|
||||
|
||||
// NewManager creates a new seal manager.
|
||||
func NewManager(db *sql.DB, b *barrier.AESGCMBarrier) *Manager {
|
||||
return &Manager{
|
||||
db: db,
|
||||
barrier: b,
|
||||
state: StateUninitialized,
|
||||
}
|
||||
}
|
||||
|
||||
// State returns the current service state.
|
||||
func (m *Manager) State() ServiceState {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.state
|
||||
}
|
||||
|
||||
// CheckInitialized checks the database for an existing seal config and
|
||||
// updates the state accordingly. Should be called on startup.
|
||||
func (m *Manager) CheckInitialized() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
var count int
|
||||
err := m.db.QueryRow("SELECT COUNT(*) FROM seal_config").Scan(&count)
|
||||
if err != nil {
|
||||
return fmt.Errorf("seal: check initialized: %w", err)
|
||||
}
|
||||
if count > 0 {
|
||||
m.state = StateSealed
|
||||
} else {
|
||||
m.state = StateUninitialized
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Initialize performs first-time setup: generates MEK, encrypts it with the
|
||||
// password-derived KWK, and stores everything in seal_config.
|
||||
func (m *Manager) Initialize(ctx context.Context, password []byte, params crypto.Argon2Params) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.state != StateUninitialized {
|
||||
return ErrAlreadyInitialized
|
||||
}
|
||||
|
||||
m.state = StateInitializing
|
||||
defer func() {
|
||||
if m.mek == nil {
|
||||
// If we failed, go back to uninitialized.
|
||||
m.state = StateUninitialized
|
||||
}
|
||||
}()
|
||||
|
||||
// Generate salt and MEK.
|
||||
salt, err := crypto.GenerateSalt()
|
||||
if err != nil {
|
||||
return fmt.Errorf("seal: generate salt: %w", err)
|
||||
}
|
||||
|
||||
mek, err := crypto.GenerateKey()
|
||||
if err != nil {
|
||||
return fmt.Errorf("seal: generate mek: %w", err)
|
||||
}
|
||||
|
||||
// Derive KWK from password.
|
||||
kwk := crypto.DeriveKey(password, salt, params)
|
||||
defer crypto.Zeroize(kwk)
|
||||
|
||||
// Encrypt MEK with KWK.
|
||||
encryptedMEK, err := crypto.Encrypt(kwk, mek)
|
||||
if err != nil {
|
||||
crypto.Zeroize(mek)
|
||||
return fmt.Errorf("seal: encrypt mek: %w", err)
|
||||
}
|
||||
|
||||
// Store in database.
|
||||
_, err = m.db.ExecContext(ctx, `
|
||||
INSERT INTO seal_config (id, encrypted_mek, kdf_salt, argon2_time, argon2_memory, argon2_threads)
|
||||
VALUES (1, ?, ?, ?, ?, ?)`,
|
||||
encryptedMEK, salt, params.Time, params.Memory, params.Threads)
|
||||
if err != nil {
|
||||
crypto.Zeroize(mek)
|
||||
return fmt.Errorf("seal: store config: %w", err)
|
||||
}
|
||||
|
||||
// Unseal the barrier with the MEK.
|
||||
if err := m.barrier.Unseal(mek); err != nil {
|
||||
crypto.Zeroize(mek)
|
||||
return fmt.Errorf("seal: unseal barrier: %w", err)
|
||||
}
|
||||
|
||||
m.mek = mek
|
||||
m.state = StateUnsealed
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unseal decrypts the MEK using the provided password and unseals the barrier.
|
||||
func (m *Manager) Unseal(password []byte) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.state == StateUninitialized {
|
||||
return ErrNotInitialized
|
||||
}
|
||||
if m.state == StateUnsealed {
|
||||
return ErrNotSealed
|
||||
}
|
||||
|
||||
// Rate limiting.
|
||||
now := time.Now()
|
||||
if now.Before(m.lockoutUntil) {
|
||||
return ErrRateLimited
|
||||
}
|
||||
if now.Sub(m.lastAttempt) > time.Minute {
|
||||
m.unsealAttempts = 0
|
||||
}
|
||||
m.unsealAttempts++
|
||||
m.lastAttempt = now
|
||||
if m.unsealAttempts > 5 {
|
||||
m.lockoutUntil = now.Add(60 * time.Second)
|
||||
m.unsealAttempts = 0
|
||||
return ErrRateLimited
|
||||
}
|
||||
|
||||
// Read seal config.
|
||||
var (
|
||||
encryptedMEK []byte
|
||||
salt []byte
|
||||
argTime, argMem uint32
|
||||
argThreads uint8
|
||||
)
|
||||
err := m.db.QueryRow(`
|
||||
SELECT encrypted_mek, kdf_salt, argon2_time, argon2_memory, argon2_threads
|
||||
FROM seal_config WHERE id = 1`).Scan(&encryptedMEK, &salt, &argTime, &argMem, &argThreads)
|
||||
if err != nil {
|
||||
return fmt.Errorf("seal: read config: %w", err)
|
||||
}
|
||||
|
||||
params := crypto.Argon2Params{Time: argTime, Memory: argMem, Threads: argThreads}
|
||||
|
||||
// Derive KWK and decrypt MEK.
|
||||
kwk := crypto.DeriveKey(password, salt, params)
|
||||
defer crypto.Zeroize(kwk)
|
||||
|
||||
mek, err := crypto.Decrypt(kwk, encryptedMEK)
|
||||
if err != nil {
|
||||
return ErrInvalidPassword
|
||||
}
|
||||
|
||||
// Unseal the barrier.
|
||||
if err := m.barrier.Unseal(mek); err != nil {
|
||||
crypto.Zeroize(mek)
|
||||
return fmt.Errorf("seal: unseal barrier: %w", err)
|
||||
}
|
||||
|
||||
m.mek = mek
|
||||
m.state = StateUnsealed
|
||||
m.unsealAttempts = 0
|
||||
return nil
|
||||
}
|
||||
|
||||
// Seal seals the service: zeroizes MEK, seals the barrier.
|
||||
func (m *Manager) Seal() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.state != StateUnsealed {
|
||||
return ErrNotSealed
|
||||
}
|
||||
|
||||
if m.mek != nil {
|
||||
crypto.Zeroize(m.mek)
|
||||
m.mek = nil
|
||||
}
|
||||
m.barrier.Seal()
|
||||
m.state = StateSealed
|
||||
return nil
|
||||
}
|
||||
136
internal/seal/seal_test.go
Normal file
136
internal/seal/seal_test.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package seal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/barrier"
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/crypto"
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/db"
|
||||
)
|
||||
|
||||
func setupSeal(t *testing.T) (*Manager, func()) {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
database, err := db.Open(filepath.Join(dir, "test.db"))
|
||||
if err != nil {
|
||||
t.Fatalf("open db: %v", err)
|
||||
}
|
||||
if err := db.Migrate(database); err != nil {
|
||||
t.Fatalf("migrate: %v", err)
|
||||
}
|
||||
b := barrier.NewAESGCMBarrier(database)
|
||||
mgr := NewManager(database, b)
|
||||
return mgr, func() { database.Close() }
|
||||
}
|
||||
|
||||
func TestSealInitializeAndUnseal(t *testing.T) {
|
||||
mgr, cleanup := setupSeal(t)
|
||||
defer cleanup()
|
||||
|
||||
if err := mgr.CheckInitialized(); err != nil {
|
||||
t.Fatalf("CheckInitialized: %v", err)
|
||||
}
|
||||
if mgr.State() != StateUninitialized {
|
||||
t.Fatalf("state: got %v, want Uninitialized", mgr.State())
|
||||
}
|
||||
|
||||
password := []byte("test-password-123")
|
||||
// Use fast params for testing.
|
||||
params := crypto.Argon2Params{Time: 1, Memory: 64 * 1024, Threads: 1}
|
||||
|
||||
if err := mgr.Initialize(context.Background(), password, params); err != nil {
|
||||
t.Fatalf("Initialize: %v", err)
|
||||
}
|
||||
if mgr.State() != StateUnsealed {
|
||||
t.Fatalf("state after init: got %v, want Unsealed", mgr.State())
|
||||
}
|
||||
|
||||
// Seal.
|
||||
if err := mgr.Seal(); err != nil {
|
||||
t.Fatalf("Seal: %v", err)
|
||||
}
|
||||
if mgr.State() != StateSealed {
|
||||
t.Fatalf("state after seal: got %v, want Sealed", mgr.State())
|
||||
}
|
||||
|
||||
// Unseal with correct password.
|
||||
if err := mgr.Unseal(password); err != nil {
|
||||
t.Fatalf("Unseal: %v", err)
|
||||
}
|
||||
if mgr.State() != StateUnsealed {
|
||||
t.Fatalf("state after unseal: got %v, want Unsealed", mgr.State())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSealWrongPassword(t *testing.T) {
|
||||
mgr, cleanup := setupSeal(t)
|
||||
defer cleanup()
|
||||
mgr.CheckInitialized()
|
||||
|
||||
params := crypto.Argon2Params{Time: 1, Memory: 64 * 1024, Threads: 1}
|
||||
mgr.Initialize(context.Background(), []byte("correct"), params)
|
||||
mgr.Seal()
|
||||
|
||||
err := mgr.Unseal([]byte("wrong"))
|
||||
if err != ErrInvalidPassword {
|
||||
t.Fatalf("expected ErrInvalidPassword, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSealDoubleInitialize(t *testing.T) {
|
||||
mgr, cleanup := setupSeal(t)
|
||||
defer cleanup()
|
||||
mgr.CheckInitialized()
|
||||
|
||||
params := crypto.Argon2Params{Time: 1, Memory: 64 * 1024, Threads: 1}
|
||||
mgr.Initialize(context.Background(), []byte("password"), params)
|
||||
|
||||
err := mgr.Initialize(context.Background(), []byte("password"), params)
|
||||
if err != ErrAlreadyInitialized {
|
||||
t.Fatalf("expected ErrAlreadyInitialized, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSealCheckInitializedPersists(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
dbPath := filepath.Join(dir, "test.db")
|
||||
|
||||
// First: initialize.
|
||||
database, _ := db.Open(dbPath)
|
||||
db.Migrate(database)
|
||||
b := barrier.NewAESGCMBarrier(database)
|
||||
mgr := NewManager(database, b)
|
||||
mgr.CheckInitialized()
|
||||
params := crypto.Argon2Params{Time: 1, Memory: 64 * 1024, Threads: 1}
|
||||
mgr.Initialize(context.Background(), []byte("password"), params)
|
||||
database.Close()
|
||||
|
||||
// Second: reopen and check.
|
||||
database2, _ := db.Open(dbPath)
|
||||
defer database2.Close()
|
||||
b2 := barrier.NewAESGCMBarrier(database2)
|
||||
mgr2 := NewManager(database2, b2)
|
||||
mgr2.CheckInitialized()
|
||||
if mgr2.State() != StateSealed {
|
||||
t.Fatalf("state after reopen: got %v, want Sealed", mgr2.State())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSealStateString(t *testing.T) {
|
||||
tests := []struct {
|
||||
state ServiceState
|
||||
want string
|
||||
}{
|
||||
{StateUninitialized, "uninitialized"},
|
||||
{StateSealed, "sealed"},
|
||||
{StateInitializing, "initializing"},
|
||||
{StateUnsealed, "unsealed"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if got := tt.state.String(); got != tt.want {
|
||||
t.Errorf("State(%d).String() = %q, want %q", tt.state, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
109
internal/server/middleware.go
Normal file
109
internal/server/middleware.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/auth"
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/seal"
|
||||
)
|
||||
|
||||
type contextKey string
|
||||
|
||||
const tokenInfoKey contextKey = "tokenInfo"
|
||||
|
||||
// TokenInfoFromContext extracts the validated token info from the request context.
|
||||
func TokenInfoFromContext(ctx context.Context) *auth.TokenInfo {
|
||||
info, _ := ctx.Value(tokenInfoKey).(*auth.TokenInfo)
|
||||
return info
|
||||
}
|
||||
|
||||
// loggingMiddleware logs HTTP requests, stripping sensitive headers.
|
||||
func (s *Server) loggingMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
sw := &statusWriter{ResponseWriter: w, status: 200}
|
||||
next.ServeHTTP(sw, r)
|
||||
s.logger.Info("http request",
|
||||
"method", r.Method,
|
||||
"path", r.URL.Path,
|
||||
"status", sw.status,
|
||||
"duration", time.Since(start),
|
||||
"remote", r.RemoteAddr,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
// requireUnseal rejects requests unless the service is unsealed.
|
||||
func (s *Server) requireUnseal(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
state := s.seal.State()
|
||||
switch state {
|
||||
case seal.StateUninitialized:
|
||||
http.Error(w, `{"error":"not initialized"}`, http.StatusPreconditionFailed)
|
||||
return
|
||||
case seal.StateSealed, seal.StateInitializing:
|
||||
http.Error(w, `{"error":"sealed"}`, http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// requireAuth validates the bearer token and injects TokenInfo into context.
|
||||
func (s *Server) requireAuth(next http.HandlerFunc) http.HandlerFunc {
|
||||
return s.requireUnseal(func(w http.ResponseWriter, r *http.Request) {
|
||||
token := extractToken(r)
|
||||
if token == "" {
|
||||
http.Error(w, `{"error":"unauthorized"}`, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
info, err := s.auth.ValidateToken(token)
|
||||
if err != nil {
|
||||
http.Error(w, `{"error":"unauthorized"}`, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), tokenInfoKey, info)
|
||||
next(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
// requireAdmin requires the authenticated user to have admin role.
|
||||
func (s *Server) requireAdmin(next http.HandlerFunc) http.HandlerFunc {
|
||||
return s.requireAuth(func(w http.ResponseWriter, r *http.Request) {
|
||||
info := TokenInfoFromContext(r.Context())
|
||||
if info == nil || !info.IsAdmin {
|
||||
http.Error(w, `{"error":"forbidden"}`, http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
next(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func extractToken(r *http.Request) string {
|
||||
// Check Authorization header first.
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if strings.HasPrefix(authHeader, "Bearer ") {
|
||||
return strings.TrimPrefix(authHeader, "Bearer ")
|
||||
}
|
||||
// Fall back to cookie.
|
||||
cookie, err := r.Cookie("metacrypt_token")
|
||||
if err == nil {
|
||||
return cookie.Value
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type statusWriter struct {
|
||||
http.ResponseWriter
|
||||
status int
|
||||
}
|
||||
|
||||
func (w *statusWriter) WriteHeader(code int) {
|
||||
w.status = code
|
||||
w.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
532
internal/server/routes.go
Normal file
532
internal/server/routes.go
Normal file
@@ -0,0 +1,532 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"html/template"
|
||||
"io"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
|
||||
mcias "git.wntrmute.dev/kyle/mcias/clients/go"
|
||||
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/crypto"
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/policy"
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/seal"
|
||||
)
|
||||
|
||||
func (s *Server) registerRoutes(mux *http.ServeMux) {
|
||||
// Static files.
|
||||
mux.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.Dir("web/static"))))
|
||||
|
||||
// Web UI routes.
|
||||
mux.HandleFunc("/", s.handleWebRoot)
|
||||
mux.HandleFunc("/init", s.handleWebInit)
|
||||
mux.HandleFunc("/unseal", s.handleWebUnseal)
|
||||
mux.HandleFunc("/login", s.handleWebLogin)
|
||||
mux.HandleFunc("/dashboard", s.requireAuthWeb(s.handleWebDashboard))
|
||||
|
||||
// API routes.
|
||||
mux.HandleFunc("/v1/status", s.handleStatus)
|
||||
mux.HandleFunc("/v1/init", s.handleInit)
|
||||
mux.HandleFunc("/v1/unseal", s.handleUnseal)
|
||||
mux.HandleFunc("/v1/seal", s.requireAdmin(s.handleSeal))
|
||||
|
||||
mux.HandleFunc("/v1/auth/login", s.handleLogin)
|
||||
mux.HandleFunc("/v1/auth/logout", s.requireAuth(s.handleLogout))
|
||||
mux.HandleFunc("/v1/auth/tokeninfo", s.requireAuth(s.handleTokenInfo))
|
||||
|
||||
mux.HandleFunc("/v1/engine/mounts", s.requireAuth(s.handleEngineMounts))
|
||||
mux.HandleFunc("/v1/engine/mount", s.requireAdmin(s.handleEngineMount))
|
||||
mux.HandleFunc("/v1/engine/unmount", s.requireAdmin(s.handleEngineUnmount))
|
||||
mux.HandleFunc("/v1/engine/request", s.requireAuth(s.handleEngineRequest))
|
||||
|
||||
mux.HandleFunc("/v1/policy/rules", s.requireAuth(s.handlePolicyRules))
|
||||
mux.HandleFunc("/v1/policy/rule", s.requireAuth(s.handlePolicyRule))
|
||||
}
|
||||
|
||||
// --- API Handlers ---
|
||||
|
||||
func (s *Server) handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, `{"error":"method not allowed"}`, http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"state": s.seal.State().String(),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) handleInit(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, `{"error":"method not allowed"}`, http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
var req struct {
|
||||
Password string `json:"password"`
|
||||
}
|
||||
if err := readJSON(r, &req); err != nil {
|
||||
http.Error(w, `{"error":"invalid request"}`, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if req.Password == "" {
|
||||
http.Error(w, `{"error":"password is required"}`, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
params := crypto.Argon2Params{
|
||||
Time: s.cfg.Seal.Argon2Time,
|
||||
Memory: s.cfg.Seal.Argon2Memory,
|
||||
Threads: s.cfg.Seal.Argon2Threads,
|
||||
}
|
||||
if err := s.seal.Initialize(r.Context(), []byte(req.Password), params); err != nil {
|
||||
if err == seal.ErrAlreadyInitialized {
|
||||
http.Error(w, `{"error":"already initialized"}`, http.StatusConflict)
|
||||
return
|
||||
}
|
||||
s.logger.Error("init failed", "error", err)
|
||||
http.Error(w, `{"error":"initialization failed"}`, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"state": s.seal.State().String(),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) handleUnseal(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, `{"error":"method not allowed"}`, http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
var req struct {
|
||||
Password string `json:"password"`
|
||||
}
|
||||
if err := readJSON(r, &req); err != nil {
|
||||
http.Error(w, `{"error":"invalid request"}`, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.seal.Unseal([]byte(req.Password)); err != nil {
|
||||
switch err {
|
||||
case seal.ErrNotInitialized:
|
||||
http.Error(w, `{"error":"not initialized"}`, http.StatusPreconditionFailed)
|
||||
case seal.ErrInvalidPassword:
|
||||
http.Error(w, `{"error":"invalid password"}`, http.StatusUnauthorized)
|
||||
case seal.ErrRateLimited:
|
||||
http.Error(w, `{"error":"too many attempts, try again later"}`, http.StatusTooManyRequests)
|
||||
case seal.ErrNotSealed:
|
||||
http.Error(w, `{"error":"already unsealed"}`, http.StatusConflict)
|
||||
default:
|
||||
s.logger.Error("unseal failed", "error", err)
|
||||
http.Error(w, `{"error":"unseal failed"}`, http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"state": s.seal.State().String(),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) handleSeal(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, `{"error":"method not allowed"}`, http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.engines.SealAll(); err != nil {
|
||||
s.logger.Error("seal engines failed", "error", err)
|
||||
}
|
||||
|
||||
if err := s.seal.Seal(); err != nil {
|
||||
s.logger.Error("seal failed", "error", err)
|
||||
http.Error(w, `{"error":"seal failed"}`, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
s.auth.ClearCache()
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"state": s.seal.State().String(),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) handleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, `{"error":"method not allowed"}`, http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if s.seal.State() != seal.StateUnsealed {
|
||||
http.Error(w, `{"error":"sealed"}`, http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
TOTPCode string `json:"totp_code"`
|
||||
}
|
||||
if err := readJSON(r, &req); err != nil {
|
||||
http.Error(w, `{"error":"invalid request"}`, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
token, expiresAt, err := s.auth.Login(req.Username, req.Password, req.TOTPCode)
|
||||
if err != nil {
|
||||
http.Error(w, `{"error":"invalid credentials"}`, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"token": token,
|
||||
"expires_at": expiresAt,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) handleLogout(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, `{"error":"method not allowed"}`, http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
token := extractToken(r)
|
||||
client, err := mcias.New(s.cfg.MCIAS.ServerURL, mcias.Options{
|
||||
CACertPath: s.cfg.MCIAS.CACert,
|
||||
Token: token,
|
||||
})
|
||||
if err == nil {
|
||||
s.auth.Logout(client)
|
||||
}
|
||||
|
||||
// Clear cookie.
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "metacrypt_token",
|
||||
Value: "",
|
||||
Path: "/",
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
})
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{"ok": true})
|
||||
}
|
||||
|
||||
func (s *Server) handleTokenInfo(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, `{"error":"method not allowed"}`, http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
info := TokenInfoFromContext(r.Context())
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"username": info.Username,
|
||||
"roles": info.Roles,
|
||||
"is_admin": info.IsAdmin,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) handleEngineMounts(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, `{"error":"method not allowed"}`, http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
mounts := s.engines.ListMounts()
|
||||
writeJSON(w, http.StatusOK, mounts)
|
||||
}
|
||||
|
||||
func (s *Server) handleEngineMount(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, `{"error":"method not allowed"}`, http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
var req struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
if err := readJSON(r, &req); err != nil {
|
||||
http.Error(w, `{"error":"invalid request"}`, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
// Phase 1: no engine types registered yet.
|
||||
http.Error(w, `{"error":"no engine types available in phase 1"}`, http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
func (s *Server) handleEngineUnmount(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, `{"error":"method not allowed"}`, http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
var req struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
if err := readJSON(r, &req); err != nil {
|
||||
http.Error(w, `{"error":"invalid request"}`, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if err := s.engines.Unmount(req.Name); err != nil {
|
||||
http.Error(w, `{"error":"`+err.Error()+`"}`, http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{"ok": true})
|
||||
}
|
||||
|
||||
func (s *Server) handleEngineRequest(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, `{"error":"method not allowed"}`, http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
// Phase 1 stub.
|
||||
http.Error(w, `{"error":"no engine types available in phase 1"}`, http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
func (s *Server) handlePolicyRules(w http.ResponseWriter, r *http.Request) {
|
||||
info := TokenInfoFromContext(r.Context())
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
if !info.IsAdmin {
|
||||
http.Error(w, `{"error":"forbidden"}`, http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
rules, err := s.policy.ListRules(r.Context())
|
||||
if err != nil {
|
||||
s.logger.Error("list policies", "error", err)
|
||||
http.Error(w, `{"error":"internal error"}`, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if rules == nil {
|
||||
rules = []policy.Rule{}
|
||||
}
|
||||
writeJSON(w, http.StatusOK, rules)
|
||||
case http.MethodPost:
|
||||
if !info.IsAdmin {
|
||||
http.Error(w, `{"error":"forbidden"}`, http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
var rule policy.Rule
|
||||
if err := readJSON(r, &rule); err != nil {
|
||||
http.Error(w, `{"error":"invalid request"}`, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if rule.ID == "" {
|
||||
http.Error(w, `{"error":"id is required"}`, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if err := s.policy.CreateRule(r.Context(), &rule); err != nil {
|
||||
s.logger.Error("create policy", "error", err)
|
||||
http.Error(w, `{"error":"internal error"}`, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusCreated, rule)
|
||||
default:
|
||||
http.Error(w, `{"error":"method not allowed"}`, http.StatusMethodNotAllowed)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handlePolicyRule(w http.ResponseWriter, r *http.Request) {
|
||||
info := TokenInfoFromContext(r.Context())
|
||||
if !info.IsAdmin {
|
||||
http.Error(w, `{"error":"forbidden"}`, http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
id := r.URL.Query().Get("id")
|
||||
if id == "" {
|
||||
http.Error(w, `{"error":"id parameter required"}`, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
rule, err := s.policy.GetRule(r.Context(), id)
|
||||
if err != nil {
|
||||
http.Error(w, `{"error":"not found"}`, http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, rule)
|
||||
case http.MethodDelete:
|
||||
if err := s.policy.DeleteRule(r.Context(), id); err != nil {
|
||||
http.Error(w, `{"error":"not found"}`, http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{"ok": true})
|
||||
default:
|
||||
http.Error(w, `{"error":"method not allowed"}`, http.StatusMethodNotAllowed)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Web Handlers ---
|
||||
|
||||
func (s *Server) handleWebRoot(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
state := s.seal.State()
|
||||
switch state {
|
||||
case seal.StateUninitialized:
|
||||
http.Redirect(w, r, "/init", http.StatusFound)
|
||||
case seal.StateSealed:
|
||||
http.Redirect(w, r, "/unseal", http.StatusFound)
|
||||
case seal.StateInitializing:
|
||||
http.Redirect(w, r, "/init", http.StatusFound)
|
||||
case seal.StateUnsealed:
|
||||
http.Redirect(w, r, "/dashboard", http.StatusFound)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleWebInit(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
if s.seal.State() != seal.StateUninitialized {
|
||||
http.Redirect(w, r, "/", http.StatusFound)
|
||||
return
|
||||
}
|
||||
s.renderTemplate(w, "init.html", nil)
|
||||
case http.MethodPost:
|
||||
r.ParseForm()
|
||||
password := r.FormValue("password")
|
||||
if password == "" {
|
||||
s.renderTemplate(w, "init.html", map[string]interface{}{"Error": "Password is required"})
|
||||
return
|
||||
}
|
||||
params := crypto.Argon2Params{
|
||||
Time: s.cfg.Seal.Argon2Time,
|
||||
Memory: s.cfg.Seal.Argon2Memory,
|
||||
Threads: s.cfg.Seal.Argon2Threads,
|
||||
}
|
||||
if err := s.seal.Initialize(r.Context(), []byte(password), params); err != nil {
|
||||
s.renderTemplate(w, "init.html", map[string]interface{}{"Error": err.Error()})
|
||||
return
|
||||
}
|
||||
http.Redirect(w, r, "/dashboard", http.StatusFound)
|
||||
default:
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleWebUnseal(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
state := s.seal.State()
|
||||
if state == seal.StateUninitialized {
|
||||
http.Redirect(w, r, "/init", http.StatusFound)
|
||||
return
|
||||
}
|
||||
if state == seal.StateUnsealed {
|
||||
http.Redirect(w, r, "/dashboard", http.StatusFound)
|
||||
return
|
||||
}
|
||||
s.renderTemplate(w, "unseal.html", nil)
|
||||
case http.MethodPost:
|
||||
r.ParseForm()
|
||||
password := r.FormValue("password")
|
||||
if err := s.seal.Unseal([]byte(password)); err != nil {
|
||||
msg := "Invalid password"
|
||||
if err == seal.ErrRateLimited {
|
||||
msg = "Too many attempts. Please wait 60 seconds."
|
||||
}
|
||||
s.renderTemplate(w, "unseal.html", map[string]interface{}{"Error": msg})
|
||||
return
|
||||
}
|
||||
http.Redirect(w, r, "/dashboard", http.StatusFound)
|
||||
default:
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleWebLogin(w http.ResponseWriter, r *http.Request) {
|
||||
if s.seal.State() != seal.StateUnsealed {
|
||||
http.Redirect(w, r, "/", http.StatusFound)
|
||||
return
|
||||
}
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
s.renderTemplate(w, "login.html", nil)
|
||||
case http.MethodPost:
|
||||
r.ParseForm()
|
||||
username := r.FormValue("username")
|
||||
password := r.FormValue("password")
|
||||
totpCode := r.FormValue("totp_code")
|
||||
token, _, err := s.auth.Login(username, password, totpCode)
|
||||
if err != nil {
|
||||
s.renderTemplate(w, "login.html", map[string]interface{}{"Error": "Invalid credentials"})
|
||||
return
|
||||
}
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "metacrypt_token",
|
||||
Value: token,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
})
|
||||
http.Redirect(w, r, "/dashboard", http.StatusFound)
|
||||
default:
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleWebDashboard(w http.ResponseWriter, r *http.Request) {
|
||||
info := TokenInfoFromContext(r.Context())
|
||||
mounts := s.engines.ListMounts()
|
||||
s.renderTemplate(w, "dashboard.html", map[string]interface{}{
|
||||
"Username": info.Username,
|
||||
"IsAdmin": info.IsAdmin,
|
||||
"Roles": info.Roles,
|
||||
"Mounts": mounts,
|
||||
"State": s.seal.State().String(),
|
||||
})
|
||||
}
|
||||
|
||||
// requireAuthWeb redirects to login for web pages instead of returning 401.
|
||||
func (s *Server) requireAuthWeb(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if s.seal.State() != seal.StateUnsealed {
|
||||
http.Redirect(w, r, "/", http.StatusFound)
|
||||
return
|
||||
}
|
||||
token := extractToken(r)
|
||||
if token == "" {
|
||||
http.Redirect(w, r, "/login", http.StatusFound)
|
||||
return
|
||||
}
|
||||
info, err := s.auth.ValidateToken(token)
|
||||
if err != nil {
|
||||
http.Redirect(w, r, "/login", http.StatusFound)
|
||||
return
|
||||
}
|
||||
ctx := r.Context()
|
||||
ctx = context.WithValue(ctx, tokenInfoKey, info)
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) renderTemplate(w http.ResponseWriter, name string, data interface{}) {
|
||||
tmpl, err := template.ParseFiles(
|
||||
filepath.Join("web", "templates", "layout.html"),
|
||||
filepath.Join("web", "templates", name),
|
||||
)
|
||||
if err != nil {
|
||||
s.logger.Error("parse template", "name", name, "error", err)
|
||||
http.Error(w, "internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
if err := tmpl.ExecuteTemplate(w, "layout", data); err != nil {
|
||||
s.logger.Error("execute template", "name", name, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func writeJSON(w http.ResponseWriter, status int, v interface{}) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
json.NewEncoder(w).Encode(v)
|
||||
}
|
||||
|
||||
func readJSON(r *http.Request, v interface{}) error {
|
||||
defer r.Body.Close()
|
||||
body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20)) // 1MB limit
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return json.Unmarshal(body, v)
|
||||
}
|
||||
77
internal/server/server.go
Normal file
77
internal/server/server.go
Normal file
@@ -0,0 +1,77 @@
|
||||
// Package server implements the HTTP server for Metacrypt.
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/auth"
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/config"
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/engine"
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/policy"
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/seal"
|
||||
)
|
||||
|
||||
// Server is the Metacrypt HTTP server.
|
||||
type Server struct {
|
||||
cfg *config.Config
|
||||
seal *seal.Manager
|
||||
auth *auth.Authenticator
|
||||
policy *policy.Engine
|
||||
engines *engine.Registry
|
||||
httpSrv *http.Server
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// New creates a new server.
|
||||
func New(cfg *config.Config, sealMgr *seal.Manager, authenticator *auth.Authenticator,
|
||||
policyEngine *policy.Engine, engineRegistry *engine.Registry, logger *slog.Logger) *Server {
|
||||
s := &Server{
|
||||
cfg: cfg,
|
||||
seal: sealMgr,
|
||||
auth: authenticator,
|
||||
policy: policyEngine,
|
||||
engines: engineRegistry,
|
||||
logger: logger,
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// Start starts the HTTPS server.
|
||||
func (s *Server) Start() error {
|
||||
mux := http.NewServeMux()
|
||||
s.registerRoutes(mux)
|
||||
|
||||
tlsCfg := &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
CipherSuites: []uint16{
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||
},
|
||||
}
|
||||
|
||||
s.httpSrv = &http.Server{
|
||||
Addr: s.cfg.Server.ListenAddr,
|
||||
Handler: s.loggingMiddleware(mux),
|
||||
TLSConfig: tlsCfg,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
IdleTimeout: 120 * time.Second,
|
||||
}
|
||||
|
||||
s.logger.Info("starting server", "addr", s.cfg.Server.ListenAddr)
|
||||
err := s.httpSrv.ListenAndServeTLS(s.cfg.Server.TLSCert, s.cfg.Server.TLSKey)
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
return fmt.Errorf("server: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the server.
|
||||
func (s *Server) Shutdown(ctx context.Context) error {
|
||||
return s.httpSrv.Shutdown(ctx)
|
||||
}
|
||||
179
internal/server/server_test.go
Normal file
179
internal/server/server_test.go
Normal file
@@ -0,0 +1,179 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"log/slog"
|
||||
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/barrier"
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/config"
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/crypto"
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/db"
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/engine"
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/policy"
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/seal"
|
||||
|
||||
// auth is used indirectly via the server
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/auth"
|
||||
)
|
||||
|
||||
func setupTestServer(t *testing.T) (*Server, *seal.Manager, *http.ServeMux) {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
database, err := db.Open(filepath.Join(dir, "test.db"))
|
||||
if err != nil {
|
||||
t.Fatalf("open db: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { database.Close() })
|
||||
db.Migrate(database)
|
||||
|
||||
b := barrier.NewAESGCMBarrier(database)
|
||||
sealMgr := seal.NewManager(database, b)
|
||||
sealMgr.CheckInitialized()
|
||||
|
||||
// Auth requires MCIAS client which we can't create in tests easily,
|
||||
// so we pass nil and avoid auth-dependent routes in these tests.
|
||||
authenticator := auth.NewAuthenticator(nil)
|
||||
policyEngine := policy.NewEngine(b)
|
||||
engineRegistry := engine.NewRegistry(b)
|
||||
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
ListenAddr: ":0",
|
||||
TLSCert: "cert.pem",
|
||||
TLSKey: "key.pem",
|
||||
},
|
||||
Database: config.DatabaseConfig{Path: filepath.Join(dir, "test.db")},
|
||||
MCIAS: config.MCIASConfig{ServerURL: "https://mcias.test"},
|
||||
Seal: config.SealConfig{
|
||||
Argon2Time: 1,
|
||||
Argon2Memory: 64 * 1024,
|
||||
Argon2Threads: 1,
|
||||
},
|
||||
}
|
||||
|
||||
logger := slog.Default()
|
||||
srv := New(cfg, sealMgr, authenticator, policyEngine, engineRegistry, logger)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
srv.registerRoutes(mux)
|
||||
return srv, sealMgr, mux
|
||||
}
|
||||
|
||||
func TestStatusEndpoint(t *testing.T) {
|
||||
_, _, mux := setupTestServer(t)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/status", nil)
|
||||
w := httptest.NewRecorder()
|
||||
mux.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status code: got %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp["state"] != "uninitialized" {
|
||||
t.Errorf("state: got %q, want %q", resp["state"], "uninitialized")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInitEndpoint(t *testing.T) {
|
||||
_, _, mux := setupTestServer(t)
|
||||
|
||||
body := `{"password":"test-password"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/init", strings.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
mux.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status code: got %d, want %d. Body: %s", w.Code, http.StatusOK, w.Body.String())
|
||||
}
|
||||
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp["state"] != "unsealed" {
|
||||
t.Errorf("state: got %q, want %q", resp["state"], "unsealed")
|
||||
}
|
||||
|
||||
// Second init should fail.
|
||||
req2 := httptest.NewRequest(http.MethodPost, "/v1/init", strings.NewReader(body))
|
||||
w2 := httptest.NewRecorder()
|
||||
mux.ServeHTTP(w2, req2)
|
||||
if w2.Code != http.StatusConflict {
|
||||
t.Errorf("double init: got %d, want %d", w2.Code, http.StatusConflict)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnsealEndpoint(t *testing.T) {
|
||||
_, sealMgr, mux := setupTestServer(t)
|
||||
|
||||
// Initialize first.
|
||||
params := crypto.Argon2Params{Time: 1, Memory: 64 * 1024, Threads: 1}
|
||||
sealMgr.Initialize(context.Background(), []byte("password"), params)
|
||||
sealMgr.Seal()
|
||||
|
||||
// Unseal with wrong password.
|
||||
body := `{"password":"wrong"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/unseal", strings.NewReader(body))
|
||||
w := httptest.NewRecorder()
|
||||
mux.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("wrong password: got %d, want %d", w.Code, http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
// Unseal with correct password.
|
||||
body = `{"password":"password"}`
|
||||
req = httptest.NewRequest(http.MethodPost, "/v1/unseal", strings.NewReader(body))
|
||||
w = httptest.NewRecorder()
|
||||
mux.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("correct password: got %d, want %d. Body: %s", w.Code, http.StatusOK, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatusMethodNotAllowed(t *testing.T) {
|
||||
_, _, mux := setupTestServer(t)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/status", nil)
|
||||
w := httptest.NewRecorder()
|
||||
mux.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusMethodNotAllowed {
|
||||
t.Errorf("POST /v1/status: got %d, want %d", w.Code, http.StatusMethodNotAllowed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRootRedirect(t *testing.T) {
|
||||
_, _, mux := setupTestServer(t)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
mux.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusFound {
|
||||
t.Errorf("root redirect: got %d, want %d", w.Code, http.StatusFound)
|
||||
}
|
||||
loc := w.Header().Get("Location")
|
||||
if loc != "/init" {
|
||||
t.Errorf("redirect location: got %q, want /init", loc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenInfoFromContext(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
if info := TokenInfoFromContext(ctx); info != nil {
|
||||
t.Error("expected nil from empty context")
|
||||
}
|
||||
|
||||
info := &auth.TokenInfo{Username: "test", IsAdmin: true}
|
||||
ctx = context.WithValue(ctx, tokenInfoKey, info)
|
||||
got := TokenInfoFromContext(ctx)
|
||||
if got == nil || got.Username != "test" {
|
||||
t.Error("expected token info from context")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user