Checkpoint: auth, engine, seal, server, grpc updates

Co-authored-by: Junie <junie@jetbrains.com>
This commit is contained in:
2026-03-15 09:54:04 -07:00
parent 33beb33a13
commit 44e5e6e174
21 changed files with 185 additions and 31 deletions

View File

@@ -5,6 +5,7 @@ import (
"crypto/sha256"
"encoding/hex"
"errors"
"log/slog"
"sync"
"time"
@@ -34,29 +35,35 @@ type cachedClaims struct {
// Authenticator provides MCIAS-backed authentication.
type Authenticator struct {
client *mcias.Client
logger *slog.Logger
mu sync.RWMutex
cache map[string]*cachedClaims // keyed by SHA-256(token)
}
// NewAuthenticator creates a new authenticator with the given MCIAS client.
func NewAuthenticator(client *mcias.Client) *Authenticator {
func NewAuthenticator(client *mcias.Client, logger *slog.Logger) *Authenticator {
return &Authenticator{
client: client,
logger: logger,
cache: make(map[string]*cachedClaims),
}
}
// Login authenticates a user via MCIAS and returns the token.
func (a *Authenticator) Login(username, password, totpCode string) (token string, expiresAt string, err error) {
a.logger.Debug("login attempt", "username", username)
tok, exp, err := a.client.Login(username, password, totpCode)
if err != nil {
var authErr *mcias.MciasAuthError
if errors.As(err, &authErr) {
a.logger.Debug("login failed: invalid credentials", "username", username)
return "", "", ErrInvalidCredentials
}
a.logger.Debug("login failed", "username", username, "error", err)
return "", "", err
}
a.logger.Debug("login succeeded", "username", username)
return tok, exp, nil
}
@@ -69,15 +76,19 @@ func (a *Authenticator) ValidateToken(token string) (*TokenInfo, error) {
cached, ok := a.cache[key]
a.mu.RUnlock()
if ok && time.Now().Before(cached.expiresAt) {
a.logger.Debug("token validated from cache")
return cached.info, nil
}
a.logger.Debug("validating token with MCIAS")
// Validate with MCIAS.
claims, err := a.client.ValidateToken(token)
if err != nil {
a.logger.Debug("token validation failed", "error", err)
return nil, err
}
if !claims.Valid {
a.logger.Debug("token invalid per MCIAS")
return nil, ErrInvalidToken
}
@@ -94,6 +105,7 @@ func (a *Authenticator) ValidateToken(token string) (*TokenInfo, error) {
expiresAt: time.Now().Add(tokenCacheTTL),
}
a.mu.Unlock()
a.logger.Debug("token validated and cached", "username", info.Username, "is_admin", info.IsAdmin)
return info, nil
}
@@ -105,6 +117,7 @@ func (a *Authenticator) Logout(client *mcias.Client) error {
// ClearCache removes all cached token validations.
func (a *Authenticator) ClearCache() {
a.logger.Debug("clearing token cache")
a.mu.Lock()
a.cache = make(map[string]*cachedClaims)
a.mu.Unlock()

View File

@@ -1,6 +1,7 @@
package auth
import (
"log/slog"
"testing"
)
@@ -33,7 +34,7 @@ func TestHasAdminRole(t *testing.T) {
}
func TestNewAuthenticator(t *testing.T) {
a := NewAuthenticator(nil)
a := NewAuthenticator(nil, slog.Default())
if a == nil {
t.Fatal("NewAuthenticator returned nil")
}
@@ -43,7 +44,7 @@ func TestNewAuthenticator(t *testing.T) {
}
func TestClearCache(t *testing.T) {
a := NewAuthenticator(nil)
a := NewAuthenticator(nil, slog.Default())
a.cache["test"] = &cachedClaims{info: &TokenInfo{Username: "test"}}
a.ClearCache()
if len(a.cache) != 0 {

View File

@@ -7,6 +7,7 @@ import (
"encoding/json"
"errors"
"fmt"
"log/slog"
"strings"
"sync"
@@ -76,18 +77,20 @@ type Mount struct {
// Registry manages mounted engine instances.
type Registry struct {
mu sync.RWMutex
mounts map[string]*Mount
mu sync.RWMutex
mounts map[string]*Mount
factories map[EngineType]Factory
barrier barrier.Barrier
barrier barrier.Barrier
logger *slog.Logger
}
// NewRegistry creates a new engine registry.
func NewRegistry(b barrier.Barrier) *Registry {
func NewRegistry(b barrier.Barrier, logger *slog.Logger) *Registry {
return &Registry{
mounts: make(map[string]*Mount),
factories: make(map[EngineType]Factory),
barrier: b,
logger: logger,
}
}
@@ -95,6 +98,7 @@ func NewRegistry(b barrier.Barrier) *Registry {
func (r *Registry) RegisterFactory(t EngineType, f Factory) {
r.mu.Lock()
defer r.mu.Unlock()
r.logger.Debug("registering engine factory", "type", t)
r.factories[t] = f
}
@@ -120,6 +124,7 @@ func (r *Registry) Mount(ctx context.Context, name string, engineType EngineType
return fmt.Errorf("%w: %s", ErrUnknownType, engineType)
}
r.logger.Debug("mounting engine", "name", name, "type", engineType)
eng := factory()
mountPath := fmt.Sprintf("engine/%s/%s/", engineType, name)
@@ -142,6 +147,7 @@ func (r *Registry) Mount(ctx context.Context, name string, engineType EngineType
MountPath: mountPath,
Engine: eng,
}
r.logger.Debug("engine mounted", "name", name, "type", engineType, "mount_path", mountPath)
return nil
}
@@ -179,6 +185,7 @@ func (r *Registry) Unmount(ctx context.Context, name string) error {
return ErrMountNotFound
}
r.logger.Debug("unmounting engine", "name", name, "type", mount.Type)
if err := mount.Engine.Seal(); err != nil {
return fmt.Errorf("engine: seal %q: %w", name, err)
}
@@ -189,6 +196,7 @@ func (r *Registry) Unmount(ctx context.Context, name string) error {
}
delete(r.mounts, name)
r.logger.Debug("engine unmounted", "name", name)
return nil
}
@@ -231,6 +239,7 @@ func (r *Registry) UnsealAll(ctx context.Context) error {
continue // already loaded
}
r.logger.Debug("discovered pre-migration engine mount", "name", name, "type", engineType)
eng := factory()
mountPath := fmt.Sprintf("engine/%s/%s/", engineType, name)
if err := eng.Unseal(ctx, r.barrier, mountPath); err != nil {
@@ -280,6 +289,7 @@ func (r *Registry) loadFromMetadata(ctx context.Context) error {
return fmt.Errorf("%w: %s (mount %q)", ErrUnknownType, meta.Type, meta.Name)
}
r.logger.Debug("unsealing engine from metadata", "name", meta.Name, "type", meta.Type)
eng := factory()
mountPath := fmt.Sprintf("engine/%s/%s/", meta.Type, meta.Name)
if err := eng.Unseal(ctx, r.barrier, mountPath); err != nil {
@@ -323,6 +333,7 @@ func (r *Registry) HandleRequest(ctx context.Context, mountName string, req *Req
return nil, ErrMountNotFound
}
r.logger.Debug("routing engine request", "mount", mountName, "operation", req.Operation, "path", req.Path)
return mount.Engine.HandleRequest(ctx, req)
}
@@ -331,6 +342,7 @@ func (r *Registry) SealAll() error {
r.mu.Lock()
defer r.mu.Unlock()
r.logger.Debug("sealing all engines", "count", len(r.mounts))
for name, mount := range r.mounts {
if err := mount.Engine.Seal(); err != nil {
return fmt.Errorf("engine: seal %q: %w", name, err)

View File

@@ -2,6 +2,7 @@ package engine
import (
"context"
"log/slog"
"testing"
"git.wntrmute.dev/kyle/metacrypt/internal/barrier"
@@ -39,7 +40,7 @@ func (m *mockBarrier) Delete(_ context.Context, _ string) error { retu
func (m *mockBarrier) List(_ context.Context, _ string) ([]string, error) { return nil, nil }
func TestRegistryMountUnmount(t *testing.T) {
reg := NewRegistry(&mockBarrier{})
reg := NewRegistry(&mockBarrier{}, slog.Default())
reg.RegisterFactory(EngineTypeTransit, func() Engine {
return &mockEngine{engineType: EngineTypeTransit}
})
@@ -73,14 +74,14 @@ func TestRegistryMountUnmount(t *testing.T) {
}
func TestRegistryUnmountNotFound(t *testing.T) {
reg := NewRegistry(&mockBarrier{})
reg := NewRegistry(&mockBarrier{}, slog.Default())
if err := reg.Unmount(context.Background(), "nonexistent"); err != ErrMountNotFound {
t.Fatalf("expected ErrMountNotFound, got: %v", err)
}
}
func TestRegistryUnknownType(t *testing.T) {
reg := NewRegistry(&mockBarrier{})
reg := NewRegistry(&mockBarrier{}, slog.Default())
err := reg.Mount(context.Background(), "test", EngineTypeTransit, nil)
if err == nil {
t.Fatal("expected error for unknown engine type")
@@ -88,7 +89,7 @@ func TestRegistryUnknownType(t *testing.T) {
}
func TestRegistryHandleRequest(t *testing.T) {
reg := NewRegistry(&mockBarrier{})
reg := NewRegistry(&mockBarrier{}, slog.Default())
reg.RegisterFactory(EngineTypeTransit, func() Engine {
return &mockEngine{engineType: EngineTypeTransit}
})
@@ -111,7 +112,7 @@ func TestRegistryHandleRequest(t *testing.T) {
}
func TestRegistrySealAll(t *testing.T) {
reg := NewRegistry(&mockBarrier{})
reg := NewRegistry(&mockBarrier{}, slog.Default())
reg.RegisterFactory(EngineTypeTransit, func() Engine {
return &mockEngine{engineType: EngineTypeTransit}
})

View File

@@ -2,6 +2,7 @@ package grpcserver
import (
"context"
"log/slog"
"strings"
"google.golang.org/grpc"
@@ -25,7 +26,7 @@ func tokenInfoFromContext(ctx context.Context) *auth.TokenInfo {
// authInterceptor validates the Bearer token from gRPC metadata and injects
// *auth.TokenInfo into the context. The set of method full names that require
// auth is passed in; all others pass through without validation.
func authInterceptor(authenticator *auth.Authenticator, methods map[string]bool) grpc.UnaryServerInterceptor {
func authInterceptor(authenticator *auth.Authenticator, logger *slog.Logger, methods map[string]bool) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
if !methods[info.FullMethod] {
return handler(ctx, req)
@@ -33,14 +34,17 @@ func authInterceptor(authenticator *auth.Authenticator, methods map[string]bool)
token := extractToken(ctx)
if token == "" {
logger.Debug("grpc request rejected: missing token", "method", info.FullMethod)
return nil, status.Error(codes.Unauthenticated, "missing authorization token")
}
tokenInfo, err := authenticator.ValidateToken(token)
if err != nil {
logger.Debug("grpc request rejected: invalid token", "method", info.FullMethod, "error", err)
return nil, status.Error(codes.Unauthenticated, "invalid token")
}
logger.Debug("grpc request authenticated", "method", info.FullMethod, "username", tokenInfo.Username)
ctx = context.WithValue(ctx, tokenInfoKey, tokenInfo)
return handler(ctx, req)
}
@@ -48,27 +52,30 @@ func authInterceptor(authenticator *auth.Authenticator, methods map[string]bool)
// adminInterceptor requires IsAdmin on the token info for the listed methods.
// Must run after authInterceptor.
func adminInterceptor(methods map[string]bool) grpc.UnaryServerInterceptor {
func adminInterceptor(logger *slog.Logger, methods map[string]bool) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
if !methods[info.FullMethod] {
return handler(ctx, req)
}
ti := tokenInfoFromContext(ctx)
if ti == nil || !ti.IsAdmin {
logger.Debug("grpc request rejected: admin required", "method", info.FullMethod)
return nil, status.Error(codes.PermissionDenied, "admin required")
}
logger.Debug("grpc admin request authorized", "method", info.FullMethod, "username", ti.Username)
return handler(ctx, req)
}
}
// sealInterceptor rejects calls with FailedPrecondition when the vault is
// sealed, for the listed methods.
func sealInterceptor(sealMgr *seal.Manager, methods map[string]bool) grpc.UnaryServerInterceptor {
func sealInterceptor(sealMgr *seal.Manager, logger *slog.Logger, methods map[string]bool) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
if !methods[info.FullMethod] {
return handler(ctx, req)
}
if sealMgr.State() != seal.StateUnsealed {
logger.Debug("grpc request rejected: vault sealed", "method", info.FullMethod)
return nil, status.Error(codes.FailedPrecondition, "vault is sealed")
}
return handler(ctx, req)

View File

@@ -66,9 +66,9 @@ func (s *GRPCServer) Start() error {
creds := credentials.NewTLS(tlsCfg)
interceptor := chainInterceptors(
sealInterceptor(s.sealMgr, sealRequiredMethods()),
authInterceptor(s.auth, authRequiredMethods()),
adminInterceptor(adminRequiredMethods()),
sealInterceptor(s.sealMgr, s.logger, sealRequiredMethods()),
authInterceptor(s.auth, s.logger, authRequiredMethods()),
adminInterceptor(s.logger, adminRequiredMethods()),
)
s.srv = grpc.NewServer(

View File

@@ -6,6 +6,7 @@ import (
"database/sql"
"errors"
"fmt"
"log/slog"
"sync"
"time"
@@ -51,6 +52,7 @@ var (
type Manager struct {
db *sql.DB
barrier *barrier.AESGCMBarrier
logger *slog.Logger
mu sync.RWMutex
state ServiceState
@@ -63,10 +65,11 @@ type Manager struct {
}
// NewManager creates a new seal manager.
func NewManager(db *sql.DB, b *barrier.AESGCMBarrier) *Manager {
func NewManager(db *sql.DB, b *barrier.AESGCMBarrier, logger *slog.Logger) *Manager {
return &Manager{
db: db,
barrier: b,
logger: logger,
state: StateUninitialized,
}
}
@@ -98,8 +101,10 @@ func (m *Manager) CheckInitialized() error {
}
if count > 0 {
m.state = StateSealed
m.logger.Debug("seal config found, state set to sealed")
} else {
m.state = StateUninitialized
m.logger.Debug("no seal config found, state set to uninitialized")
}
return nil
}
@@ -114,6 +119,7 @@ func (m *Manager) Initialize(ctx context.Context, password []byte, params crypto
return ErrAlreadyInitialized
}
m.logger.Debug("initializing seal manager")
m.state = StateInitializing
defer func() {
if m.mek == nil {
@@ -162,6 +168,7 @@ func (m *Manager) Initialize(ctx context.Context, password []byte, params crypto
m.mek = mek
m.state = StateUnsealed
m.logger.Debug("seal initialization complete, barrier unsealed")
return nil
}
@@ -177,9 +184,11 @@ func (m *Manager) Unseal(password []byte) error {
return ErrNotSealed
}
m.logger.Debug("unseal attempt")
// Rate limiting.
now := time.Now()
if now.Before(m.lockoutUntil) {
m.logger.Debug("unseal attempt rate limited")
return ErrRateLimited
}
if now.Sub(m.lastAttempt) > time.Minute {
@@ -190,6 +199,7 @@ func (m *Manager) Unseal(password []byte) error {
if m.unsealAttempts > 5 {
m.lockoutUntil = now.Add(60 * time.Second)
m.unsealAttempts = 0
m.logger.Debug("unseal attempts exceeded, locking out")
return ErrRateLimited
}
@@ -215,6 +225,7 @@ func (m *Manager) Unseal(password []byte) error {
mek, err := crypto.Decrypt(kwk, encryptedMEK)
if err != nil {
m.logger.Debug("unseal failed: invalid password")
return ErrInvalidPassword
}
@@ -227,6 +238,7 @@ func (m *Manager) Unseal(password []byte) error {
m.mek = mek
m.state = StateUnsealed
m.unsealAttempts = 0
m.logger.Debug("unseal succeeded, barrier unsealed")
return nil
}
@@ -239,11 +251,13 @@ func (m *Manager) Seal() error {
return ErrNotSealed
}
m.logger.Debug("sealing service")
if m.mek != nil {
crypto.Zeroize(m.mek)
m.mek = nil
}
m.barrier.Seal()
m.state = StateSealed
m.logger.Debug("service sealed")
return nil
}

View File

@@ -2,6 +2,7 @@ package seal
import (
"context"
"log/slog"
"path/filepath"
"testing"
@@ -21,7 +22,7 @@ func setupSeal(t *testing.T) (*Manager, func()) {
t.Fatalf("migrate: %v", err)
}
b := barrier.NewAESGCMBarrier(database)
mgr := NewManager(database, b)
mgr := NewManager(database, b, slog.Default())
return mgr, func() { database.Close() }
}
@@ -101,7 +102,7 @@ func TestSealCheckInitializedPersists(t *testing.T) {
database, _ := db.Open(dbPath)
db.Migrate(database)
b := barrier.NewAESGCMBarrier(database)
mgr := NewManager(database, b)
mgr := NewManager(database, b, slog.Default())
mgr.CheckInitialized()
params := crypto.Argon2Params{Time: 1, Memory: 64 * 1024, Threads: 1}
mgr.Initialize(context.Background(), []byte("password"), params)
@@ -111,7 +112,7 @@ func TestSealCheckInitializedPersists(t *testing.T) {
database2, _ := db.Open(dbPath)
defer database2.Close()
b2 := barrier.NewAESGCMBarrier(database2)
mgr2 := NewManager(database2, b2)
mgr2 := NewManager(database2, b2, slog.Default())
mgr2.CheckInitialized()
if mgr2.State() != StateSealed {
t.Fatalf("state after reopen: got %v, want Sealed", mgr2.State())

View File

@@ -42,9 +42,11 @@ func (s *Server) requireUnseal(next http.HandlerFunc) http.HandlerFunc {
state := s.seal.State()
switch state {
case seal.StateUninitialized:
s.logger.Debug("request rejected: service uninitialized", "path", r.URL.Path)
http.Error(w, `{"error":"not initialized"}`, http.StatusPreconditionFailed)
return
case seal.StateSealed, seal.StateInitializing:
s.logger.Debug("request rejected: service sealed", "path", r.URL.Path)
http.Error(w, `{"error":"sealed"}`, http.StatusServiceUnavailable)
return
}
@@ -57,16 +59,19 @@ func (s *Server) requireAuth(next http.HandlerFunc) http.HandlerFunc {
return s.requireUnseal(func(w http.ResponseWriter, r *http.Request) {
token := extractToken(r)
if token == "" {
s.logger.Debug("request rejected: missing token", "path", r.URL.Path)
http.Error(w, `{"error":"unauthorized"}`, http.StatusUnauthorized)
return
}
info, err := s.auth.ValidateToken(token)
if err != nil {
s.logger.Debug("request rejected: invalid token", "path", r.URL.Path, "error", err)
http.Error(w, `{"error":"unauthorized"}`, http.StatusUnauthorized)
return
}
s.logger.Debug("request authenticated", "path", r.URL.Path, "username", info.Username)
ctx := context.WithValue(r.Context(), tokenInfoKey, info)
next(w, r.WithContext(ctx))
})
@@ -77,9 +82,11 @@ func (s *Server) requireAdmin(next http.HandlerFunc) http.HandlerFunc {
return s.requireAuth(func(w http.ResponseWriter, r *http.Request) {
info := TokenInfoFromContext(r.Context())
if info == nil || !info.IsAdmin {
s.logger.Debug("request rejected: admin required", "path", r.URL.Path)
http.Error(w, `{"error":"forbidden"}`, http.StatusForbidden)
return
}
s.logger.Debug("admin request authorized", "path", r.URL.Path, "username", info.Username)
next(w, r)
})
}

View File

@@ -36,14 +36,14 @@ func setupTestServer(t *testing.T) (*Server, *seal.Manager, chi.Router) {
db.Migrate(database)
b := barrier.NewAESGCMBarrier(database)
sealMgr := seal.NewManager(database, b)
sealMgr := seal.NewManager(database, b, slog.Default())
sealMgr.CheckInitialized()
// Auth requires MCIAS client which we can't create in tests easily,
// so we pass nil and avoid auth-dependent routes in these tests.
authenticator := auth.NewAuthenticator(nil)
authenticator := auth.NewAuthenticator(nil, slog.Default())
policyEngine := policy.NewEngine(b)
engineRegistry := engine.NewRegistry(b)
engineRegistry := engine.NewRegistry(b, slog.Default())
cfg := &config.Config{
Server: config.ServerConfig{