Files
mcias/internal/grpcserver/grpcserver.go
Kyle Isom f34e9a69a0 Fix all golangci-lint warnings
- errorlint: use errors.Is for db.ErrNotFound comparisons
  in accountservice.go, credentialservice.go, tokenservice.go
- gofmt/goimports: move mciasv1 alias into internal import group
  in auth.go, credentialservice.go, grpcserver.go, grpcserver_test.go
- gosec G115: add nolint annotation on int32 port conversions
  in mciasgrpcctl/main.go and credentialservice.go (port validated
  as [1,65535] on input; overflow not reachable)
- govet fieldalignment: reorder Server, grpcRateLimiter,
  grpcRateLimitEntry, testEnv structs to reduce GC bitmap size
  (96 -> 80 pointer bytes each)
- ineffassign: remove intermediate grpcSrv = GRPCServer() call
  in cmd/mciassrv/main.go (immediately overwritten by TLS build)
- staticcheck SA9003: replace empty if-body with _ = Serve(lis)
  in grpcserver_test.go
0 golangci-lint issues; 137 tests pass (go test -race ./...)
2026-03-11 15:24:07 -07:00

346 lines
10 KiB
Go

// Package grpcserver provides a gRPC server that exposes the same
// functionality as the REST HTTP server using the same internal packages.
//
// Security design:
// - All RPCs share business logic with the REST server via internal/auth,
// internal/token, internal/db, and internal/crypto packages.
// - Authentication uses the same JWT validation path as the REST middleware:
// alg-first check, signature verification, revocation table lookup.
// - The authorization metadata key is "authorization"; its value must be
// "Bearer <token>" (case-insensitive prefix check).
// - Credential fields (PasswordHash, TOTPSecret*, PGPassword) are never
// included in any RPC response message.
// - No credential material is logged by any interceptor.
// - TLS is required at the listener level (enforced by cmd/mciassrv).
// - Public RPCs (Health, GetPublicKey, ValidateToken) bypass auth.
package grpcserver
import (
"context"
"crypto/ed25519"
"log/slog"
"net"
"strings"
"sync"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
mciasv1 "git.wntrmute.dev/kyle/mcias/gen/mcias/v1"
"git.wntrmute.dev/kyle/mcias/internal/config"
"git.wntrmute.dev/kyle/mcias/internal/db"
"git.wntrmute.dev/kyle/mcias/internal/token"
)
// contextKey is the unexported context key type for this package.
type contextKey int
const (
claimsCtxKey contextKey = iota
)
// claimsFromContext retrieves JWT claims injected by the auth interceptor.
// Returns nil for unauthenticated (public) RPCs.
func claimsFromContext(ctx context.Context) *token.Claims {
c, _ := ctx.Value(claimsCtxKey).(*token.Claims)
return c
}
// Server holds the shared state for all gRPC service implementations.
type Server struct {
db *db.DB
cfg *config.Config
logger *slog.Logger
privKey ed25519.PrivateKey
pubKey ed25519.PublicKey
masterKey []byte
}
// New creates a Server with the given dependencies (same as the REST Server).
func New(database *db.DB, cfg *config.Config, priv ed25519.PrivateKey, pub ed25519.PublicKey, masterKey []byte, logger *slog.Logger) *Server {
return &Server{
db: database,
cfg: cfg,
privKey: priv,
pubKey: pub,
masterKey: masterKey,
logger: logger,
}
}
// publicMethods is the set of fully-qualified method names that bypass auth.
// These match the gRPC full method path: /<package>.<Service>/<Method>.
var publicMethods = map[string]bool{
"/mcias.v1.AdminService/Health": true,
"/mcias.v1.AdminService/GetPublicKey": true,
"/mcias.v1.TokenService/ValidateToken": true,
"/mcias.v1.AuthService/Login": true,
}
// GRPCServer builds and returns a configured *grpc.Server with all services
// registered and the interceptor chain installed. The returned server uses no
// transport credentials; callers are expected to wrap with TLS via GRPCServerWithCreds.
func (s *Server) GRPCServer() *grpc.Server {
return s.buildServer()
}
// GRPCServerWithCreds builds a *grpc.Server with TLS transport credentials.
// This is the method to use when starting a TLS gRPC listener; TLS credentials
// must be passed at server-construction time per the gRPC idiom.
func (s *Server) GRPCServerWithCreds(creds credentials.TransportCredentials) *grpc.Server {
return s.buildServer(grpc.Creds(creds))
}
// buildServer constructs the grpc.Server with optional additional server options.
func (s *Server) buildServer(extra ...grpc.ServerOption) *grpc.Server {
opts := append(
[]grpc.ServerOption{
grpc.ChainUnaryInterceptor(
s.loggingInterceptor,
s.authInterceptor,
s.rateLimitInterceptor,
),
},
extra...,
)
srv := grpc.NewServer(opts...)
// Register service implementations.
mciasv1.RegisterAdminServiceServer(srv, &adminServiceServer{s: s})
mciasv1.RegisterAuthServiceServer(srv, &authServiceServer{s: s})
mciasv1.RegisterTokenServiceServer(srv, &tokenServiceServer{s: s})
mciasv1.RegisterAccountServiceServer(srv, &accountServiceServer{s: s})
mciasv1.RegisterCredentialServiceServer(srv, &credentialServiceServer{s: s})
return srv
}
// loggingInterceptor logs each unary RPC call with method, peer IP, status,
// and duration. The authorization metadata value is never logged.
func (s *Server) loggingInterceptor(
ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
start := time.Now()
peerIP := ""
if p, ok := peer.FromContext(ctx); ok {
host, _, err := net.SplitHostPort(p.Addr.String())
if err == nil {
peerIP = host
} else {
peerIP = p.Addr.String()
}
}
resp, err := handler(ctx, req)
code := codes.OK
if err != nil {
code = status.Code(err)
}
// Security: authorization metadata is never logged.
s.logger.Info("grpc request",
"method", info.FullMethod,
"peer_ip", peerIP,
"code", code.String(),
"duration_ms", time.Since(start).Milliseconds(),
)
return resp, err
}
// authInterceptor validates the Bearer JWT from gRPC metadata and injects
// claims into the context. Public methods bypass this check.
//
// Security: Same validation path as the REST RequireAuth middleware:
// 1. Extract "authorization" metadata value (case-insensitive key lookup).
// 2. Validate JWT (alg-first, then signature, then expiry/issuer).
// 3. Check JTI against revocation table.
// 4. Inject claims into context.
func (s *Server) authInterceptor(
ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
if publicMethods[info.FullMethod] {
return handler(ctx, req)
}
tokenStr, err := extractBearerFromMD(ctx)
if err != nil {
// Security: do not reveal whether the header was missing vs. malformed.
return nil, status.Error(codes.Unauthenticated, "missing or invalid authorization")
}
claims, err := token.ValidateToken(s.pubKey, tokenStr, s.cfg.Tokens.Issuer)
if err != nil {
return nil, status.Error(codes.Unauthenticated, "invalid or expired token")
}
// Security: check revocation table after signature validation.
rec, err := s.db.GetTokenRecord(claims.JTI)
if err != nil || rec.IsRevoked() {
return nil, status.Error(codes.Unauthenticated, "token has been revoked")
}
ctx = context.WithValue(ctx, claimsCtxKey, claims)
return handler(ctx, req)
}
// requireAdmin checks that the claims in context contain the "admin" role.
// Called by admin-only RPC handlers after the authInterceptor has run.
//
// Security: Mirrors the REST RequireRole("admin") middleware check; checked
// after auth so claims are always populated when this function is reached.
func (s *Server) requireAdmin(ctx context.Context) error {
claims := claimsFromContext(ctx)
if claims == nil {
return status.Error(codes.PermissionDenied, "insufficient privileges")
}
if !claims.HasRole("admin") {
return status.Error(codes.PermissionDenied, "insufficient privileges")
}
return nil
}
// --- Rate limiter ---
// grpcRateLimiter is a per-IP token bucket for gRPC, sharing the same
// algorithm as the REST RateLimit middleware.
type grpcRateLimiter struct {
ips map[string]*grpcRateLimitEntry
rps float64
burst float64
ttl time.Duration
mu sync.Mutex
}
type grpcRateLimitEntry struct {
lastSeen time.Time
tokens float64
mu sync.Mutex
}
func newGRPCRateLimiter(rps float64, burst int) *grpcRateLimiter {
l := &grpcRateLimiter{
rps: rps,
burst: float64(burst),
ttl: 10 * time.Minute,
ips: make(map[string]*grpcRateLimitEntry),
}
go l.cleanup()
return l
}
func (l *grpcRateLimiter) allow(ip string) bool {
l.mu.Lock()
entry, ok := l.ips[ip]
if !ok {
entry = &grpcRateLimitEntry{tokens: l.burst, lastSeen: time.Now()}
l.ips[ip] = entry
}
l.mu.Unlock()
entry.mu.Lock()
defer entry.mu.Unlock()
now := time.Now()
elapsed := now.Sub(entry.lastSeen).Seconds()
entry.tokens = minFloat64(l.burst, entry.tokens+elapsed*l.rps)
entry.lastSeen = now
if entry.tokens < 1 {
return false
}
entry.tokens--
return true
}
func (l *grpcRateLimiter) cleanup() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
l.mu.Lock()
cutoff := time.Now().Add(-l.ttl)
for ip, entry := range l.ips {
entry.mu.Lock()
if entry.lastSeen.Before(cutoff) {
delete(l.ips, ip)
}
entry.mu.Unlock()
}
l.mu.Unlock()
}
}
// defaultRateLimiter is the server-wide rate limiter instance.
// 10 req/s sustained, burst 10 — same parameters as the REST limiter.
var defaultRateLimiter = newGRPCRateLimiter(10, 10)
// rateLimitInterceptor applies per-IP rate limiting using the same token-bucket
// parameters as the REST rate limiter (10 req/s, burst 10).
func (s *Server) rateLimitInterceptor(
ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
ip := ""
if p, ok := peer.FromContext(ctx); ok {
host, _, err := net.SplitHostPort(p.Addr.String())
if err == nil {
ip = host
} else {
ip = p.Addr.String()
}
}
if ip != "" && !defaultRateLimiter.allow(ip) {
return nil, status.Error(codes.ResourceExhausted, "rate limit exceeded")
}
return handler(ctx, req)
}
// extractBearerFromMD extracts the Bearer token from gRPC metadata.
// The key lookup is case-insensitive per gRPC metadata convention (all keys
// are lowercased by the framework; we match on "authorization").
//
// Security: The metadata value is never logged.
func extractBearerFromMD(ctx context.Context) (string, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return "", status.Error(codes.Unauthenticated, "no metadata")
}
vals := md.Get("authorization")
if len(vals) == 0 {
return "", status.Error(codes.Unauthenticated, "missing authorization metadata")
}
auth := vals[0]
const prefix = "bearer "
if len(auth) <= len(prefix) || !strings.EqualFold(auth[:len(prefix)], prefix) {
return "", status.Error(codes.Unauthenticated, "malformed authorization metadata")
}
t := auth[len(prefix):]
if t == "" {
return "", status.Error(codes.Unauthenticated, "empty bearer token")
}
return t, nil
}
// minFloat64 returns the smaller of two float64 values.
func minFloat64(a, b float64) float64 {
if a < b {
return a
}
return b
}