Checkpoint: auth, engine, seal, server, grpc updates
Co-authored-by: Junie <junie@jetbrains.com>
This commit is contained in:
@@ -5,6 +5,7 @@ import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -34,29 +35,35 @@ type cachedClaims struct {
|
||||
// Authenticator provides MCIAS-backed authentication.
|
||||
type Authenticator struct {
|
||||
client *mcias.Client
|
||||
logger *slog.Logger
|
||||
|
||||
mu sync.RWMutex
|
||||
cache map[string]*cachedClaims // keyed by SHA-256(token)
|
||||
}
|
||||
|
||||
// NewAuthenticator creates a new authenticator with the given MCIAS client.
|
||||
func NewAuthenticator(client *mcias.Client) *Authenticator {
|
||||
func NewAuthenticator(client *mcias.Client, logger *slog.Logger) *Authenticator {
|
||||
return &Authenticator{
|
||||
client: client,
|
||||
logger: logger,
|
||||
cache: make(map[string]*cachedClaims),
|
||||
}
|
||||
}
|
||||
|
||||
// Login authenticates a user via MCIAS and returns the token.
|
||||
func (a *Authenticator) Login(username, password, totpCode string) (token string, expiresAt string, err error) {
|
||||
a.logger.Debug("login attempt", "username", username)
|
||||
tok, exp, err := a.client.Login(username, password, totpCode)
|
||||
if err != nil {
|
||||
var authErr *mcias.MciasAuthError
|
||||
if errors.As(err, &authErr) {
|
||||
a.logger.Debug("login failed: invalid credentials", "username", username)
|
||||
return "", "", ErrInvalidCredentials
|
||||
}
|
||||
a.logger.Debug("login failed", "username", username, "error", err)
|
||||
return "", "", err
|
||||
}
|
||||
a.logger.Debug("login succeeded", "username", username)
|
||||
return tok, exp, nil
|
||||
}
|
||||
|
||||
@@ -69,15 +76,19 @@ func (a *Authenticator) ValidateToken(token string) (*TokenInfo, error) {
|
||||
cached, ok := a.cache[key]
|
||||
a.mu.RUnlock()
|
||||
if ok && time.Now().Before(cached.expiresAt) {
|
||||
a.logger.Debug("token validated from cache")
|
||||
return cached.info, nil
|
||||
}
|
||||
|
||||
a.logger.Debug("validating token with MCIAS")
|
||||
// Validate with MCIAS.
|
||||
claims, err := a.client.ValidateToken(token)
|
||||
if err != nil {
|
||||
a.logger.Debug("token validation failed", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
if !claims.Valid {
|
||||
a.logger.Debug("token invalid per MCIAS")
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
@@ -94,6 +105,7 @@ func (a *Authenticator) ValidateToken(token string) (*TokenInfo, error) {
|
||||
expiresAt: time.Now().Add(tokenCacheTTL),
|
||||
}
|
||||
a.mu.Unlock()
|
||||
a.logger.Debug("token validated and cached", "username", info.Username, "is_admin", info.IsAdmin)
|
||||
|
||||
return info, nil
|
||||
}
|
||||
@@ -105,6 +117,7 @@ func (a *Authenticator) Logout(client *mcias.Client) error {
|
||||
|
||||
// ClearCache removes all cached token validations.
|
||||
func (a *Authenticator) ClearCache() {
|
||||
a.logger.Debug("clearing token cache")
|
||||
a.mu.Lock()
|
||||
a.cache = make(map[string]*cachedClaims)
|
||||
a.mu.Unlock()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -33,7 +34,7 @@ func TestHasAdminRole(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNewAuthenticator(t *testing.T) {
|
||||
a := NewAuthenticator(nil)
|
||||
a := NewAuthenticator(nil, slog.Default())
|
||||
if a == nil {
|
||||
t.Fatal("NewAuthenticator returned nil")
|
||||
}
|
||||
@@ -43,7 +44,7 @@ func TestNewAuthenticator(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClearCache(t *testing.T) {
|
||||
a := NewAuthenticator(nil)
|
||||
a := NewAuthenticator(nil, slog.Default())
|
||||
a.cache["test"] = &cachedClaims{info: &TokenInfo{Username: "test"}}
|
||||
a.ClearCache()
|
||||
if len(a.cache) != 0 {
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
@@ -76,18 +77,20 @@ type Mount struct {
|
||||
|
||||
// Registry manages mounted engine instances.
|
||||
type Registry struct {
|
||||
mu sync.RWMutex
|
||||
mounts map[string]*Mount
|
||||
mu sync.RWMutex
|
||||
mounts map[string]*Mount
|
||||
factories map[EngineType]Factory
|
||||
barrier barrier.Barrier
|
||||
barrier barrier.Barrier
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewRegistry creates a new engine registry.
|
||||
func NewRegistry(b barrier.Barrier) *Registry {
|
||||
func NewRegistry(b barrier.Barrier, logger *slog.Logger) *Registry {
|
||||
return &Registry{
|
||||
mounts: make(map[string]*Mount),
|
||||
factories: make(map[EngineType]Factory),
|
||||
barrier: b,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -95,6 +98,7 @@ func NewRegistry(b barrier.Barrier) *Registry {
|
||||
func (r *Registry) RegisterFactory(t EngineType, f Factory) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.logger.Debug("registering engine factory", "type", t)
|
||||
r.factories[t] = f
|
||||
}
|
||||
|
||||
@@ -120,6 +124,7 @@ func (r *Registry) Mount(ctx context.Context, name string, engineType EngineType
|
||||
return fmt.Errorf("%w: %s", ErrUnknownType, engineType)
|
||||
}
|
||||
|
||||
r.logger.Debug("mounting engine", "name", name, "type", engineType)
|
||||
eng := factory()
|
||||
mountPath := fmt.Sprintf("engine/%s/%s/", engineType, name)
|
||||
|
||||
@@ -142,6 +147,7 @@ func (r *Registry) Mount(ctx context.Context, name string, engineType EngineType
|
||||
MountPath: mountPath,
|
||||
Engine: eng,
|
||||
}
|
||||
r.logger.Debug("engine mounted", "name", name, "type", engineType, "mount_path", mountPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -179,6 +185,7 @@ func (r *Registry) Unmount(ctx context.Context, name string) error {
|
||||
return ErrMountNotFound
|
||||
}
|
||||
|
||||
r.logger.Debug("unmounting engine", "name", name, "type", mount.Type)
|
||||
if err := mount.Engine.Seal(); err != nil {
|
||||
return fmt.Errorf("engine: seal %q: %w", name, err)
|
||||
}
|
||||
@@ -189,6 +196,7 @@ func (r *Registry) Unmount(ctx context.Context, name string) error {
|
||||
}
|
||||
|
||||
delete(r.mounts, name)
|
||||
r.logger.Debug("engine unmounted", "name", name)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -231,6 +239,7 @@ func (r *Registry) UnsealAll(ctx context.Context) error {
|
||||
continue // already loaded
|
||||
}
|
||||
|
||||
r.logger.Debug("discovered pre-migration engine mount", "name", name, "type", engineType)
|
||||
eng := factory()
|
||||
mountPath := fmt.Sprintf("engine/%s/%s/", engineType, name)
|
||||
if err := eng.Unseal(ctx, r.barrier, mountPath); err != nil {
|
||||
@@ -280,6 +289,7 @@ func (r *Registry) loadFromMetadata(ctx context.Context) error {
|
||||
return fmt.Errorf("%w: %s (mount %q)", ErrUnknownType, meta.Type, meta.Name)
|
||||
}
|
||||
|
||||
r.logger.Debug("unsealing engine from metadata", "name", meta.Name, "type", meta.Type)
|
||||
eng := factory()
|
||||
mountPath := fmt.Sprintf("engine/%s/%s/", meta.Type, meta.Name)
|
||||
if err := eng.Unseal(ctx, r.barrier, mountPath); err != nil {
|
||||
@@ -323,6 +333,7 @@ func (r *Registry) HandleRequest(ctx context.Context, mountName string, req *Req
|
||||
return nil, ErrMountNotFound
|
||||
}
|
||||
|
||||
r.logger.Debug("routing engine request", "mount", mountName, "operation", req.Operation, "path", req.Path)
|
||||
return mount.Engine.HandleRequest(ctx, req)
|
||||
}
|
||||
|
||||
@@ -331,6 +342,7 @@ func (r *Registry) SealAll() error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
r.logger.Debug("sealing all engines", "count", len(r.mounts))
|
||||
for name, mount := range r.mounts {
|
||||
if err := mount.Engine.Seal(); err != nil {
|
||||
return fmt.Errorf("engine: seal %q: %w", name, err)
|
||||
|
||||
@@ -2,6 +2,7 @@ package engine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"testing"
|
||||
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/barrier"
|
||||
@@ -39,7 +40,7 @@ func (m *mockBarrier) Delete(_ context.Context, _ string) error { retu
|
||||
func (m *mockBarrier) List(_ context.Context, _ string) ([]string, error) { return nil, nil }
|
||||
|
||||
func TestRegistryMountUnmount(t *testing.T) {
|
||||
reg := NewRegistry(&mockBarrier{})
|
||||
reg := NewRegistry(&mockBarrier{}, slog.Default())
|
||||
reg.RegisterFactory(EngineTypeTransit, func() Engine {
|
||||
return &mockEngine{engineType: EngineTypeTransit}
|
||||
})
|
||||
@@ -73,14 +74,14 @@ func TestRegistryMountUnmount(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRegistryUnmountNotFound(t *testing.T) {
|
||||
reg := NewRegistry(&mockBarrier{})
|
||||
reg := NewRegistry(&mockBarrier{}, slog.Default())
|
||||
if err := reg.Unmount(context.Background(), "nonexistent"); err != ErrMountNotFound {
|
||||
t.Fatalf("expected ErrMountNotFound, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryUnknownType(t *testing.T) {
|
||||
reg := NewRegistry(&mockBarrier{})
|
||||
reg := NewRegistry(&mockBarrier{}, slog.Default())
|
||||
err := reg.Mount(context.Background(), "test", EngineTypeTransit, nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unknown engine type")
|
||||
@@ -88,7 +89,7 @@ func TestRegistryUnknownType(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRegistryHandleRequest(t *testing.T) {
|
||||
reg := NewRegistry(&mockBarrier{})
|
||||
reg := NewRegistry(&mockBarrier{}, slog.Default())
|
||||
reg.RegisterFactory(EngineTypeTransit, func() Engine {
|
||||
return &mockEngine{engineType: EngineTypeTransit}
|
||||
})
|
||||
@@ -111,7 +112,7 @@ func TestRegistryHandleRequest(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRegistrySealAll(t *testing.T) {
|
||||
reg := NewRegistry(&mockBarrier{})
|
||||
reg := NewRegistry(&mockBarrier{}, slog.Default())
|
||||
reg.RegisterFactory(EngineTypeTransit, func() Engine {
|
||||
return &mockEngine{engineType: EngineTypeTransit}
|
||||
})
|
||||
|
||||
@@ -2,6 +2,7 @@ package grpcserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"strings"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
@@ -25,7 +26,7 @@ func tokenInfoFromContext(ctx context.Context) *auth.TokenInfo {
|
||||
// authInterceptor validates the Bearer token from gRPC metadata and injects
|
||||
// *auth.TokenInfo into the context. The set of method full names that require
|
||||
// auth is passed in; all others pass through without validation.
|
||||
func authInterceptor(authenticator *auth.Authenticator, methods map[string]bool) grpc.UnaryServerInterceptor {
|
||||
func authInterceptor(authenticator *auth.Authenticator, logger *slog.Logger, methods map[string]bool) grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
||||
if !methods[info.FullMethod] {
|
||||
return handler(ctx, req)
|
||||
@@ -33,14 +34,17 @@ func authInterceptor(authenticator *auth.Authenticator, methods map[string]bool)
|
||||
|
||||
token := extractToken(ctx)
|
||||
if token == "" {
|
||||
logger.Debug("grpc request rejected: missing token", "method", info.FullMethod)
|
||||
return nil, status.Error(codes.Unauthenticated, "missing authorization token")
|
||||
}
|
||||
|
||||
tokenInfo, err := authenticator.ValidateToken(token)
|
||||
if err != nil {
|
||||
logger.Debug("grpc request rejected: invalid token", "method", info.FullMethod, "error", err)
|
||||
return nil, status.Error(codes.Unauthenticated, "invalid token")
|
||||
}
|
||||
|
||||
logger.Debug("grpc request authenticated", "method", info.FullMethod, "username", tokenInfo.Username)
|
||||
ctx = context.WithValue(ctx, tokenInfoKey, tokenInfo)
|
||||
return handler(ctx, req)
|
||||
}
|
||||
@@ -48,27 +52,30 @@ func authInterceptor(authenticator *auth.Authenticator, methods map[string]bool)
|
||||
|
||||
// adminInterceptor requires IsAdmin on the token info for the listed methods.
|
||||
// Must run after authInterceptor.
|
||||
func adminInterceptor(methods map[string]bool) grpc.UnaryServerInterceptor {
|
||||
func adminInterceptor(logger *slog.Logger, methods map[string]bool) grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
||||
if !methods[info.FullMethod] {
|
||||
return handler(ctx, req)
|
||||
}
|
||||
ti := tokenInfoFromContext(ctx)
|
||||
if ti == nil || !ti.IsAdmin {
|
||||
logger.Debug("grpc request rejected: admin required", "method", info.FullMethod)
|
||||
return nil, status.Error(codes.PermissionDenied, "admin required")
|
||||
}
|
||||
logger.Debug("grpc admin request authorized", "method", info.FullMethod, "username", ti.Username)
|
||||
return handler(ctx, req)
|
||||
}
|
||||
}
|
||||
|
||||
// sealInterceptor rejects calls with FailedPrecondition when the vault is
|
||||
// sealed, for the listed methods.
|
||||
func sealInterceptor(sealMgr *seal.Manager, methods map[string]bool) grpc.UnaryServerInterceptor {
|
||||
func sealInterceptor(sealMgr *seal.Manager, logger *slog.Logger, methods map[string]bool) grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
||||
if !methods[info.FullMethod] {
|
||||
return handler(ctx, req)
|
||||
}
|
||||
if sealMgr.State() != seal.StateUnsealed {
|
||||
logger.Debug("grpc request rejected: vault sealed", "method", info.FullMethod)
|
||||
return nil, status.Error(codes.FailedPrecondition, "vault is sealed")
|
||||
}
|
||||
return handler(ctx, req)
|
||||
|
||||
@@ -66,9 +66,9 @@ func (s *GRPCServer) Start() error {
|
||||
creds := credentials.NewTLS(tlsCfg)
|
||||
|
||||
interceptor := chainInterceptors(
|
||||
sealInterceptor(s.sealMgr, sealRequiredMethods()),
|
||||
authInterceptor(s.auth, authRequiredMethods()),
|
||||
adminInterceptor(adminRequiredMethods()),
|
||||
sealInterceptor(s.sealMgr, s.logger, sealRequiredMethods()),
|
||||
authInterceptor(s.auth, s.logger, authRequiredMethods()),
|
||||
adminInterceptor(s.logger, adminRequiredMethods()),
|
||||
)
|
||||
|
||||
s.srv = grpc.NewServer(
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -51,6 +52,7 @@ var (
|
||||
type Manager struct {
|
||||
db *sql.DB
|
||||
barrier *barrier.AESGCMBarrier
|
||||
logger *slog.Logger
|
||||
|
||||
mu sync.RWMutex
|
||||
state ServiceState
|
||||
@@ -63,10 +65,11 @@ type Manager struct {
|
||||
}
|
||||
|
||||
// NewManager creates a new seal manager.
|
||||
func NewManager(db *sql.DB, b *barrier.AESGCMBarrier) *Manager {
|
||||
func NewManager(db *sql.DB, b *barrier.AESGCMBarrier, logger *slog.Logger) *Manager {
|
||||
return &Manager{
|
||||
db: db,
|
||||
barrier: b,
|
||||
logger: logger,
|
||||
state: StateUninitialized,
|
||||
}
|
||||
}
|
||||
@@ -98,8 +101,10 @@ func (m *Manager) CheckInitialized() error {
|
||||
}
|
||||
if count > 0 {
|
||||
m.state = StateSealed
|
||||
m.logger.Debug("seal config found, state set to sealed")
|
||||
} else {
|
||||
m.state = StateUninitialized
|
||||
m.logger.Debug("no seal config found, state set to uninitialized")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -114,6 +119,7 @@ func (m *Manager) Initialize(ctx context.Context, password []byte, params crypto
|
||||
return ErrAlreadyInitialized
|
||||
}
|
||||
|
||||
m.logger.Debug("initializing seal manager")
|
||||
m.state = StateInitializing
|
||||
defer func() {
|
||||
if m.mek == nil {
|
||||
@@ -162,6 +168,7 @@ func (m *Manager) Initialize(ctx context.Context, password []byte, params crypto
|
||||
|
||||
m.mek = mek
|
||||
m.state = StateUnsealed
|
||||
m.logger.Debug("seal initialization complete, barrier unsealed")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -177,9 +184,11 @@ func (m *Manager) Unseal(password []byte) error {
|
||||
return ErrNotSealed
|
||||
}
|
||||
|
||||
m.logger.Debug("unseal attempt")
|
||||
// Rate limiting.
|
||||
now := time.Now()
|
||||
if now.Before(m.lockoutUntil) {
|
||||
m.logger.Debug("unseal attempt rate limited")
|
||||
return ErrRateLimited
|
||||
}
|
||||
if now.Sub(m.lastAttempt) > time.Minute {
|
||||
@@ -190,6 +199,7 @@ func (m *Manager) Unseal(password []byte) error {
|
||||
if m.unsealAttempts > 5 {
|
||||
m.lockoutUntil = now.Add(60 * time.Second)
|
||||
m.unsealAttempts = 0
|
||||
m.logger.Debug("unseal attempts exceeded, locking out")
|
||||
return ErrRateLimited
|
||||
}
|
||||
|
||||
@@ -215,6 +225,7 @@ func (m *Manager) Unseal(password []byte) error {
|
||||
|
||||
mek, err := crypto.Decrypt(kwk, encryptedMEK)
|
||||
if err != nil {
|
||||
m.logger.Debug("unseal failed: invalid password")
|
||||
return ErrInvalidPassword
|
||||
}
|
||||
|
||||
@@ -227,6 +238,7 @@ func (m *Manager) Unseal(password []byte) error {
|
||||
m.mek = mek
|
||||
m.state = StateUnsealed
|
||||
m.unsealAttempts = 0
|
||||
m.logger.Debug("unseal succeeded, barrier unsealed")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -239,11 +251,13 @@ func (m *Manager) Seal() error {
|
||||
return ErrNotSealed
|
||||
}
|
||||
|
||||
m.logger.Debug("sealing service")
|
||||
if m.mek != nil {
|
||||
crypto.Zeroize(m.mek)
|
||||
m.mek = nil
|
||||
}
|
||||
m.barrier.Seal()
|
||||
m.state = StateSealed
|
||||
m.logger.Debug("service sealed")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package seal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
@@ -21,7 +22,7 @@ func setupSeal(t *testing.T) (*Manager, func()) {
|
||||
t.Fatalf("migrate: %v", err)
|
||||
}
|
||||
b := barrier.NewAESGCMBarrier(database)
|
||||
mgr := NewManager(database, b)
|
||||
mgr := NewManager(database, b, slog.Default())
|
||||
return mgr, func() { database.Close() }
|
||||
}
|
||||
|
||||
@@ -101,7 +102,7 @@ func TestSealCheckInitializedPersists(t *testing.T) {
|
||||
database, _ := db.Open(dbPath)
|
||||
db.Migrate(database)
|
||||
b := barrier.NewAESGCMBarrier(database)
|
||||
mgr := NewManager(database, b)
|
||||
mgr := NewManager(database, b, slog.Default())
|
||||
mgr.CheckInitialized()
|
||||
params := crypto.Argon2Params{Time: 1, Memory: 64 * 1024, Threads: 1}
|
||||
mgr.Initialize(context.Background(), []byte("password"), params)
|
||||
@@ -111,7 +112,7 @@ func TestSealCheckInitializedPersists(t *testing.T) {
|
||||
database2, _ := db.Open(dbPath)
|
||||
defer database2.Close()
|
||||
b2 := barrier.NewAESGCMBarrier(database2)
|
||||
mgr2 := NewManager(database2, b2)
|
||||
mgr2 := NewManager(database2, b2, slog.Default())
|
||||
mgr2.CheckInitialized()
|
||||
if mgr2.State() != StateSealed {
|
||||
t.Fatalf("state after reopen: got %v, want Sealed", mgr2.State())
|
||||
|
||||
@@ -42,9 +42,11 @@ func (s *Server) requireUnseal(next http.HandlerFunc) http.HandlerFunc {
|
||||
state := s.seal.State()
|
||||
switch state {
|
||||
case seal.StateUninitialized:
|
||||
s.logger.Debug("request rejected: service uninitialized", "path", r.URL.Path)
|
||||
http.Error(w, `{"error":"not initialized"}`, http.StatusPreconditionFailed)
|
||||
return
|
||||
case seal.StateSealed, seal.StateInitializing:
|
||||
s.logger.Debug("request rejected: service sealed", "path", r.URL.Path)
|
||||
http.Error(w, `{"error":"sealed"}`, http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
@@ -57,16 +59,19 @@ func (s *Server) requireAuth(next http.HandlerFunc) http.HandlerFunc {
|
||||
return s.requireUnseal(func(w http.ResponseWriter, r *http.Request) {
|
||||
token := extractToken(r)
|
||||
if token == "" {
|
||||
s.logger.Debug("request rejected: missing token", "path", r.URL.Path)
|
||||
http.Error(w, `{"error":"unauthorized"}`, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
info, err := s.auth.ValidateToken(token)
|
||||
if err != nil {
|
||||
s.logger.Debug("request rejected: invalid token", "path", r.URL.Path, "error", err)
|
||||
http.Error(w, `{"error":"unauthorized"}`, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
s.logger.Debug("request authenticated", "path", r.URL.Path, "username", info.Username)
|
||||
ctx := context.WithValue(r.Context(), tokenInfoKey, info)
|
||||
next(w, r.WithContext(ctx))
|
||||
})
|
||||
@@ -77,9 +82,11 @@ func (s *Server) requireAdmin(next http.HandlerFunc) http.HandlerFunc {
|
||||
return s.requireAuth(func(w http.ResponseWriter, r *http.Request) {
|
||||
info := TokenInfoFromContext(r.Context())
|
||||
if info == nil || !info.IsAdmin {
|
||||
s.logger.Debug("request rejected: admin required", "path", r.URL.Path)
|
||||
http.Error(w, `{"error":"forbidden"}`, http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
s.logger.Debug("admin request authorized", "path", r.URL.Path, "username", info.Username)
|
||||
next(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -36,14 +36,14 @@ func setupTestServer(t *testing.T) (*Server, *seal.Manager, chi.Router) {
|
||||
db.Migrate(database)
|
||||
|
||||
b := barrier.NewAESGCMBarrier(database)
|
||||
sealMgr := seal.NewManager(database, b)
|
||||
sealMgr := seal.NewManager(database, b, slog.Default())
|
||||
sealMgr.CheckInitialized()
|
||||
|
||||
// Auth requires MCIAS client which we can't create in tests easily,
|
||||
// so we pass nil and avoid auth-dependent routes in these tests.
|
||||
authenticator := auth.NewAuthenticator(nil)
|
||||
authenticator := auth.NewAuthenticator(nil, slog.Default())
|
||||
policyEngine := policy.NewEngine(b)
|
||||
engineRegistry := engine.NewRegistry(b)
|
||||
engineRegistry := engine.NewRegistry(b, slog.Default())
|
||||
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
|
||||
Reference in New Issue
Block a user