Files
mcias/internal/grpcserver/auth.go
Kyle Isom ec7c966ad2 trusted proxy, TOTP replay protection, new tests
- Trusted proxy config option for proxy-aware IP extraction
  used by rate limiting and audit logs; validates proxy IP
  before trusting X-Forwarded-For / X-Real-IP headers
- TOTP replay protection via counter-based validation to
  reject reused codes within the same time step (±30s)
- RateLimit middleware updated to extract client IP from
  proxy headers without IP spoofing risk
- New tests for ClientIP proxy logic (spoofed headers,
  fallback) and extended rate-limit proxy coverage
- HTMX error banner script integrated into web UI base
- .gitignore updated for mciasdb build artifact

Security: resolves CRIT-01 (TOTP replay attack) and
DEF-03 (proxy-unaware rate limiting); gRPC TOTP
enrollment aligned with REST via StorePendingTOTP

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-12 17:44:01 -07:00

306 lines
12 KiB
Go

// authServiceServer implements mciasv1.AuthServiceServer.
// All handlers delegate to the same internal packages as the REST server.
package grpcserver
import (
"context"
"fmt"
"net"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
mciasv1 "git.wntrmute.dev/kyle/mcias/gen/mcias/v1"
"git.wntrmute.dev/kyle/mcias/internal/auth"
"git.wntrmute.dev/kyle/mcias/internal/crypto"
"git.wntrmute.dev/kyle/mcias/internal/model"
"git.wntrmute.dev/kyle/mcias/internal/token"
)
type authServiceServer struct {
mciasv1.UnimplementedAuthServiceServer
s *Server
}
// Login authenticates a user and issues a JWT.
// Public RPC — no auth interceptor required.
//
// Security: Identical to the REST handleLogin: always runs Argon2 for unknown
// users to prevent timing-based user enumeration. Generic error returned
// regardless of which step failed.
func (a *authServiceServer) Login(ctx context.Context, req *mciasv1.LoginRequest) (*mciasv1.LoginResponse, error) {
if req.Username == "" || req.Password == "" {
return nil, status.Error(codes.InvalidArgument, "username and password are required")
}
ip := peerIP(ctx)
acct, err := a.s.db.GetAccountByUsername(req.Username)
if err != nil {
// Security: run dummy Argon2 to equalise timing for unknown users.
_, _ = auth.VerifyPassword("dummy", auth.DummyHash())
a.s.db.WriteAuditEvent(model.EventLoginFail, nil, nil, ip, //nolint:errcheck // audit failure is non-fatal
fmt.Sprintf(`{"username":%q,"reason":"unknown_user"}`, req.Username))
return nil, status.Error(codes.Unauthenticated, "invalid credentials")
}
if acct.Status != model.AccountStatusActive {
_, _ = auth.VerifyPassword("dummy", auth.DummyHash())
a.s.db.WriteAuditEvent(model.EventLoginFail, &acct.ID, nil, ip, `{"reason":"account_inactive"}`) //nolint:errcheck
return nil, status.Error(codes.Unauthenticated, "invalid credentials")
}
// Security: check per-account lockout before running Argon2 (F-08).
locked, lockErr := a.s.db.IsLockedOut(acct.ID)
if lockErr != nil {
a.s.logger.Error("lockout check", "error", lockErr)
}
if locked {
_, _ = auth.VerifyPassword("dummy", auth.DummyHash())
a.s.db.WriteAuditEvent(model.EventLoginFail, &acct.ID, nil, ip, `{"reason":"account_locked"}`) //nolint:errcheck
return nil, status.Error(codes.ResourceExhausted, "account temporarily locked")
}
ok, err := auth.VerifyPassword(req.Password, acct.PasswordHash)
if err != nil || !ok {
a.s.db.WriteAuditEvent(model.EventLoginFail, &acct.ID, nil, ip, `{"reason":"wrong_password"}`) //nolint:errcheck
_ = a.s.db.RecordLoginFailure(acct.ID)
return nil, status.Error(codes.Unauthenticated, "invalid credentials")
}
if acct.TOTPRequired {
if req.TotpCode == "" {
// Security (DEF-08): password was already verified, so a missing
// TOTP code means the gRPC client needs to re-prompt the user —
// it is not a credential failure. Do NOT increment the lockout
// counter here; doing so would lock out well-behaved clients that
// call Login in two steps (password first, TOTP second) and would
// also let an attacker trigger account lockout by omitting the
// code after a successful password guess.
a.s.db.WriteAuditEvent(model.EventLoginFail, &acct.ID, nil, ip, `{"reason":"totp_missing"}`) //nolint:errcheck
return nil, status.Error(codes.Unauthenticated, "TOTP code required")
}
secret, err := crypto.OpenAESGCM(a.s.masterKey, acct.TOTPSecretNonce, acct.TOTPSecretEnc)
if err != nil {
a.s.logger.Error("decrypt TOTP secret", "error", err, "account_id", acct.ID)
return nil, status.Error(codes.Internal, "internal error")
}
valid, counter, err := auth.ValidateTOTP(secret, req.TotpCode)
if err != nil || !valid {
a.s.db.WriteAuditEvent(model.EventLoginTOTPFail, &acct.ID, nil, ip, `{"reason":"wrong_totp"}`) //nolint:errcheck
_ = a.s.db.RecordLoginFailure(acct.ID)
return nil, status.Error(codes.Unauthenticated, "invalid credentials")
}
// Security (CRIT-01): reject replay of a code already used within
// its ±30-second validity window.
if err := a.s.db.CheckAndUpdateTOTPCounter(acct.ID, counter); err != nil {
a.s.db.WriteAuditEvent(model.EventLoginTOTPFail, &acct.ID, nil, ip, `{"reason":"totp_replay"}`) //nolint:errcheck
_ = a.s.db.RecordLoginFailure(acct.ID)
return nil, status.Error(codes.Unauthenticated, "invalid credentials")
}
}
// Login succeeded: clear any outstanding failure counter.
_ = a.s.db.ClearLoginFailures(acct.ID)
expiry := a.s.cfg.DefaultExpiry()
roles, err := a.s.db.GetRoles(acct.ID)
if err != nil {
return nil, status.Error(codes.Internal, "internal error")
}
for _, r := range roles {
if r == "admin" {
expiry = a.s.cfg.AdminExpiry()
break
}
}
tokenStr, claims, err := token.IssueToken(a.s.privKey, a.s.cfg.Tokens.Issuer, acct.UUID, roles, expiry)
if err != nil {
a.s.logger.Error("issue token", "error", err)
return nil, status.Error(codes.Internal, "internal error")
}
if err := a.s.db.TrackToken(claims.JTI, acct.ID, claims.IssuedAt, claims.ExpiresAt); err != nil {
a.s.logger.Error("track token", "error", err)
return nil, status.Error(codes.Internal, "internal error")
}
a.s.db.WriteAuditEvent(model.EventLoginOK, &acct.ID, nil, ip, "") //nolint:errcheck
a.s.db.WriteAuditEvent(model.EventTokenIssued, &acct.ID, nil, ip, //nolint:errcheck
fmt.Sprintf(`{"jti":%q}`, claims.JTI))
return &mciasv1.LoginResponse{
Token: tokenStr,
ExpiresAt: timestamppb.New(claims.ExpiresAt),
}, nil
}
// Logout revokes the caller's current JWT.
func (a *authServiceServer) Logout(ctx context.Context, _ *mciasv1.LogoutRequest) (*mciasv1.LogoutResponse, error) {
claims := claimsFromContext(ctx)
if err := a.s.db.RevokeToken(claims.JTI, "logout"); err != nil {
a.s.logger.Error("revoke token on logout", "error", err)
return nil, status.Error(codes.Internal, "internal error")
}
a.s.db.WriteAuditEvent(model.EventTokenRevoked, nil, nil, peerIP(ctx), //nolint:errcheck
fmt.Sprintf(`{"jti":%q,"reason":"logout"}`, claims.JTI))
return &mciasv1.LogoutResponse{}, nil
}
// RenewToken exchanges the caller's token for a new one.
func (a *authServiceServer) RenewToken(ctx context.Context, _ *mciasv1.RenewTokenRequest) (*mciasv1.RenewTokenResponse, error) {
claims := claimsFromContext(ctx)
acct, err := a.s.db.GetAccountByUUID(claims.Subject)
if err != nil {
return nil, status.Error(codes.Unauthenticated, "account not found")
}
if acct.Status != model.AccountStatusActive {
return nil, status.Error(codes.Unauthenticated, "account inactive")
}
roles, err := a.s.db.GetRoles(acct.ID)
if err != nil {
return nil, status.Error(codes.Internal, "internal error")
}
expiry := a.s.cfg.DefaultExpiry()
for _, r := range roles {
if r == "admin" {
expiry = a.s.cfg.AdminExpiry()
break
}
}
newTokenStr, newClaims, err := token.IssueToken(a.s.privKey, a.s.cfg.Tokens.Issuer, acct.UUID, roles, expiry)
if err != nil {
return nil, status.Error(codes.Internal, "internal error")
}
// Security: revoke old + track new in a single transaction (F-03) so that a
// failure between the two steps cannot leave the user with no valid token.
if err := a.s.db.RenewToken(claims.JTI, "renewed", newClaims.JTI, acct.ID, newClaims.IssuedAt, newClaims.ExpiresAt); err != nil {
return nil, status.Error(codes.Internal, "internal error")
}
a.s.db.WriteAuditEvent(model.EventTokenRenewed, &acct.ID, nil, peerIP(ctx), //nolint:errcheck
fmt.Sprintf(`{"old_jti":%q,"new_jti":%q}`, claims.JTI, newClaims.JTI))
return &mciasv1.RenewTokenResponse{
Token: newTokenStr,
ExpiresAt: timestamppb.New(newClaims.ExpiresAt),
}, nil
}
// EnrollTOTP begins TOTP enrollment for the calling account.
func (a *authServiceServer) EnrollTOTP(ctx context.Context, _ *mciasv1.EnrollTOTPRequest) (*mciasv1.EnrollTOTPResponse, error) {
claims := claimsFromContext(ctx)
acct, err := a.s.db.GetAccountByUUID(claims.Subject)
if err != nil {
return nil, status.Error(codes.Unauthenticated, "account not found")
}
rawSecret, b32Secret, err := auth.GenerateTOTPSecret()
if err != nil {
return nil, status.Error(codes.Internal, "internal error")
}
secretEnc, secretNonce, err := crypto.SealAESGCM(a.s.masterKey, rawSecret)
if err != nil {
return nil, status.Error(codes.Internal, "internal error")
}
// Security: use StorePendingTOTP (not SetTOTP) so that totp_required is
// not set to 1 until the user confirms the code via ConfirmTOTP. Calling
// SetTOTP here would immediately lock the account behind TOTP before the
// user has had a chance to configure their authenticator app — matching the
// behaviour of the REST EnrollTOTP handler at internal/server/server.go.
if err := a.s.db.StorePendingTOTP(acct.ID, secretEnc, secretNonce); err != nil {
return nil, status.Error(codes.Internal, "internal error")
}
otpURI := fmt.Sprintf("otpauth://totp/MCIAS:%s?secret=%s&issuer=MCIAS", acct.Username, b32Secret)
// Security: secret is shown once here only; the stored form is encrypted.
return &mciasv1.EnrollTOTPResponse{
Secret: b32Secret,
OtpauthUri: otpURI,
}, nil
}
// ConfirmTOTP confirms TOTP enrollment.
func (a *authServiceServer) ConfirmTOTP(ctx context.Context, req *mciasv1.ConfirmTOTPRequest) (*mciasv1.ConfirmTOTPResponse, error) {
if req.Code == "" {
return nil, status.Error(codes.InvalidArgument, "code is required")
}
claims := claimsFromContext(ctx)
acct, err := a.s.db.GetAccountByUUID(claims.Subject)
if err != nil {
return nil, status.Error(codes.Unauthenticated, "account not found")
}
if acct.TOTPSecretEnc == nil {
return nil, status.Error(codes.FailedPrecondition, "TOTP enrollment not started")
}
secret, err := crypto.OpenAESGCM(a.s.masterKey, acct.TOTPSecretNonce, acct.TOTPSecretEnc)
if err != nil {
return nil, status.Error(codes.Internal, "internal error")
}
valid, counter, err := auth.ValidateTOTP(secret, req.Code)
if err != nil || !valid {
return nil, status.Error(codes.Unauthenticated, "invalid TOTP code")
}
// Security (CRIT-01): record the counter even during enrollment confirmation
// so the same code cannot be replayed immediately after confirming.
if err := a.s.db.CheckAndUpdateTOTPCounter(acct.ID, counter); err != nil {
return nil, status.Error(codes.Unauthenticated, "invalid TOTP code")
}
// SetTOTP with existing enc/nonce sets totp_required=1, confirming enrollment.
if err := a.s.db.SetTOTP(acct.ID, acct.TOTPSecretEnc, acct.TOTPSecretNonce); err != nil {
return nil, status.Error(codes.Internal, "internal error")
}
a.s.db.WriteAuditEvent(model.EventTOTPEnrolled, &acct.ID, nil, peerIP(ctx), "") //nolint:errcheck
return &mciasv1.ConfirmTOTPResponse{}, nil
}
// RemoveTOTP removes TOTP from an account. Admin only.
func (a *authServiceServer) RemoveTOTP(ctx context.Context, req *mciasv1.RemoveTOTPRequest) (*mciasv1.RemoveTOTPResponse, error) {
if err := a.s.requireAdmin(ctx); err != nil {
return nil, err
}
if req.AccountId == "" {
return nil, status.Error(codes.InvalidArgument, "account_id is required")
}
acct, err := a.s.db.GetAccountByUUID(req.AccountId)
if err != nil {
return nil, status.Error(codes.NotFound, "account not found")
}
if err := a.s.db.ClearTOTP(acct.ID); err != nil {
return nil, status.Error(codes.Internal, "internal error")
}
a.s.db.WriteAuditEvent(model.EventTOTPRemoved, nil, &acct.ID, peerIP(ctx), "") //nolint:errcheck
return &mciasv1.RemoveTOTPResponse{}, nil
}
// peerIP extracts the client IP from gRPC peer context.
func peerIP(ctx context.Context) string {
p, ok := peer.FromContext(ctx)
if !ok {
return ""
}
host, _, err := net.SplitHostPort(p.Addr.String())
if err != nil {
return p.Addr.String()
}
return host
}