- New internal/vault package: thread-safe Vault struct with seal/unseal state, key material zeroing, and key derivation - REST: POST /v1/vault/unseal, POST /v1/vault/seal, GET /v1/vault/status; health returns sealed status - UI: /unseal page with passphrase form, redirect when sealed - gRPC: sealedInterceptor rejects RPCs when sealed - Middleware: RequireUnsealed blocks all routes except exempt paths; RequireAuth reads pubkey from vault at request time - Startup: server starts sealed when passphrase unavailable - All servers share single *vault.Vault by pointer - CSRF manager derives key lazily from vault Security: Key material is zeroed on seal. Sealed middleware runs before auth. Handlers fail closed if vault becomes sealed mid-request. Unseal endpoint is rate-limited (3/s burst 5). No CSRF on unseal page (no session to protect; chicken-and-egg with master key). Passphrase never logged. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
419 lines
13 KiB
Go
419 lines
13 KiB
Go
// Package grpcserver provides a gRPC server that exposes the same
|
|
// functionality as the REST HTTP server using the same internal packages.
|
|
//
|
|
// Security design:
|
|
// - All RPCs share business logic with the REST server via internal/auth,
|
|
// internal/token, internal/db, and internal/crypto packages.
|
|
// - Authentication uses the same JWT validation path as the REST middleware:
|
|
// alg-first check, signature verification, revocation table lookup.
|
|
// - The authorization metadata key is "authorization"; its value must be
|
|
// "Bearer <token>" (case-insensitive prefix check).
|
|
// - Credential fields (PasswordHash, TOTPSecret*, PGPassword) are never
|
|
// included in any RPC response message.
|
|
// - No credential material is logged by any interceptor.
|
|
// - TLS is required at the listener level (enforced by cmd/mciassrv).
|
|
// - Public RPCs (Health, GetPublicKey, ValidateToken) bypass auth.
|
|
package grpcserver
|
|
|
|
import (
|
|
"context"
|
|
"log/slog"
|
|
"net"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/credentials"
|
|
"google.golang.org/grpc/metadata"
|
|
"google.golang.org/grpc/peer"
|
|
"google.golang.org/grpc/status"
|
|
|
|
mciasv1 "git.wntrmute.dev/kyle/mcias/gen/mcias/v1"
|
|
"git.wntrmute.dev/kyle/mcias/internal/config"
|
|
"git.wntrmute.dev/kyle/mcias/internal/db"
|
|
"git.wntrmute.dev/kyle/mcias/internal/token"
|
|
"git.wntrmute.dev/kyle/mcias/internal/vault"
|
|
)
|
|
|
|
// contextKey is the unexported context key type for this package.
|
|
type contextKey int
|
|
|
|
const (
|
|
claimsCtxKey contextKey = iota
|
|
)
|
|
|
|
// claimsFromContext retrieves JWT claims injected by the auth interceptor.
|
|
// Returns nil for unauthenticated (public) RPCs.
|
|
func claimsFromContext(ctx context.Context) *token.Claims {
|
|
c, _ := ctx.Value(claimsCtxKey).(*token.Claims)
|
|
return c
|
|
}
|
|
|
|
// Server holds the shared state for all gRPC service implementations.
|
|
type Server struct {
|
|
db *db.DB
|
|
cfg *config.Config
|
|
logger *slog.Logger
|
|
rateLimiter *grpcRateLimiter
|
|
vault *vault.Vault
|
|
}
|
|
|
|
// New creates a Server with the given dependencies (same as the REST Server).
|
|
// A fresh per-IP rate limiter (10 req/s, burst 10) is allocated per Server
|
|
// instance so that tests do not share state across test cases.
|
|
func New(database *db.DB, cfg *config.Config, v *vault.Vault, logger *slog.Logger) *Server {
|
|
return &Server{
|
|
db: database,
|
|
cfg: cfg,
|
|
vault: v,
|
|
logger: logger,
|
|
rateLimiter: newGRPCRateLimiter(10, 10),
|
|
}
|
|
}
|
|
|
|
// publicMethods is the set of fully-qualified method names that bypass auth.
|
|
// These match the gRPC full method path: /<package>.<Service>/<Method>.
|
|
var publicMethods = map[string]bool{
|
|
"/mcias.v1.AdminService/Health": true,
|
|
"/mcias.v1.AdminService/GetPublicKey": true,
|
|
"/mcias.v1.TokenService/ValidateToken": true,
|
|
"/mcias.v1.AuthService/Login": true,
|
|
}
|
|
|
|
// GRPCServer builds and returns a configured *grpc.Server with all services
|
|
// registered and the interceptor chain installed. The returned server uses no
|
|
// transport credentials; callers are expected to wrap with TLS via GRPCServerWithCreds.
|
|
func (s *Server) GRPCServer() *grpc.Server {
|
|
return s.buildServer()
|
|
}
|
|
|
|
// GRPCServerWithCreds builds a *grpc.Server with TLS transport credentials.
|
|
// This is the method to use when starting a TLS gRPC listener; TLS credentials
|
|
// must be passed at server-construction time per the gRPC idiom.
|
|
func (s *Server) GRPCServerWithCreds(creds credentials.TransportCredentials) *grpc.Server {
|
|
return s.buildServer(grpc.Creds(creds))
|
|
}
|
|
|
|
// buildServer constructs the grpc.Server with optional additional server options.
|
|
func (s *Server) buildServer(extra ...grpc.ServerOption) *grpc.Server {
|
|
opts := append(
|
|
[]grpc.ServerOption{
|
|
grpc.ChainUnaryInterceptor(
|
|
s.loggingInterceptor,
|
|
s.sealedInterceptor,
|
|
s.authInterceptor,
|
|
s.rateLimitInterceptor,
|
|
),
|
|
},
|
|
extra...,
|
|
)
|
|
srv := grpc.NewServer(opts...)
|
|
|
|
// Register service implementations.
|
|
mciasv1.RegisterAdminServiceServer(srv, &adminServiceServer{s: s})
|
|
mciasv1.RegisterAuthServiceServer(srv, &authServiceServer{s: s})
|
|
mciasv1.RegisterTokenServiceServer(srv, &tokenServiceServer{s: s})
|
|
mciasv1.RegisterAccountServiceServer(srv, &accountServiceServer{s: s})
|
|
mciasv1.RegisterCredentialServiceServer(srv, &credentialServiceServer{s: s})
|
|
mciasv1.RegisterPolicyServiceServer(srv, &policyServiceServer{s: s})
|
|
|
|
return srv
|
|
}
|
|
|
|
// loggingInterceptor logs each unary RPC call with method, peer IP, status,
|
|
// and duration. The authorization metadata value is never logged.
|
|
func (s *Server) loggingInterceptor(
|
|
ctx context.Context,
|
|
req interface{},
|
|
info *grpc.UnaryServerInfo,
|
|
handler grpc.UnaryHandler,
|
|
) (interface{}, error) {
|
|
start := time.Now()
|
|
|
|
peerIP := ""
|
|
if p, ok := peer.FromContext(ctx); ok {
|
|
host, _, err := net.SplitHostPort(p.Addr.String())
|
|
if err == nil {
|
|
peerIP = host
|
|
} else {
|
|
peerIP = p.Addr.String()
|
|
}
|
|
}
|
|
|
|
resp, err := handler(ctx, req)
|
|
|
|
code := codes.OK
|
|
if err != nil {
|
|
code = status.Code(err)
|
|
}
|
|
|
|
// Security: authorization metadata is never logged.
|
|
s.logger.Info("grpc request",
|
|
"method", info.FullMethod,
|
|
"peer_ip", peerIP,
|
|
"code", code.String(),
|
|
"duration_ms", time.Since(start).Milliseconds(),
|
|
)
|
|
return resp, err
|
|
}
|
|
|
|
// sealedInterceptor rejects all RPCs (except Health) when the vault is sealed.
|
|
//
|
|
// Security: This is the first interceptor in the chain (after logging). It
|
|
// prevents any authenticated or data-serving handler from running while the
|
|
// vault is sealed and key material is unavailable.
|
|
func (s *Server) sealedInterceptor(
|
|
ctx context.Context,
|
|
req interface{},
|
|
info *grpc.UnaryServerInfo,
|
|
handler grpc.UnaryHandler,
|
|
) (interface{}, error) {
|
|
if !s.vault.IsSealed() {
|
|
return handler(ctx, req)
|
|
}
|
|
// Health is always allowed — returns sealed status.
|
|
if info.FullMethod == "/mcias.v1.AdminService/Health" {
|
|
return handler(ctx, req)
|
|
}
|
|
return nil, status.Error(codes.Unavailable, "vault sealed")
|
|
}
|
|
|
|
// authInterceptor validates the Bearer JWT from gRPC metadata and injects
|
|
// claims into the context. Public methods bypass this check.
|
|
//
|
|
// Security: Same validation path as the REST RequireAuth middleware:
|
|
// 1. Extract "authorization" metadata value (case-insensitive key lookup).
|
|
// 2. Read public key from vault (fail closed if sealed).
|
|
// 3. Validate JWT (alg-first, then signature, then expiry/issuer).
|
|
// 4. Check JTI against revocation table.
|
|
// 5. Inject claims into context.
|
|
func (s *Server) authInterceptor(
|
|
ctx context.Context,
|
|
req interface{},
|
|
info *grpc.UnaryServerInfo,
|
|
handler grpc.UnaryHandler,
|
|
) (interface{}, error) {
|
|
if publicMethods[info.FullMethod] {
|
|
return handler(ctx, req)
|
|
}
|
|
|
|
tokenStr, err := extractBearerFromMD(ctx)
|
|
if err != nil {
|
|
// Security: do not reveal whether the header was missing vs. malformed.
|
|
return nil, status.Error(codes.Unauthenticated, "missing or invalid authorization")
|
|
}
|
|
|
|
// Security: read the public key from vault at request time.
|
|
pubKey, err := s.vault.PubKey()
|
|
if err != nil {
|
|
return nil, status.Error(codes.Unavailable, "vault sealed")
|
|
}
|
|
|
|
claims, err := token.ValidateToken(pubKey, tokenStr, s.cfg.Tokens.Issuer)
|
|
if err != nil {
|
|
return nil, status.Error(codes.Unauthenticated, "invalid or expired token")
|
|
}
|
|
|
|
// Security: check revocation table after signature validation.
|
|
rec, err := s.db.GetTokenRecord(claims.JTI)
|
|
if err != nil || rec.IsRevoked() {
|
|
return nil, status.Error(codes.Unauthenticated, "token has been revoked")
|
|
}
|
|
|
|
ctx = context.WithValue(ctx, claimsCtxKey, claims)
|
|
return handler(ctx, req)
|
|
}
|
|
|
|
// requireAdmin checks that the claims in context contain the "admin" role.
|
|
// Called by admin-only RPC handlers after the authInterceptor has run.
|
|
//
|
|
// Security: Mirrors the REST RequireRole("admin") middleware check; checked
|
|
// after auth so claims are always populated when this function is reached.
|
|
func (s *Server) requireAdmin(ctx context.Context) error {
|
|
claims := claimsFromContext(ctx)
|
|
if claims == nil {
|
|
return status.Error(codes.PermissionDenied, "insufficient privileges")
|
|
}
|
|
if !claims.HasRole("admin") {
|
|
return status.Error(codes.PermissionDenied, "insufficient privileges")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// --- Rate limiter ---
|
|
|
|
// grpcRateLimiter is a per-IP token bucket for gRPC, sharing the same
|
|
// algorithm as the REST RateLimit middleware.
|
|
type grpcRateLimiter struct {
|
|
ips map[string]*grpcRateLimitEntry
|
|
rps float64
|
|
burst float64
|
|
ttl time.Duration
|
|
mu sync.Mutex
|
|
}
|
|
|
|
type grpcRateLimitEntry struct {
|
|
lastSeen time.Time
|
|
tokens float64
|
|
mu sync.Mutex
|
|
}
|
|
|
|
func newGRPCRateLimiter(rps float64, burst int) *grpcRateLimiter {
|
|
l := &grpcRateLimiter{
|
|
rps: rps,
|
|
burst: float64(burst),
|
|
ttl: 10 * time.Minute,
|
|
ips: make(map[string]*grpcRateLimitEntry),
|
|
}
|
|
go l.cleanup()
|
|
return l
|
|
}
|
|
|
|
func (l *grpcRateLimiter) allow(ip string) bool {
|
|
l.mu.Lock()
|
|
entry, ok := l.ips[ip]
|
|
if !ok {
|
|
entry = &grpcRateLimitEntry{tokens: l.burst, lastSeen: time.Now()}
|
|
l.ips[ip] = entry
|
|
}
|
|
l.mu.Unlock()
|
|
|
|
entry.mu.Lock()
|
|
defer entry.mu.Unlock()
|
|
|
|
now := time.Now()
|
|
elapsed := now.Sub(entry.lastSeen).Seconds()
|
|
entry.tokens = minFloat64(l.burst, entry.tokens+elapsed*l.rps)
|
|
entry.lastSeen = now
|
|
|
|
if entry.tokens < 1 {
|
|
return false
|
|
}
|
|
entry.tokens--
|
|
return true
|
|
}
|
|
|
|
func (l *grpcRateLimiter) cleanup() {
|
|
ticker := time.NewTicker(5 * time.Minute)
|
|
defer ticker.Stop()
|
|
for range ticker.C {
|
|
l.mu.Lock()
|
|
cutoff := time.Now().Add(-l.ttl)
|
|
for ip, entry := range l.ips {
|
|
entry.mu.Lock()
|
|
if entry.lastSeen.Before(cutoff) {
|
|
delete(l.ips, ip)
|
|
}
|
|
entry.mu.Unlock()
|
|
}
|
|
l.mu.Unlock()
|
|
}
|
|
}
|
|
|
|
// rateLimitInterceptor applies per-IP rate limiting using the same token-bucket
|
|
// parameters as the REST rate limiter (10 req/s, burst 10).
|
|
//
|
|
// Security (SEC-06): uses grpcClientIP to extract the real client IP when
|
|
// behind a trusted reverse proxy, matching the REST middleware behaviour.
|
|
func (s *Server) rateLimitInterceptor(
|
|
ctx context.Context,
|
|
req interface{},
|
|
info *grpc.UnaryServerInfo,
|
|
handler grpc.UnaryHandler,
|
|
) (interface{}, error) {
|
|
var trustedProxy net.IP
|
|
if s.cfg.Server.TrustedProxy != "" {
|
|
trustedProxy = net.ParseIP(s.cfg.Server.TrustedProxy)
|
|
}
|
|
|
|
ip := grpcClientIP(ctx, trustedProxy)
|
|
|
|
if ip != "" && !s.rateLimiter.allow(ip) {
|
|
return nil, status.Error(codes.ResourceExhausted, "rate limit exceeded")
|
|
}
|
|
return handler(ctx, req)
|
|
}
|
|
|
|
// grpcClientIP extracts the real client IP from gRPC context, optionally
|
|
// honouring proxy headers when the peer matches the trusted proxy.
|
|
//
|
|
// Security (SEC-06): mirrors middleware.ClientIP for the REST server.
|
|
// X-Forwarded-For and X-Real-IP metadata are only trusted when the immediate
|
|
// peer address matches trustedProxy exactly, preventing IP-spoofing attacks.
|
|
// Only the first (leftmost) value in x-forwarded-for is used (original client).
|
|
// gRPC lowercases all metadata keys, so we look up "x-forwarded-for" and
|
|
// "x-real-ip".
|
|
func grpcClientIP(ctx context.Context, trustedProxy net.IP) string {
|
|
peerIP := ""
|
|
if p, ok := peer.FromContext(ctx); ok {
|
|
host, _, err := net.SplitHostPort(p.Addr.String())
|
|
if err == nil {
|
|
peerIP = host
|
|
} else {
|
|
peerIP = p.Addr.String()
|
|
}
|
|
}
|
|
|
|
if trustedProxy != nil && peerIP != "" {
|
|
remoteIP := net.ParseIP(peerIP)
|
|
if remoteIP != nil && remoteIP.Equal(trustedProxy) {
|
|
// Peer is the trusted proxy — extract real client IP from metadata.
|
|
// Prefer x-real-ip (single value) over x-forwarded-for (may be a
|
|
// comma-separated list when multiple proxies are chained).
|
|
md, ok := metadata.FromIncomingContext(ctx)
|
|
if ok {
|
|
if vals := md.Get("x-real-ip"); len(vals) > 0 {
|
|
if ip := net.ParseIP(strings.TrimSpace(vals[0])); ip != nil {
|
|
return ip.String()
|
|
}
|
|
}
|
|
if vals := md.Get("x-forwarded-for"); len(vals) > 0 {
|
|
// Take the first (leftmost) address — the original client.
|
|
first, _, _ := strings.Cut(vals[0], ",")
|
|
if ip := net.ParseIP(strings.TrimSpace(first)); ip != nil {
|
|
return ip.String()
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return peerIP
|
|
}
|
|
|
|
// extractBearerFromMD extracts the Bearer token from gRPC metadata.
|
|
// The key lookup is case-insensitive per gRPC metadata convention (all keys
|
|
// are lowercased by the framework; we match on "authorization").
|
|
//
|
|
// Security: The metadata value is never logged.
|
|
func extractBearerFromMD(ctx context.Context) (string, error) {
|
|
md, ok := metadata.FromIncomingContext(ctx)
|
|
if !ok {
|
|
return "", status.Error(codes.Unauthenticated, "no metadata")
|
|
}
|
|
vals := md.Get("authorization")
|
|
if len(vals) == 0 {
|
|
return "", status.Error(codes.Unauthenticated, "missing authorization metadata")
|
|
}
|
|
auth := vals[0]
|
|
const prefix = "bearer "
|
|
if len(auth) <= len(prefix) || !strings.EqualFold(auth[:len(prefix)], prefix) {
|
|
return "", status.Error(codes.Unauthenticated, "malformed authorization metadata")
|
|
}
|
|
t := auth[len(prefix):]
|
|
if t == "" {
|
|
return "", status.Error(codes.Unauthenticated, "empty bearer token")
|
|
}
|
|
return t, nil
|
|
}
|
|
|
|
// minFloat64 returns the smaller of two float64 values.
|
|
func minFloat64(a, b float64) float64 {
|
|
if a < b {
|
|
return a
|
|
}
|
|
return b
|
|
}
|