Files
mcias/internal/grpcserver/grpcserver.go
Kyle Isom d87b4b4042 Add vault seal/unseal lifecycle
- New internal/vault package: thread-safe Vault struct with
  seal/unseal state, key material zeroing, and key derivation
- REST: POST /v1/vault/unseal, POST /v1/vault/seal,
  GET /v1/vault/status; health returns sealed status
- UI: /unseal page with passphrase form, redirect when sealed
- gRPC: sealedInterceptor rejects RPCs when sealed
- Middleware: RequireUnsealed blocks all routes except exempt
  paths; RequireAuth reads pubkey from vault at request time
- Startup: server starts sealed when passphrase unavailable
- All servers share single *vault.Vault by pointer
- CSRF manager derives key lazily from vault

Security: Key material is zeroed on seal. Sealed middleware
runs before auth. Handlers fail closed if vault becomes sealed
mid-request. Unseal endpoint is rate-limited (3/s burst 5).
No CSRF on unseal page (no session to protect; chicken-and-egg
with master key). Passphrase never logged.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-14 23:55:37 -07:00

419 lines
13 KiB
Go

// Package grpcserver provides a gRPC server that exposes the same
// functionality as the REST HTTP server using the same internal packages.
//
// Security design:
// - All RPCs share business logic with the REST server via internal/auth,
// internal/token, internal/db, and internal/crypto packages.
// - Authentication uses the same JWT validation path as the REST middleware:
// alg-first check, signature verification, revocation table lookup.
// - The authorization metadata key is "authorization"; its value must be
// "Bearer <token>" (case-insensitive prefix check).
// - Credential fields (PasswordHash, TOTPSecret*, PGPassword) are never
// included in any RPC response message.
// - No credential material is logged by any interceptor.
// - TLS is required at the listener level (enforced by cmd/mciassrv).
// - Public RPCs (Health, GetPublicKey, ValidateToken) bypass auth.
package grpcserver
import (
"context"
"log/slog"
"net"
"strings"
"sync"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
mciasv1 "git.wntrmute.dev/kyle/mcias/gen/mcias/v1"
"git.wntrmute.dev/kyle/mcias/internal/config"
"git.wntrmute.dev/kyle/mcias/internal/db"
"git.wntrmute.dev/kyle/mcias/internal/token"
"git.wntrmute.dev/kyle/mcias/internal/vault"
)
// contextKey is the unexported context key type for this package.
type contextKey int
const (
claimsCtxKey contextKey = iota
)
// claimsFromContext retrieves JWT claims injected by the auth interceptor.
// Returns nil for unauthenticated (public) RPCs.
func claimsFromContext(ctx context.Context) *token.Claims {
c, _ := ctx.Value(claimsCtxKey).(*token.Claims)
return c
}
// Server holds the shared state for all gRPC service implementations.
type Server struct {
db *db.DB
cfg *config.Config
logger *slog.Logger
rateLimiter *grpcRateLimiter
vault *vault.Vault
}
// New creates a Server with the given dependencies (same as the REST Server).
// A fresh per-IP rate limiter (10 req/s, burst 10) is allocated per Server
// instance so that tests do not share state across test cases.
func New(database *db.DB, cfg *config.Config, v *vault.Vault, logger *slog.Logger) *Server {
return &Server{
db: database,
cfg: cfg,
vault: v,
logger: logger,
rateLimiter: newGRPCRateLimiter(10, 10),
}
}
// publicMethods is the set of fully-qualified method names that bypass auth.
// These match the gRPC full method path: /<package>.<Service>/<Method>.
var publicMethods = map[string]bool{
"/mcias.v1.AdminService/Health": true,
"/mcias.v1.AdminService/GetPublicKey": true,
"/mcias.v1.TokenService/ValidateToken": true,
"/mcias.v1.AuthService/Login": true,
}
// GRPCServer builds and returns a configured *grpc.Server with all services
// registered and the interceptor chain installed. The returned server uses no
// transport credentials; callers are expected to wrap with TLS via GRPCServerWithCreds.
func (s *Server) GRPCServer() *grpc.Server {
return s.buildServer()
}
// GRPCServerWithCreds builds a *grpc.Server with TLS transport credentials.
// This is the method to use when starting a TLS gRPC listener; TLS credentials
// must be passed at server-construction time per the gRPC idiom.
func (s *Server) GRPCServerWithCreds(creds credentials.TransportCredentials) *grpc.Server {
return s.buildServer(grpc.Creds(creds))
}
// buildServer constructs the grpc.Server with optional additional server options.
func (s *Server) buildServer(extra ...grpc.ServerOption) *grpc.Server {
opts := append(
[]grpc.ServerOption{
grpc.ChainUnaryInterceptor(
s.loggingInterceptor,
s.sealedInterceptor,
s.authInterceptor,
s.rateLimitInterceptor,
),
},
extra...,
)
srv := grpc.NewServer(opts...)
// Register service implementations.
mciasv1.RegisterAdminServiceServer(srv, &adminServiceServer{s: s})
mciasv1.RegisterAuthServiceServer(srv, &authServiceServer{s: s})
mciasv1.RegisterTokenServiceServer(srv, &tokenServiceServer{s: s})
mciasv1.RegisterAccountServiceServer(srv, &accountServiceServer{s: s})
mciasv1.RegisterCredentialServiceServer(srv, &credentialServiceServer{s: s})
mciasv1.RegisterPolicyServiceServer(srv, &policyServiceServer{s: s})
return srv
}
// loggingInterceptor logs each unary RPC call with method, peer IP, status,
// and duration. The authorization metadata value is never logged.
func (s *Server) loggingInterceptor(
ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
start := time.Now()
peerIP := ""
if p, ok := peer.FromContext(ctx); ok {
host, _, err := net.SplitHostPort(p.Addr.String())
if err == nil {
peerIP = host
} else {
peerIP = p.Addr.String()
}
}
resp, err := handler(ctx, req)
code := codes.OK
if err != nil {
code = status.Code(err)
}
// Security: authorization metadata is never logged.
s.logger.Info("grpc request",
"method", info.FullMethod,
"peer_ip", peerIP,
"code", code.String(),
"duration_ms", time.Since(start).Milliseconds(),
)
return resp, err
}
// sealedInterceptor rejects all RPCs (except Health) when the vault is sealed.
//
// Security: This is the first interceptor in the chain (after logging). It
// prevents any authenticated or data-serving handler from running while the
// vault is sealed and key material is unavailable.
func (s *Server) sealedInterceptor(
ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
if !s.vault.IsSealed() {
return handler(ctx, req)
}
// Health is always allowed — returns sealed status.
if info.FullMethod == "/mcias.v1.AdminService/Health" {
return handler(ctx, req)
}
return nil, status.Error(codes.Unavailable, "vault sealed")
}
// authInterceptor validates the Bearer JWT from gRPC metadata and injects
// claims into the context. Public methods bypass this check.
//
// Security: Same validation path as the REST RequireAuth middleware:
// 1. Extract "authorization" metadata value (case-insensitive key lookup).
// 2. Read public key from vault (fail closed if sealed).
// 3. Validate JWT (alg-first, then signature, then expiry/issuer).
// 4. Check JTI against revocation table.
// 5. Inject claims into context.
func (s *Server) authInterceptor(
ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
if publicMethods[info.FullMethod] {
return handler(ctx, req)
}
tokenStr, err := extractBearerFromMD(ctx)
if err != nil {
// Security: do not reveal whether the header was missing vs. malformed.
return nil, status.Error(codes.Unauthenticated, "missing or invalid authorization")
}
// Security: read the public key from vault at request time.
pubKey, err := s.vault.PubKey()
if err != nil {
return nil, status.Error(codes.Unavailable, "vault sealed")
}
claims, err := token.ValidateToken(pubKey, tokenStr, s.cfg.Tokens.Issuer)
if err != nil {
return nil, status.Error(codes.Unauthenticated, "invalid or expired token")
}
// Security: check revocation table after signature validation.
rec, err := s.db.GetTokenRecord(claims.JTI)
if err != nil || rec.IsRevoked() {
return nil, status.Error(codes.Unauthenticated, "token has been revoked")
}
ctx = context.WithValue(ctx, claimsCtxKey, claims)
return handler(ctx, req)
}
// requireAdmin checks that the claims in context contain the "admin" role.
// Called by admin-only RPC handlers after the authInterceptor has run.
//
// Security: Mirrors the REST RequireRole("admin") middleware check; checked
// after auth so claims are always populated when this function is reached.
func (s *Server) requireAdmin(ctx context.Context) error {
claims := claimsFromContext(ctx)
if claims == nil {
return status.Error(codes.PermissionDenied, "insufficient privileges")
}
if !claims.HasRole("admin") {
return status.Error(codes.PermissionDenied, "insufficient privileges")
}
return nil
}
// --- Rate limiter ---
// grpcRateLimiter is a per-IP token bucket for gRPC, sharing the same
// algorithm as the REST RateLimit middleware.
type grpcRateLimiter struct {
ips map[string]*grpcRateLimitEntry
rps float64
burst float64
ttl time.Duration
mu sync.Mutex
}
type grpcRateLimitEntry struct {
lastSeen time.Time
tokens float64
mu sync.Mutex
}
func newGRPCRateLimiter(rps float64, burst int) *grpcRateLimiter {
l := &grpcRateLimiter{
rps: rps,
burst: float64(burst),
ttl: 10 * time.Minute,
ips: make(map[string]*grpcRateLimitEntry),
}
go l.cleanup()
return l
}
func (l *grpcRateLimiter) allow(ip string) bool {
l.mu.Lock()
entry, ok := l.ips[ip]
if !ok {
entry = &grpcRateLimitEntry{tokens: l.burst, lastSeen: time.Now()}
l.ips[ip] = entry
}
l.mu.Unlock()
entry.mu.Lock()
defer entry.mu.Unlock()
now := time.Now()
elapsed := now.Sub(entry.lastSeen).Seconds()
entry.tokens = minFloat64(l.burst, entry.tokens+elapsed*l.rps)
entry.lastSeen = now
if entry.tokens < 1 {
return false
}
entry.tokens--
return true
}
func (l *grpcRateLimiter) cleanup() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
l.mu.Lock()
cutoff := time.Now().Add(-l.ttl)
for ip, entry := range l.ips {
entry.mu.Lock()
if entry.lastSeen.Before(cutoff) {
delete(l.ips, ip)
}
entry.mu.Unlock()
}
l.mu.Unlock()
}
}
// rateLimitInterceptor applies per-IP rate limiting using the same token-bucket
// parameters as the REST rate limiter (10 req/s, burst 10).
//
// Security (SEC-06): uses grpcClientIP to extract the real client IP when
// behind a trusted reverse proxy, matching the REST middleware behaviour.
func (s *Server) rateLimitInterceptor(
ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
var trustedProxy net.IP
if s.cfg.Server.TrustedProxy != "" {
trustedProxy = net.ParseIP(s.cfg.Server.TrustedProxy)
}
ip := grpcClientIP(ctx, trustedProxy)
if ip != "" && !s.rateLimiter.allow(ip) {
return nil, status.Error(codes.ResourceExhausted, "rate limit exceeded")
}
return handler(ctx, req)
}
// grpcClientIP extracts the real client IP from gRPC context, optionally
// honouring proxy headers when the peer matches the trusted proxy.
//
// Security (SEC-06): mirrors middleware.ClientIP for the REST server.
// X-Forwarded-For and X-Real-IP metadata are only trusted when the immediate
// peer address matches trustedProxy exactly, preventing IP-spoofing attacks.
// Only the first (leftmost) value in x-forwarded-for is used (original client).
// gRPC lowercases all metadata keys, so we look up "x-forwarded-for" and
// "x-real-ip".
func grpcClientIP(ctx context.Context, trustedProxy net.IP) string {
peerIP := ""
if p, ok := peer.FromContext(ctx); ok {
host, _, err := net.SplitHostPort(p.Addr.String())
if err == nil {
peerIP = host
} else {
peerIP = p.Addr.String()
}
}
if trustedProxy != nil && peerIP != "" {
remoteIP := net.ParseIP(peerIP)
if remoteIP != nil && remoteIP.Equal(trustedProxy) {
// Peer is the trusted proxy — extract real client IP from metadata.
// Prefer x-real-ip (single value) over x-forwarded-for (may be a
// comma-separated list when multiple proxies are chained).
md, ok := metadata.FromIncomingContext(ctx)
if ok {
if vals := md.Get("x-real-ip"); len(vals) > 0 {
if ip := net.ParseIP(strings.TrimSpace(vals[0])); ip != nil {
return ip.String()
}
}
if vals := md.Get("x-forwarded-for"); len(vals) > 0 {
// Take the first (leftmost) address — the original client.
first, _, _ := strings.Cut(vals[0], ",")
if ip := net.ParseIP(strings.TrimSpace(first)); ip != nil {
return ip.String()
}
}
}
}
}
return peerIP
}
// extractBearerFromMD extracts the Bearer token from gRPC metadata.
// The key lookup is case-insensitive per gRPC metadata convention (all keys
// are lowercased by the framework; we match on "authorization").
//
// Security: The metadata value is never logged.
func extractBearerFromMD(ctx context.Context) (string, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return "", status.Error(codes.Unauthenticated, "no metadata")
}
vals := md.Get("authorization")
if len(vals) == 0 {
return "", status.Error(codes.Unauthenticated, "missing authorization metadata")
}
auth := vals[0]
const prefix = "bearer "
if len(auth) <= len(prefix) || !strings.EqualFold(auth[:len(prefix)], prefix) {
return "", status.Error(codes.Unauthenticated, "malformed authorization metadata")
}
t := auth[len(prefix):]
if t == "" {
return "", status.Error(codes.Unauthenticated, "empty bearer token")
}
return t, nil
}
// minFloat64 returns the smaller of two float64 values.
func minFloat64(a, b float64) float64 {
if a < b {
return a
}
return b
}