Files
mcias/internal/grpcserver/auth.go
Kyle Isom bf9002a31c Fix F-03: make token renewal atomic
- db/accounts.go: add RenewToken(oldJTI, reason, newJTI,
  accountID, issuedAt, expiresAt) which wraps RevokeToken +
  TrackToken in a single BEGIN/COMMIT transaction; if either
  step fails the whole tx rolls back, so the user is never
  left with neither old nor new token valid
- server.go (handleRenewToken): replace separate RevokeToken +
  TrackToken calls with single RenewToken call; failure now
  returns 500 instead of silently losing revocation
- grpcserver/auth.go (RenewToken): same replacement
- db/db_test.go: TestRenewTokenAtomic verifies old token is
  revoked with correct reason, new token is tracked and not
  revoked, and a second renewal on the already-revoked old
  token returns an error
- AUDIT.md: mark F-03 as fixed
Security: without atomicity a crash/error between revoke and
  track could leave the old token active alongside the new one
  (two live tokens) or revoke the old token without tracking
  the new one (user locked out). The transaction ensures
  exactly one of the two tokens is valid at all times.
2026-03-11 20:24:32 -07:00

266 lines
9.5 KiB
Go

// authServiceServer implements mciasv1.AuthServiceServer.
// All handlers delegate to the same internal packages as the REST server.
package grpcserver
import (
"context"
"fmt"
"net"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
mciasv1 "git.wntrmute.dev/kyle/mcias/gen/mcias/v1"
"git.wntrmute.dev/kyle/mcias/internal/auth"
"git.wntrmute.dev/kyle/mcias/internal/crypto"
"git.wntrmute.dev/kyle/mcias/internal/model"
"git.wntrmute.dev/kyle/mcias/internal/token"
)
type authServiceServer struct {
mciasv1.UnimplementedAuthServiceServer
s *Server
}
// Login authenticates a user and issues a JWT.
// Public RPC — no auth interceptor required.
//
// Security: Identical to the REST handleLogin: always runs Argon2 for unknown
// users to prevent timing-based user enumeration. Generic error returned
// regardless of which step failed.
func (a *authServiceServer) Login(ctx context.Context, req *mciasv1.LoginRequest) (*mciasv1.LoginResponse, error) {
if req.Username == "" || req.Password == "" {
return nil, status.Error(codes.InvalidArgument, "username and password are required")
}
ip := peerIP(ctx)
acct, err := a.s.db.GetAccountByUsername(req.Username)
if err != nil {
// Security: run dummy Argon2 to equalise timing for unknown users.
_, _ = auth.VerifyPassword("dummy", "$argon2id$v=19$m=65536,t=3,p=4$dGVzdHNhbHQ$dGVzdGhhc2g")
a.s.db.WriteAuditEvent(model.EventLoginFail, nil, nil, ip, //nolint:errcheck // audit failure is non-fatal
fmt.Sprintf(`{"username":%q,"reason":"unknown_user"}`, req.Username))
return nil, status.Error(codes.Unauthenticated, "invalid credentials")
}
if acct.Status != model.AccountStatusActive {
_, _ = auth.VerifyPassword("dummy", "$argon2id$v=19$m=65536,t=3,p=4$dGVzdHNhbHQ$dGVzdGhhc2g")
a.s.db.WriteAuditEvent(model.EventLoginFail, &acct.ID, nil, ip, `{"reason":"account_inactive"}`) //nolint:errcheck
return nil, status.Error(codes.Unauthenticated, "invalid credentials")
}
ok, err := auth.VerifyPassword(req.Password, acct.PasswordHash)
if err != nil || !ok {
a.s.db.WriteAuditEvent(model.EventLoginFail, &acct.ID, nil, ip, `{"reason":"wrong_password"}`) //nolint:errcheck
return nil, status.Error(codes.Unauthenticated, "invalid credentials")
}
if acct.TOTPRequired {
if req.TotpCode == "" {
a.s.db.WriteAuditEvent(model.EventLoginFail, &acct.ID, nil, ip, `{"reason":"totp_missing"}`) //nolint:errcheck
return nil, status.Error(codes.Unauthenticated, "TOTP code required")
}
secret, err := crypto.OpenAESGCM(a.s.masterKey, acct.TOTPSecretNonce, acct.TOTPSecretEnc)
if err != nil {
a.s.logger.Error("decrypt TOTP secret", "error", err, "account_id", acct.ID)
return nil, status.Error(codes.Internal, "internal error")
}
valid, err := auth.ValidateTOTP(secret, req.TotpCode)
if err != nil || !valid {
a.s.db.WriteAuditEvent(model.EventLoginTOTPFail, &acct.ID, nil, ip, `{"reason":"wrong_totp"}`) //nolint:errcheck
return nil, status.Error(codes.Unauthenticated, "invalid credentials")
}
}
expiry := a.s.cfg.DefaultExpiry()
roles, err := a.s.db.GetRoles(acct.ID)
if err != nil {
return nil, status.Error(codes.Internal, "internal error")
}
for _, r := range roles {
if r == "admin" {
expiry = a.s.cfg.AdminExpiry()
break
}
}
tokenStr, claims, err := token.IssueToken(a.s.privKey, a.s.cfg.Tokens.Issuer, acct.UUID, roles, expiry)
if err != nil {
a.s.logger.Error("issue token", "error", err)
return nil, status.Error(codes.Internal, "internal error")
}
if err := a.s.db.TrackToken(claims.JTI, acct.ID, claims.IssuedAt, claims.ExpiresAt); err != nil {
a.s.logger.Error("track token", "error", err)
return nil, status.Error(codes.Internal, "internal error")
}
a.s.db.WriteAuditEvent(model.EventLoginOK, &acct.ID, nil, ip, "") //nolint:errcheck
a.s.db.WriteAuditEvent(model.EventTokenIssued, &acct.ID, nil, ip, //nolint:errcheck
fmt.Sprintf(`{"jti":%q}`, claims.JTI))
return &mciasv1.LoginResponse{
Token: tokenStr,
ExpiresAt: timestamppb.New(claims.ExpiresAt),
}, nil
}
// Logout revokes the caller's current JWT.
func (a *authServiceServer) Logout(ctx context.Context, _ *mciasv1.LogoutRequest) (*mciasv1.LogoutResponse, error) {
claims := claimsFromContext(ctx)
if err := a.s.db.RevokeToken(claims.JTI, "logout"); err != nil {
a.s.logger.Error("revoke token on logout", "error", err)
return nil, status.Error(codes.Internal, "internal error")
}
a.s.db.WriteAuditEvent(model.EventTokenRevoked, nil, nil, peerIP(ctx), //nolint:errcheck
fmt.Sprintf(`{"jti":%q,"reason":"logout"}`, claims.JTI))
return &mciasv1.LogoutResponse{}, nil
}
// RenewToken exchanges the caller's token for a new one.
func (a *authServiceServer) RenewToken(ctx context.Context, _ *mciasv1.RenewTokenRequest) (*mciasv1.RenewTokenResponse, error) {
claims := claimsFromContext(ctx)
acct, err := a.s.db.GetAccountByUUID(claims.Subject)
if err != nil {
return nil, status.Error(codes.Unauthenticated, "account not found")
}
if acct.Status != model.AccountStatusActive {
return nil, status.Error(codes.Unauthenticated, "account inactive")
}
roles, err := a.s.db.GetRoles(acct.ID)
if err != nil {
return nil, status.Error(codes.Internal, "internal error")
}
expiry := a.s.cfg.DefaultExpiry()
for _, r := range roles {
if r == "admin" {
expiry = a.s.cfg.AdminExpiry()
break
}
}
newTokenStr, newClaims, err := token.IssueToken(a.s.privKey, a.s.cfg.Tokens.Issuer, acct.UUID, roles, expiry)
if err != nil {
return nil, status.Error(codes.Internal, "internal error")
}
// Security: revoke old + track new in a single transaction (F-03) so that a
// failure between the two steps cannot leave the user with no valid token.
if err := a.s.db.RenewToken(claims.JTI, "renewed", newClaims.JTI, acct.ID, newClaims.IssuedAt, newClaims.ExpiresAt); err != nil {
return nil, status.Error(codes.Internal, "internal error")
}
a.s.db.WriteAuditEvent(model.EventTokenRenewed, &acct.ID, nil, peerIP(ctx), //nolint:errcheck
fmt.Sprintf(`{"old_jti":%q,"new_jti":%q}`, claims.JTI, newClaims.JTI))
return &mciasv1.RenewTokenResponse{
Token: newTokenStr,
ExpiresAt: timestamppb.New(newClaims.ExpiresAt),
}, nil
}
// EnrollTOTP begins TOTP enrollment for the calling account.
func (a *authServiceServer) EnrollTOTP(ctx context.Context, _ *mciasv1.EnrollTOTPRequest) (*mciasv1.EnrollTOTPResponse, error) {
claims := claimsFromContext(ctx)
acct, err := a.s.db.GetAccountByUUID(claims.Subject)
if err != nil {
return nil, status.Error(codes.Unauthenticated, "account not found")
}
rawSecret, b32Secret, err := auth.GenerateTOTPSecret()
if err != nil {
return nil, status.Error(codes.Internal, "internal error")
}
secretEnc, secretNonce, err := crypto.SealAESGCM(a.s.masterKey, rawSecret)
if err != nil {
return nil, status.Error(codes.Internal, "internal error")
}
if err := a.s.db.SetTOTP(acct.ID, secretEnc, secretNonce); err != nil {
return nil, status.Error(codes.Internal, "internal error")
}
otpURI := fmt.Sprintf("otpauth://totp/MCIAS:%s?secret=%s&issuer=MCIAS", acct.Username, b32Secret)
// Security: secret is shown once here only; the stored form is encrypted.
return &mciasv1.EnrollTOTPResponse{
Secret: b32Secret,
OtpauthUri: otpURI,
}, nil
}
// ConfirmTOTP confirms TOTP enrollment.
func (a *authServiceServer) ConfirmTOTP(ctx context.Context, req *mciasv1.ConfirmTOTPRequest) (*mciasv1.ConfirmTOTPResponse, error) {
if req.Code == "" {
return nil, status.Error(codes.InvalidArgument, "code is required")
}
claims := claimsFromContext(ctx)
acct, err := a.s.db.GetAccountByUUID(claims.Subject)
if err != nil {
return nil, status.Error(codes.Unauthenticated, "account not found")
}
if acct.TOTPSecretEnc == nil {
return nil, status.Error(codes.FailedPrecondition, "TOTP enrollment not started")
}
secret, err := crypto.OpenAESGCM(a.s.masterKey, acct.TOTPSecretNonce, acct.TOTPSecretEnc)
if err != nil {
return nil, status.Error(codes.Internal, "internal error")
}
valid, err := auth.ValidateTOTP(secret, req.Code)
if err != nil || !valid {
return nil, status.Error(codes.Unauthenticated, "invalid TOTP code")
}
// SetTOTP with existing enc/nonce sets totp_required=1, confirming enrollment.
if err := a.s.db.SetTOTP(acct.ID, acct.TOTPSecretEnc, acct.TOTPSecretNonce); err != nil {
return nil, status.Error(codes.Internal, "internal error")
}
a.s.db.WriteAuditEvent(model.EventTOTPEnrolled, &acct.ID, nil, peerIP(ctx), "") //nolint:errcheck
return &mciasv1.ConfirmTOTPResponse{}, nil
}
// RemoveTOTP removes TOTP from an account. Admin only.
func (a *authServiceServer) RemoveTOTP(ctx context.Context, req *mciasv1.RemoveTOTPRequest) (*mciasv1.RemoveTOTPResponse, error) {
if err := a.s.requireAdmin(ctx); err != nil {
return nil, err
}
if req.AccountId == "" {
return nil, status.Error(codes.InvalidArgument, "account_id is required")
}
acct, err := a.s.db.GetAccountByUUID(req.AccountId)
if err != nil {
return nil, status.Error(codes.NotFound, "account not found")
}
if err := a.s.db.ClearTOTP(acct.ID); err != nil {
return nil, status.Error(codes.Internal, "internal error")
}
a.s.db.WriteAuditEvent(model.EventTOTPRemoved, nil, &acct.ID, peerIP(ctx), "") //nolint:errcheck
return &mciasv1.RemoveTOTPResponse{}, nil
}
// peerIP extracts the client IP from gRPC peer context.
func peerIP(ctx context.Context) string {
p, ok := peer.FromContext(ctx)
if !ok {
return ""
}
host, _, err := net.SplitHostPort(p.Addr.String())
if err != nil {
return p.Addr.String()
}
return host
}