The package-level defaultRateLimiter drained its token bucket across all test cases, causing later tests to hit ResourceExhausted. Move rateLimiter from a package-level var to a *grpcRateLimiter field on Server; New() allocates a fresh instance (10 req/s, burst 10) per server. Each test's newTestEnv() constructs its own Server, so tests no longer share limiter state. Production behaviour is unchanged: a single Server is constructed at startup and lives for the process lifetime.
346 lines
10 KiB
Go
346 lines
10 KiB
Go
// Package grpcserver provides a gRPC server that exposes the same
|
|
// functionality as the REST HTTP server using the same internal packages.
|
|
//
|
|
// Security design:
|
|
// - All RPCs share business logic with the REST server via internal/auth,
|
|
// internal/token, internal/db, and internal/crypto packages.
|
|
// - Authentication uses the same JWT validation path as the REST middleware:
|
|
// alg-first check, signature verification, revocation table lookup.
|
|
// - The authorization metadata key is "authorization"; its value must be
|
|
// "Bearer <token>" (case-insensitive prefix check).
|
|
// - Credential fields (PasswordHash, TOTPSecret*, PGPassword) are never
|
|
// included in any RPC response message.
|
|
// - No credential material is logged by any interceptor.
|
|
// - TLS is required at the listener level (enforced by cmd/mciassrv).
|
|
// - Public RPCs (Health, GetPublicKey, ValidateToken) bypass auth.
|
|
package grpcserver
|
|
|
|
import (
|
|
"context"
|
|
"crypto/ed25519"
|
|
"log/slog"
|
|
"net"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/credentials"
|
|
"google.golang.org/grpc/metadata"
|
|
"google.golang.org/grpc/peer"
|
|
"google.golang.org/grpc/status"
|
|
|
|
mciasv1 "git.wntrmute.dev/kyle/mcias/gen/mcias/v1"
|
|
"git.wntrmute.dev/kyle/mcias/internal/config"
|
|
"git.wntrmute.dev/kyle/mcias/internal/db"
|
|
"git.wntrmute.dev/kyle/mcias/internal/token"
|
|
)
|
|
|
|
// contextKey is the unexported context key type for this package.
|
|
type contextKey int
|
|
|
|
const (
|
|
claimsCtxKey contextKey = iota
|
|
)
|
|
|
|
// claimsFromContext retrieves JWT claims injected by the auth interceptor.
|
|
// Returns nil for unauthenticated (public) RPCs.
|
|
func claimsFromContext(ctx context.Context) *token.Claims {
|
|
c, _ := ctx.Value(claimsCtxKey).(*token.Claims)
|
|
return c
|
|
}
|
|
|
|
// Server holds the shared state for all gRPC service implementations.
|
|
type Server struct {
|
|
db *db.DB
|
|
cfg *config.Config
|
|
logger *slog.Logger
|
|
rateLimiter *grpcRateLimiter
|
|
privKey ed25519.PrivateKey
|
|
pubKey ed25519.PublicKey
|
|
masterKey []byte
|
|
}
|
|
|
|
// New creates a Server with the given dependencies (same as the REST Server).
|
|
// A fresh per-IP rate limiter (10 req/s, burst 10) is allocated per Server
|
|
// instance so that tests do not share state across test cases.
|
|
func New(database *db.DB, cfg *config.Config, priv ed25519.PrivateKey, pub ed25519.PublicKey, masterKey []byte, logger *slog.Logger) *Server {
|
|
return &Server{
|
|
db: database,
|
|
cfg: cfg,
|
|
privKey: priv,
|
|
pubKey: pub,
|
|
masterKey: masterKey,
|
|
logger: logger,
|
|
rateLimiter: newGRPCRateLimiter(10, 10),
|
|
}
|
|
}
|
|
|
|
// publicMethods is the set of fully-qualified method names that bypass auth.
|
|
// These match the gRPC full method path: /<package>.<Service>/<Method>.
|
|
var publicMethods = map[string]bool{
|
|
"/mcias.v1.AdminService/Health": true,
|
|
"/mcias.v1.AdminService/GetPublicKey": true,
|
|
"/mcias.v1.TokenService/ValidateToken": true,
|
|
"/mcias.v1.AuthService/Login": true,
|
|
}
|
|
|
|
// GRPCServer builds and returns a configured *grpc.Server with all services
|
|
// registered and the interceptor chain installed. The returned server uses no
|
|
// transport credentials; callers are expected to wrap with TLS via GRPCServerWithCreds.
|
|
func (s *Server) GRPCServer() *grpc.Server {
|
|
return s.buildServer()
|
|
}
|
|
|
|
// GRPCServerWithCreds builds a *grpc.Server with TLS transport credentials.
|
|
// This is the method to use when starting a TLS gRPC listener; TLS credentials
|
|
// must be passed at server-construction time per the gRPC idiom.
|
|
func (s *Server) GRPCServerWithCreds(creds credentials.TransportCredentials) *grpc.Server {
|
|
return s.buildServer(grpc.Creds(creds))
|
|
}
|
|
|
|
// buildServer constructs the grpc.Server with optional additional server options.
|
|
func (s *Server) buildServer(extra ...grpc.ServerOption) *grpc.Server {
|
|
opts := append(
|
|
[]grpc.ServerOption{
|
|
grpc.ChainUnaryInterceptor(
|
|
s.loggingInterceptor,
|
|
s.authInterceptor,
|
|
s.rateLimitInterceptor,
|
|
),
|
|
},
|
|
extra...,
|
|
)
|
|
srv := grpc.NewServer(opts...)
|
|
|
|
// Register service implementations.
|
|
mciasv1.RegisterAdminServiceServer(srv, &adminServiceServer{s: s})
|
|
mciasv1.RegisterAuthServiceServer(srv, &authServiceServer{s: s})
|
|
mciasv1.RegisterTokenServiceServer(srv, &tokenServiceServer{s: s})
|
|
mciasv1.RegisterAccountServiceServer(srv, &accountServiceServer{s: s})
|
|
mciasv1.RegisterCredentialServiceServer(srv, &credentialServiceServer{s: s})
|
|
|
|
return srv
|
|
}
|
|
|
|
// loggingInterceptor logs each unary RPC call with method, peer IP, status,
|
|
// and duration. The authorization metadata value is never logged.
|
|
func (s *Server) loggingInterceptor(
|
|
ctx context.Context,
|
|
req interface{},
|
|
info *grpc.UnaryServerInfo,
|
|
handler grpc.UnaryHandler,
|
|
) (interface{}, error) {
|
|
start := time.Now()
|
|
|
|
peerIP := ""
|
|
if p, ok := peer.FromContext(ctx); ok {
|
|
host, _, err := net.SplitHostPort(p.Addr.String())
|
|
if err == nil {
|
|
peerIP = host
|
|
} else {
|
|
peerIP = p.Addr.String()
|
|
}
|
|
}
|
|
|
|
resp, err := handler(ctx, req)
|
|
|
|
code := codes.OK
|
|
if err != nil {
|
|
code = status.Code(err)
|
|
}
|
|
|
|
// Security: authorization metadata is never logged.
|
|
s.logger.Info("grpc request",
|
|
"method", info.FullMethod,
|
|
"peer_ip", peerIP,
|
|
"code", code.String(),
|
|
"duration_ms", time.Since(start).Milliseconds(),
|
|
)
|
|
return resp, err
|
|
}
|
|
|
|
// authInterceptor validates the Bearer JWT from gRPC metadata and injects
|
|
// claims into the context. Public methods bypass this check.
|
|
//
|
|
// Security: Same validation path as the REST RequireAuth middleware:
|
|
// 1. Extract "authorization" metadata value (case-insensitive key lookup).
|
|
// 2. Validate JWT (alg-first, then signature, then expiry/issuer).
|
|
// 3. Check JTI against revocation table.
|
|
// 4. Inject claims into context.
|
|
func (s *Server) authInterceptor(
|
|
ctx context.Context,
|
|
req interface{},
|
|
info *grpc.UnaryServerInfo,
|
|
handler grpc.UnaryHandler,
|
|
) (interface{}, error) {
|
|
if publicMethods[info.FullMethod] {
|
|
return handler(ctx, req)
|
|
}
|
|
|
|
tokenStr, err := extractBearerFromMD(ctx)
|
|
if err != nil {
|
|
// Security: do not reveal whether the header was missing vs. malformed.
|
|
return nil, status.Error(codes.Unauthenticated, "missing or invalid authorization")
|
|
}
|
|
|
|
claims, err := token.ValidateToken(s.pubKey, tokenStr, s.cfg.Tokens.Issuer)
|
|
if err != nil {
|
|
return nil, status.Error(codes.Unauthenticated, "invalid or expired token")
|
|
}
|
|
|
|
// Security: check revocation table after signature validation.
|
|
rec, err := s.db.GetTokenRecord(claims.JTI)
|
|
if err != nil || rec.IsRevoked() {
|
|
return nil, status.Error(codes.Unauthenticated, "token has been revoked")
|
|
}
|
|
|
|
ctx = context.WithValue(ctx, claimsCtxKey, claims)
|
|
return handler(ctx, req)
|
|
}
|
|
|
|
// requireAdmin checks that the claims in context contain the "admin" role.
|
|
// Called by admin-only RPC handlers after the authInterceptor has run.
|
|
//
|
|
// Security: Mirrors the REST RequireRole("admin") middleware check; checked
|
|
// after auth so claims are always populated when this function is reached.
|
|
func (s *Server) requireAdmin(ctx context.Context) error {
|
|
claims := claimsFromContext(ctx)
|
|
if claims == nil {
|
|
return status.Error(codes.PermissionDenied, "insufficient privileges")
|
|
}
|
|
if !claims.HasRole("admin") {
|
|
return status.Error(codes.PermissionDenied, "insufficient privileges")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// --- Rate limiter ---
|
|
|
|
// grpcRateLimiter is a per-IP token bucket for gRPC, sharing the same
|
|
// algorithm as the REST RateLimit middleware.
|
|
type grpcRateLimiter struct {
|
|
ips map[string]*grpcRateLimitEntry
|
|
rps float64
|
|
burst float64
|
|
ttl time.Duration
|
|
mu sync.Mutex
|
|
}
|
|
|
|
type grpcRateLimitEntry struct {
|
|
lastSeen time.Time
|
|
tokens float64
|
|
mu sync.Mutex
|
|
}
|
|
|
|
func newGRPCRateLimiter(rps float64, burst int) *grpcRateLimiter {
|
|
l := &grpcRateLimiter{
|
|
rps: rps,
|
|
burst: float64(burst),
|
|
ttl: 10 * time.Minute,
|
|
ips: make(map[string]*grpcRateLimitEntry),
|
|
}
|
|
go l.cleanup()
|
|
return l
|
|
}
|
|
|
|
func (l *grpcRateLimiter) allow(ip string) bool {
|
|
l.mu.Lock()
|
|
entry, ok := l.ips[ip]
|
|
if !ok {
|
|
entry = &grpcRateLimitEntry{tokens: l.burst, lastSeen: time.Now()}
|
|
l.ips[ip] = entry
|
|
}
|
|
l.mu.Unlock()
|
|
|
|
entry.mu.Lock()
|
|
defer entry.mu.Unlock()
|
|
|
|
now := time.Now()
|
|
elapsed := now.Sub(entry.lastSeen).Seconds()
|
|
entry.tokens = minFloat64(l.burst, entry.tokens+elapsed*l.rps)
|
|
entry.lastSeen = now
|
|
|
|
if entry.tokens < 1 {
|
|
return false
|
|
}
|
|
entry.tokens--
|
|
return true
|
|
}
|
|
|
|
func (l *grpcRateLimiter) cleanup() {
|
|
ticker := time.NewTicker(5 * time.Minute)
|
|
defer ticker.Stop()
|
|
for range ticker.C {
|
|
l.mu.Lock()
|
|
cutoff := time.Now().Add(-l.ttl)
|
|
for ip, entry := range l.ips {
|
|
entry.mu.Lock()
|
|
if entry.lastSeen.Before(cutoff) {
|
|
delete(l.ips, ip)
|
|
}
|
|
entry.mu.Unlock()
|
|
}
|
|
l.mu.Unlock()
|
|
}
|
|
}
|
|
|
|
// rateLimitInterceptor applies per-IP rate limiting using the same token-bucket
|
|
// parameters as the REST rate limiter (10 req/s, burst 10).
|
|
func (s *Server) rateLimitInterceptor(
|
|
ctx context.Context,
|
|
req interface{},
|
|
info *grpc.UnaryServerInfo,
|
|
handler grpc.UnaryHandler,
|
|
) (interface{}, error) {
|
|
ip := ""
|
|
if p, ok := peer.FromContext(ctx); ok {
|
|
host, _, err := net.SplitHostPort(p.Addr.String())
|
|
if err == nil {
|
|
ip = host
|
|
} else {
|
|
ip = p.Addr.String()
|
|
}
|
|
}
|
|
|
|
if ip != "" && !s.rateLimiter.allow(ip) {
|
|
return nil, status.Error(codes.ResourceExhausted, "rate limit exceeded")
|
|
}
|
|
return handler(ctx, req)
|
|
}
|
|
|
|
// extractBearerFromMD extracts the Bearer token from gRPC metadata.
|
|
// The key lookup is case-insensitive per gRPC metadata convention (all keys
|
|
// are lowercased by the framework; we match on "authorization").
|
|
//
|
|
// Security: The metadata value is never logged.
|
|
func extractBearerFromMD(ctx context.Context) (string, error) {
|
|
md, ok := metadata.FromIncomingContext(ctx)
|
|
if !ok {
|
|
return "", status.Error(codes.Unauthenticated, "no metadata")
|
|
}
|
|
vals := md.Get("authorization")
|
|
if len(vals) == 0 {
|
|
return "", status.Error(codes.Unauthenticated, "missing authorization metadata")
|
|
}
|
|
auth := vals[0]
|
|
const prefix = "bearer "
|
|
if len(auth) <= len(prefix) || !strings.EqualFold(auth[:len(prefix)], prefix) {
|
|
return "", status.Error(codes.Unauthenticated, "malformed authorization metadata")
|
|
}
|
|
t := auth[len(prefix):]
|
|
if t == "" {
|
|
return "", status.Error(codes.Unauthenticated, "empty bearer token")
|
|
}
|
|
return t, nil
|
|
}
|
|
|
|
// minFloat64 returns the smaller of two float64 values.
|
|
func minFloat64(a, b float64) float64 {
|
|
if a < b {
|
|
return a
|
|
}
|
|
return b
|
|
}
|