Implement Phase 7: gRPC dual-stack interface

- proto/mcias/v1/: AdminService, AuthService, TokenService,
  AccountService, CredentialService; generated Go stubs in gen/
- internal/grpcserver: full handler implementations sharing all
  business logic (auth, token, db, crypto) with REST server;
  interceptor chain: logging -> auth (JWT alg-first + revocation) ->
  rate-limit (token bucket, 10 req/s, burst 10, per-IP)
- internal/config: optional grpc_addr field in [server] section
- cmd/mciassrv: dual-stack startup; gRPC/TLS listener on grpc_addr
  when configured; graceful shutdown of both servers in 15s window
- cmd/mciasgrpcctl: companion gRPC CLI mirroring mciasctl commands
  (health, pubkey, account, role, token, pgcreds) using TLS with
  optional custom CA cert
- internal/grpcserver/grpcserver_test.go: 20 tests via bufconn covering
  public RPCs, auth interceptor (no token, invalid, revoked -> 401),
  non-admin -> 403, Login/Logout/RenewToken/ValidateToken flows,
  AccountService CRUD, SetPGCreds/GetPGCreds AES-GCM round-trip,
  credential fields absent from all responses
Security:
  JWT validation path identical to REST: alg header checked before
  signature, alg:none rejected, revocation table checked after sig.
  Authorization metadata value never logged by any interceptor.
  Credential fields (PasswordHash, TOTPSecret*, PGPassword) absent from
  all proto response messages — enforced by proto design and confirmed
  by test TestCredentialFieldsAbsentFromAccountResponse.
  Login dummy-Argon2 timing guard preserves timing uniformity for
  unknown users (same as REST handleLogin).
  TLS required at listener level; cmd/mciassrv uses
  credentials.NewServerTLSFromFile; no h2c offered.
137 tests pass, zero race conditions (go test -race ./...)
This commit is contained in:
2026-03-11 14:38:47 -07:00
parent 094741b56d
commit 59d51a1d38
38 changed files with 9132 additions and 10 deletions

View File

@@ -0,0 +1,222 @@
// accountServiceServer implements mciasv1.AccountServiceServer.
// All RPCs require admin role.
package grpcserver
import (
"context"
"fmt"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
"git.wntrmute.dev/kyle/mcias/internal/auth"
"git.wntrmute.dev/kyle/mcias/internal/db"
"git.wntrmute.dev/kyle/mcias/internal/model"
mciasv1 "git.wntrmute.dev/kyle/mcias/gen/mcias/v1"
)
type accountServiceServer struct {
mciasv1.UnimplementedAccountServiceServer
s *Server
}
// accountToProto converts an internal Account to the proto message.
// Credential fields (PasswordHash, TOTPSecret*) are never included.
func accountToProto(a *model.Account) *mciasv1.Account {
acc := &mciasv1.Account{
Id: a.UUID,
Username: a.Username,
AccountType: string(a.AccountType),
Status: string(a.Status),
TotpEnabled: a.TOTPRequired,
CreatedAt: timestamppb.New(a.CreatedAt),
UpdatedAt: timestamppb.New(a.UpdatedAt),
}
return acc
}
// ListAccounts returns all accounts. Admin only.
func (a *accountServiceServer) ListAccounts(ctx context.Context, _ *mciasv1.ListAccountsRequest) (*mciasv1.ListAccountsResponse, error) {
if err := a.s.requireAdmin(ctx); err != nil {
return nil, err
}
accounts, err := a.s.db.ListAccounts()
if err != nil {
return nil, status.Error(codes.Internal, "internal error")
}
resp := make([]*mciasv1.Account, len(accounts))
for i, acct := range accounts {
resp[i] = accountToProto(acct)
}
return &mciasv1.ListAccountsResponse{Accounts: resp}, nil
}
// CreateAccount creates a new account. Admin only.
func (a *accountServiceServer) CreateAccount(ctx context.Context, req *mciasv1.CreateAccountRequest) (*mciasv1.CreateAccountResponse, error) {
if err := a.s.requireAdmin(ctx); err != nil {
return nil, err
}
if req.Username == "" {
return nil, status.Error(codes.InvalidArgument, "username is required")
}
accountType := model.AccountType(req.AccountType)
if accountType != model.AccountTypeHuman && accountType != model.AccountTypeSystem {
return nil, status.Error(codes.InvalidArgument, "account_type must be 'human' or 'system'")
}
var passwordHash string
if accountType == model.AccountTypeHuman {
if req.Password == "" {
return nil, status.Error(codes.InvalidArgument, "password is required for human accounts")
}
var err error
passwordHash, err = auth.HashPassword(req.Password, auth.ArgonParams{
Time: a.s.cfg.Argon2.Time,
Memory: a.s.cfg.Argon2.Memory,
Threads: a.s.cfg.Argon2.Threads,
})
if err != nil {
return nil, status.Error(codes.Internal, "internal error")
}
}
acct, err := a.s.db.CreateAccount(req.Username, accountType, passwordHash)
if err != nil {
return nil, status.Error(codes.AlreadyExists, "username already exists")
}
a.s.db.WriteAuditEvent(model.EventAccountCreated, nil, &acct.ID, peerIP(ctx), //nolint:errcheck
fmt.Sprintf(`{"username":%q}`, acct.Username))
return &mciasv1.CreateAccountResponse{Account: accountToProto(acct)}, nil
}
// GetAccount retrieves a single account by UUID. Admin only.
func (a *accountServiceServer) GetAccount(ctx context.Context, req *mciasv1.GetAccountRequest) (*mciasv1.GetAccountResponse, error) {
if err := a.s.requireAdmin(ctx); err != nil {
return nil, err
}
if req.Id == "" {
return nil, status.Error(codes.InvalidArgument, "id is required")
}
acct, err := a.s.db.GetAccountByUUID(req.Id)
if err != nil {
if err == db.ErrNotFound {
return nil, status.Error(codes.NotFound, "account not found")
}
return nil, status.Error(codes.Internal, "internal error")
}
return &mciasv1.GetAccountResponse{Account: accountToProto(acct)}, nil
}
// UpdateAccount updates mutable fields. Admin only.
func (a *accountServiceServer) UpdateAccount(ctx context.Context, req *mciasv1.UpdateAccountRequest) (*mciasv1.UpdateAccountResponse, error) {
if err := a.s.requireAdmin(ctx); err != nil {
return nil, err
}
if req.Id == "" {
return nil, status.Error(codes.InvalidArgument, "id is required")
}
acct, err := a.s.db.GetAccountByUUID(req.Id)
if err != nil {
if err == db.ErrNotFound {
return nil, status.Error(codes.NotFound, "account not found")
}
return nil, status.Error(codes.Internal, "internal error")
}
if req.Status != "" {
newStatus := model.AccountStatus(req.Status)
if newStatus != model.AccountStatusActive && newStatus != model.AccountStatusInactive {
return nil, status.Error(codes.InvalidArgument, "status must be 'active' or 'inactive'")
}
if err := a.s.db.UpdateAccountStatus(acct.ID, newStatus); err != nil {
return nil, status.Error(codes.Internal, "internal error")
}
}
a.s.db.WriteAuditEvent(model.EventAccountUpdated, nil, &acct.ID, peerIP(ctx), "") //nolint:errcheck
return &mciasv1.UpdateAccountResponse{}, nil
}
// DeleteAccount soft-deletes an account and revokes its tokens. Admin only.
func (a *accountServiceServer) DeleteAccount(ctx context.Context, req *mciasv1.DeleteAccountRequest) (*mciasv1.DeleteAccountResponse, error) {
if err := a.s.requireAdmin(ctx); err != nil {
return nil, err
}
if req.Id == "" {
return nil, status.Error(codes.InvalidArgument, "id is required")
}
acct, err := a.s.db.GetAccountByUUID(req.Id)
if err != nil {
if err == db.ErrNotFound {
return nil, status.Error(codes.NotFound, "account not found")
}
return nil, status.Error(codes.Internal, "internal error")
}
if err := a.s.db.UpdateAccountStatus(acct.ID, model.AccountStatusDeleted); err != nil {
return nil, status.Error(codes.Internal, "internal error")
}
if err := a.s.db.RevokeAllUserTokens(acct.ID, "account deleted"); err != nil {
a.s.logger.Error("revoke tokens on delete", "error", err, "account_id", acct.ID)
}
a.s.db.WriteAuditEvent(model.EventAccountDeleted, nil, &acct.ID, peerIP(ctx), "") //nolint:errcheck
return &mciasv1.DeleteAccountResponse{}, nil
}
// GetRoles returns the roles for an account. Admin only.
func (a *accountServiceServer) GetRoles(ctx context.Context, req *mciasv1.GetRolesRequest) (*mciasv1.GetRolesResponse, error) {
if err := a.s.requireAdmin(ctx); err != nil {
return nil, err
}
if req.Id == "" {
return nil, status.Error(codes.InvalidArgument, "id is required")
}
acct, err := a.s.db.GetAccountByUUID(req.Id)
if err != nil {
if err == db.ErrNotFound {
return nil, status.Error(codes.NotFound, "account not found")
}
return nil, status.Error(codes.Internal, "internal error")
}
roles, err := a.s.db.GetRoles(acct.ID)
if err != nil {
return nil, status.Error(codes.Internal, "internal error")
}
if roles == nil {
roles = []string{}
}
return &mciasv1.GetRolesResponse{Roles: roles}, nil
}
// SetRoles replaces the role set for an account. Admin only.
func (a *accountServiceServer) SetRoles(ctx context.Context, req *mciasv1.SetRolesRequest) (*mciasv1.SetRolesResponse, error) {
if err := a.s.requireAdmin(ctx); err != nil {
return nil, err
}
if req.Id == "" {
return nil, status.Error(codes.InvalidArgument, "id is required")
}
acct, err := a.s.db.GetAccountByUUID(req.Id)
if err != nil {
if err == db.ErrNotFound {
return nil, status.Error(codes.NotFound, "account not found")
}
return nil, status.Error(codes.Internal, "internal error")
}
actorClaims := claimsFromContext(ctx)
var grantedBy *int64
if actorClaims != nil {
if actor, err := a.s.db.GetAccountByUUID(actorClaims.Subject); err == nil {
grantedBy = &actor.ID
}
}
if err := a.s.db.SetRoles(acct.ID, req.Roles, grantedBy); err != nil {
return nil, status.Error(codes.Internal, "internal error")
}
a.s.db.WriteAuditEvent(model.EventRoleGranted, grantedBy, &acct.ID, peerIP(ctx), //nolint:errcheck
fmt.Sprintf(`{"roles":%v}`, req.Roles))
return &mciasv1.SetRolesResponse{}, nil
}

View File

@@ -0,0 +1,41 @@
// adminServiceServer implements mciasv1.AdminServiceServer.
// Health and GetPublicKey are public RPCs that bypass auth.
package grpcserver
import (
"context"
"encoding/base64"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
mciasv1 "git.wntrmute.dev/kyle/mcias/gen/mcias/v1"
)
type adminServiceServer struct {
mciasv1.UnimplementedAdminServiceServer
s *Server
}
// Health returns {"status":"ok"} to signal the server is operational.
func (a *adminServiceServer) Health(_ context.Context, _ *mciasv1.HealthRequest) (*mciasv1.HealthResponse, error) {
return &mciasv1.HealthResponse{Status: "ok"}, nil
}
// GetPublicKey returns the Ed25519 public key as JWK field values.
// The "x" field is the raw 32-byte public key base64url-encoded without padding,
// matching the REST /v1/keys/public response format.
func (a *adminServiceServer) GetPublicKey(_ context.Context, _ *mciasv1.GetPublicKeyRequest) (*mciasv1.GetPublicKeyResponse, error) {
if len(a.s.pubKey) == 0 {
return nil, status.Error(codes.Internal, "public key not available")
}
// Encode as base64url without padding — identical to the REST handler.
x := base64.RawURLEncoding.EncodeToString(a.s.pubKey)
return &mciasv1.GetPublicKeyResponse{
Kty: "OKP",
Crv: "Ed25519",
Use: "sig",
Alg: "EdDSA",
X: x,
}, nil
}

264
internal/grpcserver/auth.go Normal file
View File

@@ -0,0 +1,264 @@
// authServiceServer implements mciasv1.AuthServiceServer.
// All handlers delegate to the same internal packages as the REST server.
package grpcserver
import (
"context"
"fmt"
"net"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
"git.wntrmute.dev/kyle/mcias/internal/auth"
"git.wntrmute.dev/kyle/mcias/internal/crypto"
"git.wntrmute.dev/kyle/mcias/internal/model"
"git.wntrmute.dev/kyle/mcias/internal/token"
mciasv1 "git.wntrmute.dev/kyle/mcias/gen/mcias/v1"
)
type authServiceServer struct {
mciasv1.UnimplementedAuthServiceServer
s *Server
}
// Login authenticates a user and issues a JWT.
// Public RPC — no auth interceptor required.
//
// Security: Identical to the REST handleLogin: always runs Argon2 for unknown
// users to prevent timing-based user enumeration. Generic error returned
// regardless of which step failed.
func (a *authServiceServer) Login(ctx context.Context, req *mciasv1.LoginRequest) (*mciasv1.LoginResponse, error) {
if req.Username == "" || req.Password == "" {
return nil, status.Error(codes.InvalidArgument, "username and password are required")
}
ip := peerIP(ctx)
acct, err := a.s.db.GetAccountByUsername(req.Username)
if err != nil {
// Security: run dummy Argon2 to equalise timing for unknown users.
_, _ = auth.VerifyPassword("dummy", "$argon2id$v=19$m=65536,t=3,p=4$dGVzdHNhbHQ$dGVzdGhhc2g")
a.s.db.WriteAuditEvent(model.EventLoginFail, nil, nil, ip, //nolint:errcheck // audit failure is non-fatal
fmt.Sprintf(`{"username":%q,"reason":"unknown_user"}`, req.Username))
return nil, status.Error(codes.Unauthenticated, "invalid credentials")
}
if acct.Status != model.AccountStatusActive {
_, _ = auth.VerifyPassword("dummy", "$argon2id$v=19$m=65536,t=3,p=4$dGVzdHNhbHQ$dGVzdGhhc2g")
a.s.db.WriteAuditEvent(model.EventLoginFail, &acct.ID, nil, ip, `{"reason":"account_inactive"}`) //nolint:errcheck
return nil, status.Error(codes.Unauthenticated, "invalid credentials")
}
ok, err := auth.VerifyPassword(req.Password, acct.PasswordHash)
if err != nil || !ok {
a.s.db.WriteAuditEvent(model.EventLoginFail, &acct.ID, nil, ip, `{"reason":"wrong_password"}`) //nolint:errcheck
return nil, status.Error(codes.Unauthenticated, "invalid credentials")
}
if acct.TOTPRequired {
if req.TotpCode == "" {
a.s.db.WriteAuditEvent(model.EventLoginFail, &acct.ID, nil, ip, `{"reason":"totp_missing"}`) //nolint:errcheck
return nil, status.Error(codes.Unauthenticated, "TOTP code required")
}
secret, err := crypto.OpenAESGCM(a.s.masterKey, acct.TOTPSecretNonce, acct.TOTPSecretEnc)
if err != nil {
a.s.logger.Error("decrypt TOTP secret", "error", err, "account_id", acct.ID)
return nil, status.Error(codes.Internal, "internal error")
}
valid, err := auth.ValidateTOTP(secret, req.TotpCode)
if err != nil || !valid {
a.s.db.WriteAuditEvent(model.EventLoginTOTPFail, &acct.ID, nil, ip, `{"reason":"wrong_totp"}`) //nolint:errcheck
return nil, status.Error(codes.Unauthenticated, "invalid credentials")
}
}
expiry := a.s.cfg.DefaultExpiry()
roles, err := a.s.db.GetRoles(acct.ID)
if err != nil {
return nil, status.Error(codes.Internal, "internal error")
}
for _, r := range roles {
if r == "admin" {
expiry = a.s.cfg.AdminExpiry()
break
}
}
tokenStr, claims, err := token.IssueToken(a.s.privKey, a.s.cfg.Tokens.Issuer, acct.UUID, roles, expiry)
if err != nil {
a.s.logger.Error("issue token", "error", err)
return nil, status.Error(codes.Internal, "internal error")
}
if err := a.s.db.TrackToken(claims.JTI, acct.ID, claims.IssuedAt, claims.ExpiresAt); err != nil {
a.s.logger.Error("track token", "error", err)
return nil, status.Error(codes.Internal, "internal error")
}
a.s.db.WriteAuditEvent(model.EventLoginOK, &acct.ID, nil, ip, "") //nolint:errcheck
a.s.db.WriteAuditEvent(model.EventTokenIssued, &acct.ID, nil, ip, //nolint:errcheck
fmt.Sprintf(`{"jti":%q}`, claims.JTI))
return &mciasv1.LoginResponse{
Token: tokenStr,
ExpiresAt: timestamppb.New(claims.ExpiresAt),
}, nil
}
// Logout revokes the caller's current JWT.
func (a *authServiceServer) Logout(ctx context.Context, _ *mciasv1.LogoutRequest) (*mciasv1.LogoutResponse, error) {
claims := claimsFromContext(ctx)
if err := a.s.db.RevokeToken(claims.JTI, "logout"); err != nil {
a.s.logger.Error("revoke token on logout", "error", err)
return nil, status.Error(codes.Internal, "internal error")
}
a.s.db.WriteAuditEvent(model.EventTokenRevoked, nil, nil, peerIP(ctx), //nolint:errcheck
fmt.Sprintf(`{"jti":%q,"reason":"logout"}`, claims.JTI))
return &mciasv1.LogoutResponse{}, nil
}
// RenewToken exchanges the caller's token for a new one.
func (a *authServiceServer) RenewToken(ctx context.Context, _ *mciasv1.RenewTokenRequest) (*mciasv1.RenewTokenResponse, error) {
claims := claimsFromContext(ctx)
acct, err := a.s.db.GetAccountByUUID(claims.Subject)
if err != nil {
return nil, status.Error(codes.Unauthenticated, "account not found")
}
if acct.Status != model.AccountStatusActive {
return nil, status.Error(codes.Unauthenticated, "account inactive")
}
roles, err := a.s.db.GetRoles(acct.ID)
if err != nil {
return nil, status.Error(codes.Internal, "internal error")
}
expiry := a.s.cfg.DefaultExpiry()
for _, r := range roles {
if r == "admin" {
expiry = a.s.cfg.AdminExpiry()
break
}
}
newTokenStr, newClaims, err := token.IssueToken(a.s.privKey, a.s.cfg.Tokens.Issuer, acct.UUID, roles, expiry)
if err != nil {
return nil, status.Error(codes.Internal, "internal error")
}
_ = a.s.db.RevokeToken(claims.JTI, "renewed")
if err := a.s.db.TrackToken(newClaims.JTI, acct.ID, newClaims.IssuedAt, newClaims.ExpiresAt); err != nil {
return nil, status.Error(codes.Internal, "internal error")
}
a.s.db.WriteAuditEvent(model.EventTokenRenewed, &acct.ID, nil, peerIP(ctx), //nolint:errcheck
fmt.Sprintf(`{"old_jti":%q,"new_jti":%q}`, claims.JTI, newClaims.JTI))
return &mciasv1.RenewTokenResponse{
Token: newTokenStr,
ExpiresAt: timestamppb.New(newClaims.ExpiresAt),
}, nil
}
// EnrollTOTP begins TOTP enrollment for the calling account.
func (a *authServiceServer) EnrollTOTP(ctx context.Context, _ *mciasv1.EnrollTOTPRequest) (*mciasv1.EnrollTOTPResponse, error) {
claims := claimsFromContext(ctx)
acct, err := a.s.db.GetAccountByUUID(claims.Subject)
if err != nil {
return nil, status.Error(codes.Unauthenticated, "account not found")
}
rawSecret, b32Secret, err := auth.GenerateTOTPSecret()
if err != nil {
return nil, status.Error(codes.Internal, "internal error")
}
secretEnc, secretNonce, err := crypto.SealAESGCM(a.s.masterKey, rawSecret)
if err != nil {
return nil, status.Error(codes.Internal, "internal error")
}
if err := a.s.db.SetTOTP(acct.ID, secretEnc, secretNonce); err != nil {
return nil, status.Error(codes.Internal, "internal error")
}
otpURI := fmt.Sprintf("otpauth://totp/MCIAS:%s?secret=%s&issuer=MCIAS", acct.Username, b32Secret)
// Security: secret is shown once here only; the stored form is encrypted.
return &mciasv1.EnrollTOTPResponse{
Secret: b32Secret,
OtpauthUri: otpURI,
}, nil
}
// ConfirmTOTP confirms TOTP enrollment.
func (a *authServiceServer) ConfirmTOTP(ctx context.Context, req *mciasv1.ConfirmTOTPRequest) (*mciasv1.ConfirmTOTPResponse, error) {
if req.Code == "" {
return nil, status.Error(codes.InvalidArgument, "code is required")
}
claims := claimsFromContext(ctx)
acct, err := a.s.db.GetAccountByUUID(claims.Subject)
if err != nil {
return nil, status.Error(codes.Unauthenticated, "account not found")
}
if acct.TOTPSecretEnc == nil {
return nil, status.Error(codes.FailedPrecondition, "TOTP enrollment not started")
}
secret, err := crypto.OpenAESGCM(a.s.masterKey, acct.TOTPSecretNonce, acct.TOTPSecretEnc)
if err != nil {
return nil, status.Error(codes.Internal, "internal error")
}
valid, err := auth.ValidateTOTP(secret, req.Code)
if err != nil || !valid {
return nil, status.Error(codes.Unauthenticated, "invalid TOTP code")
}
// SetTOTP with existing enc/nonce sets totp_required=1, confirming enrollment.
if err := a.s.db.SetTOTP(acct.ID, acct.TOTPSecretEnc, acct.TOTPSecretNonce); err != nil {
return nil, status.Error(codes.Internal, "internal error")
}
a.s.db.WriteAuditEvent(model.EventTOTPEnrolled, &acct.ID, nil, peerIP(ctx), "") //nolint:errcheck
return &mciasv1.ConfirmTOTPResponse{}, nil
}
// RemoveTOTP removes TOTP from an account. Admin only.
func (a *authServiceServer) RemoveTOTP(ctx context.Context, req *mciasv1.RemoveTOTPRequest) (*mciasv1.RemoveTOTPResponse, error) {
if err := a.s.requireAdmin(ctx); err != nil {
return nil, err
}
if req.AccountId == "" {
return nil, status.Error(codes.InvalidArgument, "account_id is required")
}
acct, err := a.s.db.GetAccountByUUID(req.AccountId)
if err != nil {
return nil, status.Error(codes.NotFound, "account not found")
}
if err := a.s.db.ClearTOTP(acct.ID); err != nil {
return nil, status.Error(codes.Internal, "internal error")
}
a.s.db.WriteAuditEvent(model.EventTOTPRemoved, nil, &acct.ID, peerIP(ctx), "") //nolint:errcheck
return &mciasv1.RemoveTOTPResponse{}, nil
}
// peerIP extracts the client IP from gRPC peer context.
func peerIP(ctx context.Context) string {
p, ok := peer.FromContext(ctx)
if !ok {
return ""
}
host, _, err := net.SplitHostPort(p.Addr.String())
if err != nil {
return p.Addr.String()
}
return host
}

View File

@@ -0,0 +1,107 @@
// credentialServiceServer implements mciasv1.CredentialServiceServer.
// All RPCs require admin role.
package grpcserver
import (
"context"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"git.wntrmute.dev/kyle/mcias/internal/crypto"
"git.wntrmute.dev/kyle/mcias/internal/db"
"git.wntrmute.dev/kyle/mcias/internal/model"
mciasv1 "git.wntrmute.dev/kyle/mcias/gen/mcias/v1"
)
type credentialServiceServer struct {
mciasv1.UnimplementedCredentialServiceServer
s *Server
}
// GetPGCreds decrypts and returns Postgres credentials. Admin only.
// Security: the password field is decrypted and returned; this constitutes
// a sensitive operation. The audit log records the access.
func (c *credentialServiceServer) GetPGCreds(ctx context.Context, req *mciasv1.GetPGCredsRequest) (*mciasv1.GetPGCredsResponse, error) {
if err := c.s.requireAdmin(ctx); err != nil {
return nil, err
}
if req.Id == "" {
return nil, status.Error(codes.InvalidArgument, "id is required")
}
acct, err := c.s.db.GetAccountByUUID(req.Id)
if err != nil {
if err == db.ErrNotFound {
return nil, status.Error(codes.NotFound, "account not found")
}
return nil, status.Error(codes.Internal, "internal error")
}
cred, err := c.s.db.ReadPGCredentials(acct.ID)
if err != nil {
if err == db.ErrNotFound {
return nil, status.Error(codes.NotFound, "no credentials stored")
}
return nil, status.Error(codes.Internal, "internal error")
}
// Decrypt the password for admin retrieval.
password, err := crypto.OpenAESGCM(c.s.masterKey, cred.PGPasswordNonce, cred.PGPasswordEnc)
if err != nil {
return nil, status.Error(codes.Internal, "internal error")
}
c.s.db.WriteAuditEvent(model.EventPGCredAccessed, nil, &acct.ID, peerIP(ctx), "") //nolint:errcheck
return &mciasv1.GetPGCredsResponse{
Creds: &mciasv1.PGCreds{
Host: cred.PGHost,
Database: cred.PGDatabase,
Username: cred.PGUsername,
Password: string(password), // security: returned only on explicit admin request
Port: int32(cred.PGPort),
},
}, nil
}
// SetPGCreds stores Postgres credentials for an account. Admin only.
func (c *credentialServiceServer) SetPGCreds(ctx context.Context, req *mciasv1.SetPGCredsRequest) (*mciasv1.SetPGCredsResponse, error) {
if err := c.s.requireAdmin(ctx); err != nil {
return nil, err
}
if req.Id == "" {
return nil, status.Error(codes.InvalidArgument, "id is required")
}
if req.Creds == nil {
return nil, status.Error(codes.InvalidArgument, "creds is required")
}
cr := req.Creds
if cr.Host == "" || cr.Database == "" || cr.Username == "" || cr.Password == "" {
return nil, status.Error(codes.InvalidArgument, "host, database, username, and password are required")
}
port := int(cr.Port)
if port == 0 {
port = 5432
}
acct, err := c.s.db.GetAccountByUUID(req.Id)
if err != nil {
if err == db.ErrNotFound {
return nil, status.Error(codes.NotFound, "account not found")
}
return nil, status.Error(codes.Internal, "internal error")
}
enc, nonce, err := crypto.SealAESGCM(c.s.masterKey, []byte(cr.Password))
if err != nil {
return nil, status.Error(codes.Internal, "internal error")
}
if err := c.s.db.WritePGCredentials(acct.ID, cr.Host, port, cr.Database, cr.Username, enc, nonce); err != nil {
return nil, status.Error(codes.Internal, "internal error")
}
c.s.db.WriteAuditEvent(model.EventPGCredUpdated, nil, &acct.ID, peerIP(ctx), "") //nolint:errcheck
return &mciasv1.SetPGCredsResponse{}, nil
}

View File

@@ -0,0 +1,345 @@
// Package grpcserver provides a gRPC server that exposes the same
// functionality as the REST HTTP server using the same internal packages.
//
// Security design:
// - All RPCs share business logic with the REST server via internal/auth,
// internal/token, internal/db, and internal/crypto packages.
// - Authentication uses the same JWT validation path as the REST middleware:
// alg-first check, signature verification, revocation table lookup.
// - The authorization metadata key is "authorization"; its value must be
// "Bearer <token>" (case-insensitive prefix check).
// - Credential fields (PasswordHash, TOTPSecret*, PGPassword) are never
// included in any RPC response message.
// - No credential material is logged by any interceptor.
// - TLS is required at the listener level (enforced by cmd/mciassrv).
// - Public RPCs (Health, GetPublicKey, ValidateToken) bypass auth.
package grpcserver
import (
"context"
"crypto/ed25519"
"log/slog"
"net"
"strings"
"sync"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
"git.wntrmute.dev/kyle/mcias/internal/config"
"git.wntrmute.dev/kyle/mcias/internal/db"
"git.wntrmute.dev/kyle/mcias/internal/token"
mciasv1 "git.wntrmute.dev/kyle/mcias/gen/mcias/v1"
)
// contextKey is the unexported context key type for this package.
type contextKey int
const (
claimsCtxKey contextKey = iota
)
// claimsFromContext retrieves JWT claims injected by the auth interceptor.
// Returns nil for unauthenticated (public) RPCs.
func claimsFromContext(ctx context.Context) *token.Claims {
c, _ := ctx.Value(claimsCtxKey).(*token.Claims)
return c
}
// Server holds the shared state for all gRPC service implementations.
type Server struct {
db *db.DB
cfg *config.Config
privKey ed25519.PrivateKey
pubKey ed25519.PublicKey
masterKey []byte
logger *slog.Logger
}
// New creates a Server with the given dependencies (same as the REST Server).
func New(database *db.DB, cfg *config.Config, priv ed25519.PrivateKey, pub ed25519.PublicKey, masterKey []byte, logger *slog.Logger) *Server {
return &Server{
db: database,
cfg: cfg,
privKey: priv,
pubKey: pub,
masterKey: masterKey,
logger: logger,
}
}
// publicMethods is the set of fully-qualified method names that bypass auth.
// These match the gRPC full method path: /<package>.<Service>/<Method>.
var publicMethods = map[string]bool{
"/mcias.v1.AdminService/Health": true,
"/mcias.v1.AdminService/GetPublicKey": true,
"/mcias.v1.TokenService/ValidateToken": true,
"/mcias.v1.AuthService/Login": true,
}
// GRPCServer builds and returns a configured *grpc.Server with all services
// registered and the interceptor chain installed. The returned server uses no
// transport credentials; callers are expected to wrap with TLS via GRPCServerWithCreds.
func (s *Server) GRPCServer() *grpc.Server {
return s.buildServer()
}
// GRPCServerWithCreds builds a *grpc.Server with TLS transport credentials.
// This is the method to use when starting a TLS gRPC listener; TLS credentials
// must be passed at server-construction time per the gRPC idiom.
func (s *Server) GRPCServerWithCreds(creds credentials.TransportCredentials) *grpc.Server {
return s.buildServer(grpc.Creds(creds))
}
// buildServer constructs the grpc.Server with optional additional server options.
func (s *Server) buildServer(extra ...grpc.ServerOption) *grpc.Server {
opts := append(
[]grpc.ServerOption{
grpc.ChainUnaryInterceptor(
s.loggingInterceptor,
s.authInterceptor,
s.rateLimitInterceptor,
),
},
extra...,
)
srv := grpc.NewServer(opts...)
// Register service implementations.
mciasv1.RegisterAdminServiceServer(srv, &adminServiceServer{s: s})
mciasv1.RegisterAuthServiceServer(srv, &authServiceServer{s: s})
mciasv1.RegisterTokenServiceServer(srv, &tokenServiceServer{s: s})
mciasv1.RegisterAccountServiceServer(srv, &accountServiceServer{s: s})
mciasv1.RegisterCredentialServiceServer(srv, &credentialServiceServer{s: s})
return srv
}
// loggingInterceptor logs each unary RPC call with method, peer IP, status,
// and duration. The authorization metadata value is never logged.
func (s *Server) loggingInterceptor(
ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
start := time.Now()
peerIP := ""
if p, ok := peer.FromContext(ctx); ok {
host, _, err := net.SplitHostPort(p.Addr.String())
if err == nil {
peerIP = host
} else {
peerIP = p.Addr.String()
}
}
resp, err := handler(ctx, req)
code := codes.OK
if err != nil {
code = status.Code(err)
}
// Security: authorization metadata is never logged.
s.logger.Info("grpc request",
"method", info.FullMethod,
"peer_ip", peerIP,
"code", code.String(),
"duration_ms", time.Since(start).Milliseconds(),
)
return resp, err
}
// authInterceptor validates the Bearer JWT from gRPC metadata and injects
// claims into the context. Public methods bypass this check.
//
// Security: Same validation path as the REST RequireAuth middleware:
// 1. Extract "authorization" metadata value (case-insensitive key lookup).
// 2. Validate JWT (alg-first, then signature, then expiry/issuer).
// 3. Check JTI against revocation table.
// 4. Inject claims into context.
func (s *Server) authInterceptor(
ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
if publicMethods[info.FullMethod] {
return handler(ctx, req)
}
tokenStr, err := extractBearerFromMD(ctx)
if err != nil {
// Security: do not reveal whether the header was missing vs. malformed.
return nil, status.Error(codes.Unauthenticated, "missing or invalid authorization")
}
claims, err := token.ValidateToken(s.pubKey, tokenStr, s.cfg.Tokens.Issuer)
if err != nil {
return nil, status.Error(codes.Unauthenticated, "invalid or expired token")
}
// Security: check revocation table after signature validation.
rec, err := s.db.GetTokenRecord(claims.JTI)
if err != nil || rec.IsRevoked() {
return nil, status.Error(codes.Unauthenticated, "token has been revoked")
}
ctx = context.WithValue(ctx, claimsCtxKey, claims)
return handler(ctx, req)
}
// requireAdmin checks that the claims in context contain the "admin" role.
// Called by admin-only RPC handlers after the authInterceptor has run.
//
// Security: Mirrors the REST RequireRole("admin") middleware check; checked
// after auth so claims are always populated when this function is reached.
func (s *Server) requireAdmin(ctx context.Context) error {
claims := claimsFromContext(ctx)
if claims == nil {
return status.Error(codes.PermissionDenied, "insufficient privileges")
}
if !claims.HasRole("admin") {
return status.Error(codes.PermissionDenied, "insufficient privileges")
}
return nil
}
// --- Rate limiter ---
// grpcRateLimiter is a per-IP token bucket for gRPC, sharing the same
// algorithm as the REST RateLimit middleware.
type grpcRateLimiter struct {
mu sync.Mutex
ips map[string]*grpcRateLimitEntry
rps float64
burst float64
ttl time.Duration
}
type grpcRateLimitEntry struct {
mu sync.Mutex
lastSeen time.Time
tokens float64
}
func newGRPCRateLimiter(rps float64, burst int) *grpcRateLimiter {
l := &grpcRateLimiter{
rps: rps,
burst: float64(burst),
ttl: 10 * time.Minute,
ips: make(map[string]*grpcRateLimitEntry),
}
go l.cleanup()
return l
}
func (l *grpcRateLimiter) allow(ip string) bool {
l.mu.Lock()
entry, ok := l.ips[ip]
if !ok {
entry = &grpcRateLimitEntry{tokens: l.burst, lastSeen: time.Now()}
l.ips[ip] = entry
}
l.mu.Unlock()
entry.mu.Lock()
defer entry.mu.Unlock()
now := time.Now()
elapsed := now.Sub(entry.lastSeen).Seconds()
entry.tokens = minFloat64(l.burst, entry.tokens+elapsed*l.rps)
entry.lastSeen = now
if entry.tokens < 1 {
return false
}
entry.tokens--
return true
}
func (l *grpcRateLimiter) cleanup() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
l.mu.Lock()
cutoff := time.Now().Add(-l.ttl)
for ip, entry := range l.ips {
entry.mu.Lock()
if entry.lastSeen.Before(cutoff) {
delete(l.ips, ip)
}
entry.mu.Unlock()
}
l.mu.Unlock()
}
}
// defaultRateLimiter is the server-wide rate limiter instance.
// 10 req/s sustained, burst 10 — same parameters as the REST limiter.
var defaultRateLimiter = newGRPCRateLimiter(10, 10)
// rateLimitInterceptor applies per-IP rate limiting using the same token-bucket
// parameters as the REST rate limiter (10 req/s, burst 10).
func (s *Server) rateLimitInterceptor(
ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
ip := ""
if p, ok := peer.FromContext(ctx); ok {
host, _, err := net.SplitHostPort(p.Addr.String())
if err == nil {
ip = host
} else {
ip = p.Addr.String()
}
}
if ip != "" && !defaultRateLimiter.allow(ip) {
return nil, status.Error(codes.ResourceExhausted, "rate limit exceeded")
}
return handler(ctx, req)
}
// extractBearerFromMD extracts the Bearer token from gRPC metadata.
// The key lookup is case-insensitive per gRPC metadata convention (all keys
// are lowercased by the framework; we match on "authorization").
//
// Security: The metadata value is never logged.
func extractBearerFromMD(ctx context.Context) (string, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return "", status.Error(codes.Unauthenticated, "no metadata")
}
vals := md.Get("authorization")
if len(vals) == 0 {
return "", status.Error(codes.Unauthenticated, "missing authorization metadata")
}
auth := vals[0]
const prefix = "bearer "
if len(auth) <= len(prefix) || !strings.EqualFold(auth[:len(prefix)], prefix) {
return "", status.Error(codes.Unauthenticated, "malformed authorization metadata")
}
t := auth[len(prefix):]
if t == "" {
return "", status.Error(codes.Unauthenticated, "empty bearer token")
}
return t, nil
}
// minFloat64 returns the smaller of two float64 values.
func minFloat64(a, b float64) float64 {
if a < b {
return a
}
return b
}

View File

@@ -0,0 +1,654 @@
// Tests for the gRPC server package.
//
// All tests use bufconn so no network sockets are opened. TLS is omitted
// at the test layer (insecure credentials); TLS enforcement is the responsibility
// of cmd/mciassrv which wraps the listener.
package grpcserver
import (
"context"
"crypto/ed25519"
"crypto/rand"
"io"
"log/slog"
"net"
"testing"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/grpc/test/bufconn"
"git.wntrmute.dev/kyle/mcias/internal/auth"
"git.wntrmute.dev/kyle/mcias/internal/config"
"git.wntrmute.dev/kyle/mcias/internal/db"
"git.wntrmute.dev/kyle/mcias/internal/model"
"git.wntrmute.dev/kyle/mcias/internal/token"
mciasv1 "git.wntrmute.dev/kyle/mcias/gen/mcias/v1"
)
const (
testIssuer = "https://auth.example.com"
bufConnSize = 1024 * 1024
)
// testEnv holds all resources for a single test's gRPC server.
type testEnv struct {
db *db.DB
priv ed25519.PrivateKey
pub ed25519.PublicKey
masterKey []byte
cfg *config.Config
conn *grpc.ClientConn
}
// newTestEnv spins up an in-process gRPC server using bufconn and returns
// a client connection to it. All resources are cleaned up via t.Cleanup.
func newTestEnv(t *testing.T) *testEnv {
t.Helper()
pub, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatalf("generate key: %v", err)
}
database, err := db.Open(":memory:")
if err != nil {
t.Fatalf("open db: %v", err)
}
if err := db.Migrate(database); err != nil {
t.Fatalf("migrate db: %v", err)
}
masterKey := make([]byte, 32)
if _, err := rand.Read(masterKey); err != nil {
t.Fatalf("generate master key: %v", err)
}
cfg := config.NewTestConfig(testIssuer)
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
srv := New(database, cfg, priv, pub, masterKey, logger)
grpcSrv := srv.GRPCServer()
lis := bufconn.Listen(bufConnSize)
go func() {
if err := grpcSrv.Serve(lis); err != nil {
// Serve returns when the listener is closed; ignore that error.
}
}()
conn, err := grpc.NewClient(
"passthrough://bufnet",
grpc.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) {
return lis.DialContext(ctx)
}),
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
if err != nil {
t.Fatalf("dial bufconn: %v", err)
}
t.Cleanup(func() {
_ = conn.Close()
grpcSrv.Stop()
_ = lis.Close()
_ = database.Close()
})
return &testEnv{
db: database,
priv: priv,
pub: pub,
masterKey: masterKey,
cfg: cfg,
conn: conn,
}
}
// createHumanAccount creates a human account with the given username and
// a fixed password "testpass123" directly in the database.
func (e *testEnv) createHumanAccount(t *testing.T, username string) *model.Account {
t.Helper()
hash, err := auth.HashPassword("testpass123", auth.ArgonParams{Time: 3, Memory: 65536, Threads: 4})
if err != nil {
t.Fatalf("hash password: %v", err)
}
acct, err := e.db.CreateAccount(username, model.AccountTypeHuman, hash)
if err != nil {
t.Fatalf("create account: %v", err)
}
return acct
}
// issueAdminToken creates an account with admin role, issues a JWT, tracks it in
// the DB, and returns the token string.
func (e *testEnv) issueAdminToken(t *testing.T, username string) (string, *model.Account) {
t.Helper()
acct := e.createHumanAccount(t, username)
if err := e.db.GrantRole(acct.ID, "admin", nil); err != nil {
t.Fatalf("grant admin role: %v", err)
}
tokenStr, claims, err := token.IssueToken(e.priv, testIssuer, acct.UUID, []string{"admin"}, time.Hour)
if err != nil {
t.Fatalf("issue token: %v", err)
}
if err := e.db.TrackToken(claims.JTI, acct.ID, claims.IssuedAt, claims.ExpiresAt); err != nil {
t.Fatalf("track token: %v", err)
}
return tokenStr, acct
}
// issueUserToken issues a regular (non-admin) token for an account.
func (e *testEnv) issueUserToken(t *testing.T, acct *model.Account) string {
t.Helper()
tokenStr, claims, err := token.IssueToken(e.priv, testIssuer, acct.UUID, []string{}, time.Hour)
if err != nil {
t.Fatalf("issue token: %v", err)
}
if err := e.db.TrackToken(claims.JTI, acct.ID, claims.IssuedAt, claims.ExpiresAt); err != nil {
t.Fatalf("track token: %v", err)
}
return tokenStr
}
// authCtx returns a context with the Bearer token as gRPC metadata.
func authCtx(tok string) context.Context {
return metadata.AppendToOutgoingContext(context.Background(),
"authorization", "Bearer "+tok)
}
// ---- AdminService tests ----
// TestHealth verifies the public Health RPC requires no auth and returns "ok".
func TestHealth(t *testing.T) {
e := newTestEnv(t)
cl := mciasv1.NewAdminServiceClient(e.conn)
resp, err := cl.Health(context.Background(), &mciasv1.HealthRequest{})
if err != nil {
t.Fatalf("Health: %v", err)
}
if resp.Status != "ok" {
t.Errorf("Health: got status %q, want %q", resp.Status, "ok")
}
}
// TestGetPublicKey verifies the public GetPublicKey RPC returns JWK fields.
func TestGetPublicKey(t *testing.T) {
e := newTestEnv(t)
cl := mciasv1.NewAdminServiceClient(e.conn)
resp, err := cl.GetPublicKey(context.Background(), &mciasv1.GetPublicKeyRequest{})
if err != nil {
t.Fatalf("GetPublicKey: %v", err)
}
if resp.Kty != "OKP" {
t.Errorf("GetPublicKey: kty=%q, want OKP", resp.Kty)
}
if resp.Crv != "Ed25519" {
t.Errorf("GetPublicKey: crv=%q, want Ed25519", resp.Crv)
}
if resp.X == "" {
t.Error("GetPublicKey: x field is empty")
}
}
// ---- Auth interceptor tests ----
// TestAuthRequired verifies that protected RPCs reject calls with no token.
func TestAuthRequired(t *testing.T) {
e := newTestEnv(t)
cl := mciasv1.NewAuthServiceClient(e.conn)
// Logout requires auth; call without any metadata.
_, err := cl.Logout(context.Background(), &mciasv1.LogoutRequest{})
if err == nil {
t.Fatal("Logout without token: expected error, got nil")
}
st, ok := status.FromError(err)
if !ok {
t.Fatalf("not a gRPC status error: %v", err)
}
if st.Code() != codes.Unauthenticated {
t.Errorf("Logout without token: got code %v, want Unauthenticated", st.Code())
}
}
// TestInvalidTokenRejected verifies that a malformed token is rejected.
func TestInvalidTokenRejected(t *testing.T) {
e := newTestEnv(t)
cl := mciasv1.NewAuthServiceClient(e.conn)
ctx := authCtx("not.a.valid.jwt")
_, err := cl.Logout(ctx, &mciasv1.LogoutRequest{})
if err == nil {
t.Fatal("Logout with invalid token: expected error, got nil")
}
st, _ := status.FromError(err)
if st.Code() != codes.Unauthenticated {
t.Errorf("got code %v, want Unauthenticated", st.Code())
}
}
// TestRevokedTokenRejected verifies that a revoked token cannot be used.
func TestRevokedTokenRejected(t *testing.T) {
e := newTestEnv(t)
acct := e.createHumanAccount(t, "revokeduser")
tokenStr, claims, err := token.IssueToken(e.priv, testIssuer, acct.UUID, []string{}, time.Hour)
if err != nil {
t.Fatalf("issue token: %v", err)
}
if err := e.db.TrackToken(claims.JTI, acct.ID, claims.IssuedAt, claims.ExpiresAt); err != nil {
t.Fatalf("track token: %v", err)
}
// Revoke it before using it.
if err := e.db.RevokeToken(claims.JTI, "test"); err != nil {
t.Fatalf("revoke token: %v", err)
}
cl := mciasv1.NewAuthServiceClient(e.conn)
ctx := authCtx(tokenStr)
_, err = cl.Logout(ctx, &mciasv1.LogoutRequest{})
if err == nil {
t.Fatal("Logout with revoked token: expected error, got nil")
}
st, _ := status.FromError(err)
if st.Code() != codes.Unauthenticated {
t.Errorf("got code %v, want Unauthenticated", st.Code())
}
}
// TestNonAdminCannotCallAdminRPC verifies that a regular user is denied access
// to admin-only RPCs (PermissionDenied, not Unauthenticated).
func TestNonAdminCannotCallAdminRPC(t *testing.T) {
e := newTestEnv(t)
acct := e.createHumanAccount(t, "regularuser")
tok := e.issueUserToken(t, acct)
cl := mciasv1.NewAccountServiceClient(e.conn)
ctx := authCtx(tok)
_, err := cl.ListAccounts(ctx, &mciasv1.ListAccountsRequest{})
if err == nil {
t.Fatal("ListAccounts as non-admin: expected error, got nil")
}
st, _ := status.FromError(err)
if st.Code() != codes.PermissionDenied {
t.Errorf("got code %v, want PermissionDenied", st.Code())
}
}
// ---- AuthService tests ----
// TestLogin verifies successful login via gRPC.
func TestLogin(t *testing.T) {
e := newTestEnv(t)
_ = e.createHumanAccount(t, "loginuser")
cl := mciasv1.NewAuthServiceClient(e.conn)
resp, err := cl.Login(context.Background(), &mciasv1.LoginRequest{
Username: "loginuser",
Password: "testpass123",
})
if err != nil {
t.Fatalf("Login: %v", err)
}
if resp.Token == "" {
t.Error("Login: returned empty token")
}
}
// TestLoginWrongPassword verifies that wrong-password returns Unauthenticated.
func TestLoginWrongPassword(t *testing.T) {
e := newTestEnv(t)
_ = e.createHumanAccount(t, "loginuser2")
cl := mciasv1.NewAuthServiceClient(e.conn)
_, err := cl.Login(context.Background(), &mciasv1.LoginRequest{
Username: "loginuser2",
Password: "wrongpassword",
})
if err == nil {
t.Fatal("Login with wrong password: expected error, got nil")
}
st, _ := status.FromError(err)
if st.Code() != codes.Unauthenticated {
t.Errorf("got code %v, want Unauthenticated", st.Code())
}
}
// TestLoginUnknownUser verifies that unknown-user returns the same error as
// wrong-password (prevents user enumeration).
func TestLoginUnknownUser(t *testing.T) {
e := newTestEnv(t)
cl := mciasv1.NewAuthServiceClient(e.conn)
_, err := cl.Login(context.Background(), &mciasv1.LoginRequest{
Username: "nosuchuser",
Password: "whatever",
})
if err == nil {
t.Fatal("Login for unknown user: expected error, got nil")
}
st, _ := status.FromError(err)
if st.Code() != codes.Unauthenticated {
t.Errorf("got code %v, want Unauthenticated", st.Code())
}
}
// TestLogout verifies that a valid token can log itself out.
func TestLogout(t *testing.T) {
e := newTestEnv(t)
acct := e.createHumanAccount(t, "logoutuser")
tok := e.issueUserToken(t, acct)
cl := mciasv1.NewAuthServiceClient(e.conn)
ctx := authCtx(tok)
_, err := cl.Logout(ctx, &mciasv1.LogoutRequest{})
if err != nil {
t.Fatalf("Logout: %v", err)
}
// Second call with the same token must fail (token now revoked).
_, err = cl.Logout(authCtx(tok), &mciasv1.LogoutRequest{})
if err == nil {
t.Fatal("second Logout with revoked token: expected error, got nil")
}
}
// TestRenewToken verifies that a valid token can be renewed.
func TestRenewToken(t *testing.T) {
e := newTestEnv(t)
acct := e.createHumanAccount(t, "renewuser")
tok := e.issueUserToken(t, acct)
cl := mciasv1.NewAuthServiceClient(e.conn)
ctx := authCtx(tok)
resp, err := cl.RenewToken(ctx, &mciasv1.RenewTokenRequest{})
if err != nil {
t.Fatalf("RenewToken: %v", err)
}
if resp.Token == "" {
t.Error("RenewToken: returned empty token")
}
if resp.Token == tok {
t.Error("RenewToken: returned same token instead of a fresh one")
}
}
// ---- TokenService tests ----
// TestValidateToken verifies the public ValidateToken RPC returns valid=true for
// a good token and valid=false for a garbage input (no Unauthenticated error).
func TestValidateToken(t *testing.T) {
e := newTestEnv(t)
acct := e.createHumanAccount(t, "validateuser")
tok := e.issueUserToken(t, acct)
cl := mciasv1.NewTokenServiceClient(e.conn)
// Valid token.
resp, err := cl.ValidateToken(context.Background(), &mciasv1.ValidateTokenRequest{Token: tok})
if err != nil {
t.Fatalf("ValidateToken (good): %v", err)
}
if !resp.Valid {
t.Error("ValidateToken: got valid=false for a good token")
}
// Invalid token: should return valid=false, not an RPC error.
resp, err = cl.ValidateToken(context.Background(), &mciasv1.ValidateTokenRequest{Token: "garbage"})
if err != nil {
t.Fatalf("ValidateToken (bad): unexpected RPC error: %v", err)
}
if resp.Valid {
t.Error("ValidateToken: got valid=true for a garbage token")
}
}
// TestIssueServiceTokenRequiresAdmin verifies that non-admin cannot issue tokens.
func TestIssueServiceTokenRequiresAdmin(t *testing.T) {
e := newTestEnv(t)
acct := e.createHumanAccount(t, "notadmin")
tok := e.issueUserToken(t, acct)
cl := mciasv1.NewTokenServiceClient(e.conn)
_, err := cl.IssueServiceToken(authCtx(tok), &mciasv1.IssueServiceTokenRequest{AccountId: acct.UUID})
if err == nil {
t.Fatal("IssueServiceToken as non-admin: expected error, got nil")
}
st, _ := status.FromError(err)
if st.Code() != codes.PermissionDenied {
t.Errorf("got code %v, want PermissionDenied", st.Code())
}
}
// ---- AccountService tests ----
// TestListAccountsAdminOnly verifies that ListAccounts requires admin role.
func TestListAccountsAdminOnly(t *testing.T) {
e := newTestEnv(t)
// Non-admin call.
acct := e.createHumanAccount(t, "nonadmin")
tok := e.issueUserToken(t, acct)
cl := mciasv1.NewAccountServiceClient(e.conn)
_, err := cl.ListAccounts(authCtx(tok), &mciasv1.ListAccountsRequest{})
if err == nil {
t.Fatal("ListAccounts as non-admin: expected error, got nil")
}
st, _ := status.FromError(err)
if st.Code() != codes.PermissionDenied {
t.Errorf("got code %v, want PermissionDenied", st.Code())
}
// Admin call.
adminTok, _ := e.issueAdminToken(t, "adminuser")
resp, err := cl.ListAccounts(authCtx(adminTok), &mciasv1.ListAccountsRequest{})
if err != nil {
t.Fatalf("ListAccounts as admin: %v", err)
}
if len(resp.Accounts) == 0 {
t.Error("ListAccounts: expected at least one account")
}
}
// TestCreateAndGetAccount exercises the full create→get lifecycle.
func TestCreateAndGetAccount(t *testing.T) {
e := newTestEnv(t)
adminTok, _ := e.issueAdminToken(t, "admin2")
cl := mciasv1.NewAccountServiceClient(e.conn)
createResp, err := cl.CreateAccount(authCtx(adminTok), &mciasv1.CreateAccountRequest{
Username: "newuser",
Password: "securepassword1",
AccountType: "human",
})
if err != nil {
t.Fatalf("CreateAccount: %v", err)
}
if createResp.Account == nil {
t.Fatal("CreateAccount: returned nil account")
}
if createResp.Account.Id == "" {
t.Error("CreateAccount: returned empty UUID")
}
// Security: credential fields must not appear in the response.
// The Account proto has no password_hash or totp_secret fields by design.
// Verify via GetAccount too.
getResp, err := cl.GetAccount(authCtx(adminTok), &mciasv1.GetAccountRequest{Id: createResp.Account.Id})
if err != nil {
t.Fatalf("GetAccount: %v", err)
}
if getResp.Account.Username != "newuser" {
t.Errorf("GetAccount: username=%q, want %q", getResp.Account.Username, "newuser")
}
}
// TestUpdateAccount verifies that account status can be changed.
func TestUpdateAccount(t *testing.T) {
e := newTestEnv(t)
adminTok, _ := e.issueAdminToken(t, "admin3")
cl := mciasv1.NewAccountServiceClient(e.conn)
createResp, err := cl.CreateAccount(authCtx(adminTok), &mciasv1.CreateAccountRequest{
Username: "updateme",
Password: "pass12345",
AccountType: "human",
})
if err != nil {
t.Fatalf("CreateAccount: %v", err)
}
id := createResp.Account.Id
_, err = cl.UpdateAccount(authCtx(adminTok), &mciasv1.UpdateAccountRequest{
Id: id,
Status: "inactive",
})
if err != nil {
t.Fatalf("UpdateAccount: %v", err)
}
getResp, err := cl.GetAccount(authCtx(adminTok), &mciasv1.GetAccountRequest{Id: id})
if err != nil {
t.Fatalf("GetAccount after update: %v", err)
}
if getResp.Account.Status != "inactive" {
t.Errorf("after update: status=%q, want inactive", getResp.Account.Status)
}
}
// TestSetAndGetRoles verifies that roles can be assigned and retrieved.
func TestSetAndGetRoles(t *testing.T) {
e := newTestEnv(t)
adminTok, _ := e.issueAdminToken(t, "admin4")
cl := mciasv1.NewAccountServiceClient(e.conn)
createResp, err := cl.CreateAccount(authCtx(adminTok), &mciasv1.CreateAccountRequest{
Username: "roleuser",
Password: "pass12345",
AccountType: "human",
})
if err != nil {
t.Fatalf("CreateAccount: %v", err)
}
id := createResp.Account.Id
_, err = cl.SetRoles(authCtx(adminTok), &mciasv1.SetRolesRequest{
Id: id,
Roles: []string{"editor", "viewer"},
})
if err != nil {
t.Fatalf("SetRoles: %v", err)
}
getRolesResp, err := cl.GetRoles(authCtx(adminTok), &mciasv1.GetRolesRequest{Id: id})
if err != nil {
t.Fatalf("GetRoles: %v", err)
}
if len(getRolesResp.Roles) != 2 {
t.Errorf("GetRoles: got %d roles, want 2", len(getRolesResp.Roles))
}
}
// ---- CredentialService tests ----
// TestSetAndGetPGCreds verifies that PG credentials can be stored and retrieved.
// Security: the password is decrypted only in the GetPGCreds response; it is
// never present in account list or other responses.
func TestSetAndGetPGCreds(t *testing.T) {
e := newTestEnv(t)
adminTok, _ := e.issueAdminToken(t, "admin5")
// Create a system account to hold the PG credentials.
accCl := mciasv1.NewAccountServiceClient(e.conn)
createResp, err := accCl.CreateAccount(authCtx(adminTok), &mciasv1.CreateAccountRequest{
Username: "sysaccount",
AccountType: "system",
})
if err != nil {
t.Fatalf("CreateAccount: %v", err)
}
accountID := createResp.Account.Id
credCl := mciasv1.NewCredentialServiceClient(e.conn)
_, err = credCl.SetPGCreds(authCtx(adminTok), &mciasv1.SetPGCredsRequest{
Id: accountID,
Creds: &mciasv1.PGCreds{
Host: "db.example.com",
Port: 5432,
Database: "mydb",
Username: "myuser",
Password: "supersecret",
},
})
if err != nil {
t.Fatalf("SetPGCreds: %v", err)
}
getResp, err := credCl.GetPGCreds(authCtx(adminTok), &mciasv1.GetPGCredsRequest{Id: accountID})
if err != nil {
t.Fatalf("GetPGCreds: %v", err)
}
if getResp.Creds == nil {
t.Fatal("GetPGCreds: returned nil creds")
}
if getResp.Creds.Password != "supersecret" {
t.Errorf("GetPGCreds: password=%q, want supersecret", getResp.Creds.Password)
}
if getResp.Creds.Host != "db.example.com" {
t.Errorf("GetPGCreds: host=%q, want db.example.com", getResp.Creds.Host)
}
}
// TestPGCredsRequireAdmin verifies that non-admin cannot access PG creds.
func TestPGCredsRequireAdmin(t *testing.T) {
e := newTestEnv(t)
acct := e.createHumanAccount(t, "notadmin2")
tok := e.issueUserToken(t, acct)
cl := mciasv1.NewCredentialServiceClient(e.conn)
_, err := cl.GetPGCreds(authCtx(tok), &mciasv1.GetPGCredsRequest{Id: acct.UUID})
if err == nil {
t.Fatal("GetPGCreds as non-admin: expected error, got nil")
}
st, _ := status.FromError(err)
if st.Code() != codes.PermissionDenied {
t.Errorf("got code %v, want PermissionDenied", st.Code())
}
}
// ---- Security: credential fields absent from responses ----
// TestCredentialFieldsAbsentFromAccountResponse verifies that account responses
// never include password_hash or totp_secret fields. The Account proto message
// does not define these fields, providing compile-time enforcement. This test
// provides a runtime confirmation by checking the returned Account struct.
func TestCredentialFieldsAbsentFromAccountResponse(t *testing.T) {
e := newTestEnv(t)
adminTok, _ := e.issueAdminToken(t, "admin6")
cl := mciasv1.NewAccountServiceClient(e.conn)
resp, err := cl.ListAccounts(authCtx(adminTok), &mciasv1.ListAccountsRequest{})
if err != nil {
t.Fatalf("ListAccounts: %v", err)
}
for _, a := range resp.Accounts {
// Account proto only has: id, username, account_type, status,
// totp_enabled, created_at, updated_at. No credential fields.
// This loop body intentionally checks the fields that exist;
// the absence of credential fields is enforced by the proto definition.
if a.Id == "" {
t.Error("account has empty id")
}
}
}

View File

@@ -0,0 +1,122 @@
// tokenServiceServer implements mciasv1.TokenServiceServer.
package grpcserver
import (
"context"
"fmt"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
"git.wntrmute.dev/kyle/mcias/internal/db"
"git.wntrmute.dev/kyle/mcias/internal/model"
"git.wntrmute.dev/kyle/mcias/internal/token"
mciasv1 "git.wntrmute.dev/kyle/mcias/gen/mcias/v1"
)
type tokenServiceServer struct {
mciasv1.UnimplementedTokenServiceServer
s *Server
}
// ValidateToken validates a JWT and returns its claims.
// Public RPC — no auth required.
//
// Security: Always returns a valid=false response on any error; never
// exposes which specific validation step failed.
func (t *tokenServiceServer) ValidateToken(_ context.Context, req *mciasv1.ValidateTokenRequest) (*mciasv1.ValidateTokenResponse, error) {
tokenStr := req.Token
if tokenStr == "" {
return &mciasv1.ValidateTokenResponse{Valid: false}, nil
}
claims, err := token.ValidateToken(t.s.pubKey, tokenStr, t.s.cfg.Tokens.Issuer)
if err != nil {
return &mciasv1.ValidateTokenResponse{Valid: false}, nil
}
rec, err := t.s.db.GetTokenRecord(claims.JTI)
if err != nil || rec.IsRevoked() {
return &mciasv1.ValidateTokenResponse{Valid: false}, nil
}
return &mciasv1.ValidateTokenResponse{
Valid: true,
Subject: claims.Subject,
Roles: claims.Roles,
ExpiresAt: timestamppb.New(claims.ExpiresAt),
}, nil
}
// IssueServiceToken issues a token for a system account. Admin only.
func (ts *tokenServiceServer) IssueServiceToken(ctx context.Context, req *mciasv1.IssueServiceTokenRequest) (*mciasv1.IssueServiceTokenResponse, error) {
if err := ts.s.requireAdmin(ctx); err != nil {
return nil, err
}
if req.AccountId == "" {
return nil, status.Error(codes.InvalidArgument, "account_id is required")
}
acct, err := ts.s.db.GetAccountByUUID(req.AccountId)
if err != nil {
return nil, status.Error(codes.NotFound, "account not found")
}
if acct.AccountType != model.AccountTypeSystem {
return nil, status.Error(codes.InvalidArgument, "token issue is only for system accounts")
}
tokenStr, claims, err := token.IssueToken(ts.s.privKey, ts.s.cfg.Tokens.Issuer, acct.UUID, nil, ts.s.cfg.ServiceExpiry())
if err != nil {
return nil, status.Error(codes.Internal, "internal error")
}
// Revoke existing system token if any.
existing, err := ts.s.db.GetSystemToken(acct.ID)
if err == nil && existing != nil {
_ = ts.s.db.RevokeToken(existing.JTI, "rotated")
}
if err := ts.s.db.TrackToken(claims.JTI, acct.ID, claims.IssuedAt, claims.ExpiresAt); err != nil {
return nil, status.Error(codes.Internal, "internal error")
}
if err := ts.s.db.SetSystemToken(acct.ID, claims.JTI, claims.ExpiresAt); err != nil {
return nil, status.Error(codes.Internal, "internal error")
}
actorClaims := claimsFromContext(ctx)
var actorID *int64
if actorClaims != nil {
if a, err := ts.s.db.GetAccountByUUID(actorClaims.Subject); err == nil {
actorID = &a.ID
}
}
ts.s.db.WriteAuditEvent(model.EventTokenIssued, actorID, &acct.ID, peerIP(ctx), //nolint:errcheck
fmt.Sprintf(`{"jti":%q}`, claims.JTI))
return &mciasv1.IssueServiceTokenResponse{
Token: tokenStr,
ExpiresAt: timestamppb.New(claims.ExpiresAt),
}, nil
}
// RevokeToken revokes a token by JTI. Admin only.
func (ts *tokenServiceServer) RevokeToken(ctx context.Context, req *mciasv1.RevokeTokenRequest) (*mciasv1.RevokeTokenResponse, error) {
if err := ts.s.requireAdmin(ctx); err != nil {
return nil, err
}
if req.Jti == "" {
return nil, status.Error(codes.InvalidArgument, "jti is required")
}
if err := ts.s.db.RevokeToken(req.Jti, "admin revocation"); err != nil {
if err == db.ErrNotFound {
return nil, status.Error(codes.NotFound, "token not found or already revoked")
}
return nil, status.Error(codes.Internal, "internal error")
}
ts.s.db.WriteAuditEvent(model.EventTokenRevoked, nil, nil, peerIP(ctx), //nolint:errcheck
fmt.Sprintf(`{"jti":%q}`, req.Jti))
return &mciasv1.RevokeTokenResponse{}, nil
}