Implement Phase 7: gRPC dual-stack interface
- proto/mcias/v1/: AdminService, AuthService, TokenService, AccountService, CredentialService; generated Go stubs in gen/ - internal/grpcserver: full handler implementations sharing all business logic (auth, token, db, crypto) with REST server; interceptor chain: logging -> auth (JWT alg-first + revocation) -> rate-limit (token bucket, 10 req/s, burst 10, per-IP) - internal/config: optional grpc_addr field in [server] section - cmd/mciassrv: dual-stack startup; gRPC/TLS listener on grpc_addr when configured; graceful shutdown of both servers in 15s window - cmd/mciasgrpcctl: companion gRPC CLI mirroring mciasctl commands (health, pubkey, account, role, token, pgcreds) using TLS with optional custom CA cert - internal/grpcserver/grpcserver_test.go: 20 tests via bufconn covering public RPCs, auth interceptor (no token, invalid, revoked -> 401), non-admin -> 403, Login/Logout/RenewToken/ValidateToken flows, AccountService CRUD, SetPGCreds/GetPGCreds AES-GCM round-trip, credential fields absent from all responses Security: JWT validation path identical to REST: alg header checked before signature, alg:none rejected, revocation table checked after sig. Authorization metadata value never logged by any interceptor. Credential fields (PasswordHash, TOTPSecret*, PGPassword) absent from all proto response messages — enforced by proto design and confirmed by test TestCredentialFieldsAbsentFromAccountResponse. Login dummy-Argon2 timing guard preserves timing uniformity for unknown users (same as REST handleLogin). TLS required at listener level; cmd/mciassrv uses credentials.NewServerTLSFromFile; no h2c offered. 137 tests pass, zero race conditions (go test -race ./...)
This commit is contained in:
222
internal/grpcserver/accountservice.go
Normal file
222
internal/grpcserver/accountservice.go
Normal file
@@ -0,0 +1,222 @@
|
||||
// accountServiceServer implements mciasv1.AccountServiceServer.
|
||||
// All RPCs require admin role.
|
||||
package grpcserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"git.wntrmute.dev/kyle/mcias/internal/auth"
|
||||
"git.wntrmute.dev/kyle/mcias/internal/db"
|
||||
"git.wntrmute.dev/kyle/mcias/internal/model"
|
||||
mciasv1 "git.wntrmute.dev/kyle/mcias/gen/mcias/v1"
|
||||
)
|
||||
|
||||
type accountServiceServer struct {
|
||||
mciasv1.UnimplementedAccountServiceServer
|
||||
s *Server
|
||||
}
|
||||
|
||||
// accountToProto converts an internal Account to the proto message.
|
||||
// Credential fields (PasswordHash, TOTPSecret*) are never included.
|
||||
func accountToProto(a *model.Account) *mciasv1.Account {
|
||||
acc := &mciasv1.Account{
|
||||
Id: a.UUID,
|
||||
Username: a.Username,
|
||||
AccountType: string(a.AccountType),
|
||||
Status: string(a.Status),
|
||||
TotpEnabled: a.TOTPRequired,
|
||||
CreatedAt: timestamppb.New(a.CreatedAt),
|
||||
UpdatedAt: timestamppb.New(a.UpdatedAt),
|
||||
}
|
||||
return acc
|
||||
}
|
||||
|
||||
// ListAccounts returns all accounts. Admin only.
|
||||
func (a *accountServiceServer) ListAccounts(ctx context.Context, _ *mciasv1.ListAccountsRequest) (*mciasv1.ListAccountsResponse, error) {
|
||||
if err := a.s.requireAdmin(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
accounts, err := a.s.db.ListAccounts()
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Internal, "internal error")
|
||||
}
|
||||
resp := make([]*mciasv1.Account, len(accounts))
|
||||
for i, acct := range accounts {
|
||||
resp[i] = accountToProto(acct)
|
||||
}
|
||||
return &mciasv1.ListAccountsResponse{Accounts: resp}, nil
|
||||
}
|
||||
|
||||
// CreateAccount creates a new account. Admin only.
|
||||
func (a *accountServiceServer) CreateAccount(ctx context.Context, req *mciasv1.CreateAccountRequest) (*mciasv1.CreateAccountResponse, error) {
|
||||
if err := a.s.requireAdmin(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if req.Username == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "username is required")
|
||||
}
|
||||
accountType := model.AccountType(req.AccountType)
|
||||
if accountType != model.AccountTypeHuman && accountType != model.AccountTypeSystem {
|
||||
return nil, status.Error(codes.InvalidArgument, "account_type must be 'human' or 'system'")
|
||||
}
|
||||
|
||||
var passwordHash string
|
||||
if accountType == model.AccountTypeHuman {
|
||||
if req.Password == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "password is required for human accounts")
|
||||
}
|
||||
var err error
|
||||
passwordHash, err = auth.HashPassword(req.Password, auth.ArgonParams{
|
||||
Time: a.s.cfg.Argon2.Time,
|
||||
Memory: a.s.cfg.Argon2.Memory,
|
||||
Threads: a.s.cfg.Argon2.Threads,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Internal, "internal error")
|
||||
}
|
||||
}
|
||||
|
||||
acct, err := a.s.db.CreateAccount(req.Username, accountType, passwordHash)
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.AlreadyExists, "username already exists")
|
||||
}
|
||||
|
||||
a.s.db.WriteAuditEvent(model.EventAccountCreated, nil, &acct.ID, peerIP(ctx), //nolint:errcheck
|
||||
fmt.Sprintf(`{"username":%q}`, acct.Username))
|
||||
return &mciasv1.CreateAccountResponse{Account: accountToProto(acct)}, nil
|
||||
}
|
||||
|
||||
// GetAccount retrieves a single account by UUID. Admin only.
|
||||
func (a *accountServiceServer) GetAccount(ctx context.Context, req *mciasv1.GetAccountRequest) (*mciasv1.GetAccountResponse, error) {
|
||||
if err := a.s.requireAdmin(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if req.Id == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "id is required")
|
||||
}
|
||||
acct, err := a.s.db.GetAccountByUUID(req.Id)
|
||||
if err != nil {
|
||||
if err == db.ErrNotFound {
|
||||
return nil, status.Error(codes.NotFound, "account not found")
|
||||
}
|
||||
return nil, status.Error(codes.Internal, "internal error")
|
||||
}
|
||||
return &mciasv1.GetAccountResponse{Account: accountToProto(acct)}, nil
|
||||
}
|
||||
|
||||
// UpdateAccount updates mutable fields. Admin only.
|
||||
func (a *accountServiceServer) UpdateAccount(ctx context.Context, req *mciasv1.UpdateAccountRequest) (*mciasv1.UpdateAccountResponse, error) {
|
||||
if err := a.s.requireAdmin(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if req.Id == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "id is required")
|
||||
}
|
||||
acct, err := a.s.db.GetAccountByUUID(req.Id)
|
||||
if err != nil {
|
||||
if err == db.ErrNotFound {
|
||||
return nil, status.Error(codes.NotFound, "account not found")
|
||||
}
|
||||
return nil, status.Error(codes.Internal, "internal error")
|
||||
}
|
||||
|
||||
if req.Status != "" {
|
||||
newStatus := model.AccountStatus(req.Status)
|
||||
if newStatus != model.AccountStatusActive && newStatus != model.AccountStatusInactive {
|
||||
return nil, status.Error(codes.InvalidArgument, "status must be 'active' or 'inactive'")
|
||||
}
|
||||
if err := a.s.db.UpdateAccountStatus(acct.ID, newStatus); err != nil {
|
||||
return nil, status.Error(codes.Internal, "internal error")
|
||||
}
|
||||
}
|
||||
|
||||
a.s.db.WriteAuditEvent(model.EventAccountUpdated, nil, &acct.ID, peerIP(ctx), "") //nolint:errcheck
|
||||
return &mciasv1.UpdateAccountResponse{}, nil
|
||||
}
|
||||
|
||||
// DeleteAccount soft-deletes an account and revokes its tokens. Admin only.
|
||||
func (a *accountServiceServer) DeleteAccount(ctx context.Context, req *mciasv1.DeleteAccountRequest) (*mciasv1.DeleteAccountResponse, error) {
|
||||
if err := a.s.requireAdmin(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if req.Id == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "id is required")
|
||||
}
|
||||
acct, err := a.s.db.GetAccountByUUID(req.Id)
|
||||
if err != nil {
|
||||
if err == db.ErrNotFound {
|
||||
return nil, status.Error(codes.NotFound, "account not found")
|
||||
}
|
||||
return nil, status.Error(codes.Internal, "internal error")
|
||||
}
|
||||
if err := a.s.db.UpdateAccountStatus(acct.ID, model.AccountStatusDeleted); err != nil {
|
||||
return nil, status.Error(codes.Internal, "internal error")
|
||||
}
|
||||
if err := a.s.db.RevokeAllUserTokens(acct.ID, "account deleted"); err != nil {
|
||||
a.s.logger.Error("revoke tokens on delete", "error", err, "account_id", acct.ID)
|
||||
}
|
||||
a.s.db.WriteAuditEvent(model.EventAccountDeleted, nil, &acct.ID, peerIP(ctx), "") //nolint:errcheck
|
||||
return &mciasv1.DeleteAccountResponse{}, nil
|
||||
}
|
||||
|
||||
// GetRoles returns the roles for an account. Admin only.
|
||||
func (a *accountServiceServer) GetRoles(ctx context.Context, req *mciasv1.GetRolesRequest) (*mciasv1.GetRolesResponse, error) {
|
||||
if err := a.s.requireAdmin(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if req.Id == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "id is required")
|
||||
}
|
||||
acct, err := a.s.db.GetAccountByUUID(req.Id)
|
||||
if err != nil {
|
||||
if err == db.ErrNotFound {
|
||||
return nil, status.Error(codes.NotFound, "account not found")
|
||||
}
|
||||
return nil, status.Error(codes.Internal, "internal error")
|
||||
}
|
||||
roles, err := a.s.db.GetRoles(acct.ID)
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Internal, "internal error")
|
||||
}
|
||||
if roles == nil {
|
||||
roles = []string{}
|
||||
}
|
||||
return &mciasv1.GetRolesResponse{Roles: roles}, nil
|
||||
}
|
||||
|
||||
// SetRoles replaces the role set for an account. Admin only.
|
||||
func (a *accountServiceServer) SetRoles(ctx context.Context, req *mciasv1.SetRolesRequest) (*mciasv1.SetRolesResponse, error) {
|
||||
if err := a.s.requireAdmin(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if req.Id == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "id is required")
|
||||
}
|
||||
acct, err := a.s.db.GetAccountByUUID(req.Id)
|
||||
if err != nil {
|
||||
if err == db.ErrNotFound {
|
||||
return nil, status.Error(codes.NotFound, "account not found")
|
||||
}
|
||||
return nil, status.Error(codes.Internal, "internal error")
|
||||
}
|
||||
|
||||
actorClaims := claimsFromContext(ctx)
|
||||
var grantedBy *int64
|
||||
if actorClaims != nil {
|
||||
if actor, err := a.s.db.GetAccountByUUID(actorClaims.Subject); err == nil {
|
||||
grantedBy = &actor.ID
|
||||
}
|
||||
}
|
||||
|
||||
if err := a.s.db.SetRoles(acct.ID, req.Roles, grantedBy); err != nil {
|
||||
return nil, status.Error(codes.Internal, "internal error")
|
||||
}
|
||||
a.s.db.WriteAuditEvent(model.EventRoleGranted, grantedBy, &acct.ID, peerIP(ctx), //nolint:errcheck
|
||||
fmt.Sprintf(`{"roles":%v}`, req.Roles))
|
||||
return &mciasv1.SetRolesResponse{}, nil
|
||||
}
|
||||
41
internal/grpcserver/admin.go
Normal file
41
internal/grpcserver/admin.go
Normal file
@@ -0,0 +1,41 @@
|
||||
// adminServiceServer implements mciasv1.AdminServiceServer.
|
||||
// Health and GetPublicKey are public RPCs that bypass auth.
|
||||
package grpcserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
mciasv1 "git.wntrmute.dev/kyle/mcias/gen/mcias/v1"
|
||||
)
|
||||
|
||||
type adminServiceServer struct {
|
||||
mciasv1.UnimplementedAdminServiceServer
|
||||
s *Server
|
||||
}
|
||||
|
||||
// Health returns {"status":"ok"} to signal the server is operational.
|
||||
func (a *adminServiceServer) Health(_ context.Context, _ *mciasv1.HealthRequest) (*mciasv1.HealthResponse, error) {
|
||||
return &mciasv1.HealthResponse{Status: "ok"}, nil
|
||||
}
|
||||
|
||||
// GetPublicKey returns the Ed25519 public key as JWK field values.
|
||||
// The "x" field is the raw 32-byte public key base64url-encoded without padding,
|
||||
// matching the REST /v1/keys/public response format.
|
||||
func (a *adminServiceServer) GetPublicKey(_ context.Context, _ *mciasv1.GetPublicKeyRequest) (*mciasv1.GetPublicKeyResponse, error) {
|
||||
if len(a.s.pubKey) == 0 {
|
||||
return nil, status.Error(codes.Internal, "public key not available")
|
||||
}
|
||||
// Encode as base64url without padding — identical to the REST handler.
|
||||
x := base64.RawURLEncoding.EncodeToString(a.s.pubKey)
|
||||
return &mciasv1.GetPublicKeyResponse{
|
||||
Kty: "OKP",
|
||||
Crv: "Ed25519",
|
||||
Use: "sig",
|
||||
Alg: "EdDSA",
|
||||
X: x,
|
||||
}, nil
|
||||
}
|
||||
264
internal/grpcserver/auth.go
Normal file
264
internal/grpcserver/auth.go
Normal file
@@ -0,0 +1,264 @@
|
||||
// authServiceServer implements mciasv1.AuthServiceServer.
|
||||
// All handlers delegate to the same internal packages as the REST server.
|
||||
package grpcserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/peer"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"git.wntrmute.dev/kyle/mcias/internal/auth"
|
||||
"git.wntrmute.dev/kyle/mcias/internal/crypto"
|
||||
"git.wntrmute.dev/kyle/mcias/internal/model"
|
||||
"git.wntrmute.dev/kyle/mcias/internal/token"
|
||||
mciasv1 "git.wntrmute.dev/kyle/mcias/gen/mcias/v1"
|
||||
)
|
||||
|
||||
type authServiceServer struct {
|
||||
mciasv1.UnimplementedAuthServiceServer
|
||||
s *Server
|
||||
}
|
||||
|
||||
// Login authenticates a user and issues a JWT.
|
||||
// Public RPC — no auth interceptor required.
|
||||
//
|
||||
// Security: Identical to the REST handleLogin: always runs Argon2 for unknown
|
||||
// users to prevent timing-based user enumeration. Generic error returned
|
||||
// regardless of which step failed.
|
||||
func (a *authServiceServer) Login(ctx context.Context, req *mciasv1.LoginRequest) (*mciasv1.LoginResponse, error) {
|
||||
if req.Username == "" || req.Password == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "username and password are required")
|
||||
}
|
||||
|
||||
ip := peerIP(ctx)
|
||||
|
||||
acct, err := a.s.db.GetAccountByUsername(req.Username)
|
||||
if err != nil {
|
||||
// Security: run dummy Argon2 to equalise timing for unknown users.
|
||||
_, _ = auth.VerifyPassword("dummy", "$argon2id$v=19$m=65536,t=3,p=4$dGVzdHNhbHQ$dGVzdGhhc2g")
|
||||
a.s.db.WriteAuditEvent(model.EventLoginFail, nil, nil, ip, //nolint:errcheck // audit failure is non-fatal
|
||||
fmt.Sprintf(`{"username":%q,"reason":"unknown_user"}`, req.Username))
|
||||
return nil, status.Error(codes.Unauthenticated, "invalid credentials")
|
||||
}
|
||||
|
||||
if acct.Status != model.AccountStatusActive {
|
||||
_, _ = auth.VerifyPassword("dummy", "$argon2id$v=19$m=65536,t=3,p=4$dGVzdHNhbHQ$dGVzdGhhc2g")
|
||||
a.s.db.WriteAuditEvent(model.EventLoginFail, &acct.ID, nil, ip, `{"reason":"account_inactive"}`) //nolint:errcheck
|
||||
return nil, status.Error(codes.Unauthenticated, "invalid credentials")
|
||||
}
|
||||
|
||||
ok, err := auth.VerifyPassword(req.Password, acct.PasswordHash)
|
||||
if err != nil || !ok {
|
||||
a.s.db.WriteAuditEvent(model.EventLoginFail, &acct.ID, nil, ip, `{"reason":"wrong_password"}`) //nolint:errcheck
|
||||
return nil, status.Error(codes.Unauthenticated, "invalid credentials")
|
||||
}
|
||||
|
||||
if acct.TOTPRequired {
|
||||
if req.TotpCode == "" {
|
||||
a.s.db.WriteAuditEvent(model.EventLoginFail, &acct.ID, nil, ip, `{"reason":"totp_missing"}`) //nolint:errcheck
|
||||
return nil, status.Error(codes.Unauthenticated, "TOTP code required")
|
||||
}
|
||||
secret, err := crypto.OpenAESGCM(a.s.masterKey, acct.TOTPSecretNonce, acct.TOTPSecretEnc)
|
||||
if err != nil {
|
||||
a.s.logger.Error("decrypt TOTP secret", "error", err, "account_id", acct.ID)
|
||||
return nil, status.Error(codes.Internal, "internal error")
|
||||
}
|
||||
valid, err := auth.ValidateTOTP(secret, req.TotpCode)
|
||||
if err != nil || !valid {
|
||||
a.s.db.WriteAuditEvent(model.EventLoginTOTPFail, &acct.ID, nil, ip, `{"reason":"wrong_totp"}`) //nolint:errcheck
|
||||
return nil, status.Error(codes.Unauthenticated, "invalid credentials")
|
||||
}
|
||||
}
|
||||
|
||||
expiry := a.s.cfg.DefaultExpiry()
|
||||
roles, err := a.s.db.GetRoles(acct.ID)
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Internal, "internal error")
|
||||
}
|
||||
for _, r := range roles {
|
||||
if r == "admin" {
|
||||
expiry = a.s.cfg.AdminExpiry()
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
tokenStr, claims, err := token.IssueToken(a.s.privKey, a.s.cfg.Tokens.Issuer, acct.UUID, roles, expiry)
|
||||
if err != nil {
|
||||
a.s.logger.Error("issue token", "error", err)
|
||||
return nil, status.Error(codes.Internal, "internal error")
|
||||
}
|
||||
if err := a.s.db.TrackToken(claims.JTI, acct.ID, claims.IssuedAt, claims.ExpiresAt); err != nil {
|
||||
a.s.logger.Error("track token", "error", err)
|
||||
return nil, status.Error(codes.Internal, "internal error")
|
||||
}
|
||||
|
||||
a.s.db.WriteAuditEvent(model.EventLoginOK, &acct.ID, nil, ip, "") //nolint:errcheck
|
||||
a.s.db.WriteAuditEvent(model.EventTokenIssued, &acct.ID, nil, ip, //nolint:errcheck
|
||||
fmt.Sprintf(`{"jti":%q}`, claims.JTI))
|
||||
|
||||
return &mciasv1.LoginResponse{
|
||||
Token: tokenStr,
|
||||
ExpiresAt: timestamppb.New(claims.ExpiresAt),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Logout revokes the caller's current JWT.
|
||||
func (a *authServiceServer) Logout(ctx context.Context, _ *mciasv1.LogoutRequest) (*mciasv1.LogoutResponse, error) {
|
||||
claims := claimsFromContext(ctx)
|
||||
if err := a.s.db.RevokeToken(claims.JTI, "logout"); err != nil {
|
||||
a.s.logger.Error("revoke token on logout", "error", err)
|
||||
return nil, status.Error(codes.Internal, "internal error")
|
||||
}
|
||||
a.s.db.WriteAuditEvent(model.EventTokenRevoked, nil, nil, peerIP(ctx), //nolint:errcheck
|
||||
fmt.Sprintf(`{"jti":%q,"reason":"logout"}`, claims.JTI))
|
||||
return &mciasv1.LogoutResponse{}, nil
|
||||
}
|
||||
|
||||
// RenewToken exchanges the caller's token for a new one.
|
||||
func (a *authServiceServer) RenewToken(ctx context.Context, _ *mciasv1.RenewTokenRequest) (*mciasv1.RenewTokenResponse, error) {
|
||||
claims := claimsFromContext(ctx)
|
||||
|
||||
acct, err := a.s.db.GetAccountByUUID(claims.Subject)
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Unauthenticated, "account not found")
|
||||
}
|
||||
if acct.Status != model.AccountStatusActive {
|
||||
return nil, status.Error(codes.Unauthenticated, "account inactive")
|
||||
}
|
||||
|
||||
roles, err := a.s.db.GetRoles(acct.ID)
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Internal, "internal error")
|
||||
}
|
||||
|
||||
expiry := a.s.cfg.DefaultExpiry()
|
||||
for _, r := range roles {
|
||||
if r == "admin" {
|
||||
expiry = a.s.cfg.AdminExpiry()
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
newTokenStr, newClaims, err := token.IssueToken(a.s.privKey, a.s.cfg.Tokens.Issuer, acct.UUID, roles, expiry)
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Internal, "internal error")
|
||||
}
|
||||
|
||||
_ = a.s.db.RevokeToken(claims.JTI, "renewed")
|
||||
if err := a.s.db.TrackToken(newClaims.JTI, acct.ID, newClaims.IssuedAt, newClaims.ExpiresAt); err != nil {
|
||||
return nil, status.Error(codes.Internal, "internal error")
|
||||
}
|
||||
|
||||
a.s.db.WriteAuditEvent(model.EventTokenRenewed, &acct.ID, nil, peerIP(ctx), //nolint:errcheck
|
||||
fmt.Sprintf(`{"old_jti":%q,"new_jti":%q}`, claims.JTI, newClaims.JTI))
|
||||
|
||||
return &mciasv1.RenewTokenResponse{
|
||||
Token: newTokenStr,
|
||||
ExpiresAt: timestamppb.New(newClaims.ExpiresAt),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// EnrollTOTP begins TOTP enrollment for the calling account.
|
||||
func (a *authServiceServer) EnrollTOTP(ctx context.Context, _ *mciasv1.EnrollTOTPRequest) (*mciasv1.EnrollTOTPResponse, error) {
|
||||
claims := claimsFromContext(ctx)
|
||||
acct, err := a.s.db.GetAccountByUUID(claims.Subject)
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Unauthenticated, "account not found")
|
||||
}
|
||||
|
||||
rawSecret, b32Secret, err := auth.GenerateTOTPSecret()
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Internal, "internal error")
|
||||
}
|
||||
|
||||
secretEnc, secretNonce, err := crypto.SealAESGCM(a.s.masterKey, rawSecret)
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Internal, "internal error")
|
||||
}
|
||||
|
||||
if err := a.s.db.SetTOTP(acct.ID, secretEnc, secretNonce); err != nil {
|
||||
return nil, status.Error(codes.Internal, "internal error")
|
||||
}
|
||||
|
||||
otpURI := fmt.Sprintf("otpauth://totp/MCIAS:%s?secret=%s&issuer=MCIAS", acct.Username, b32Secret)
|
||||
|
||||
// Security: secret is shown once here only; the stored form is encrypted.
|
||||
return &mciasv1.EnrollTOTPResponse{
|
||||
Secret: b32Secret,
|
||||
OtpauthUri: otpURI,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ConfirmTOTP confirms TOTP enrollment.
|
||||
func (a *authServiceServer) ConfirmTOTP(ctx context.Context, req *mciasv1.ConfirmTOTPRequest) (*mciasv1.ConfirmTOTPResponse, error) {
|
||||
if req.Code == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "code is required")
|
||||
}
|
||||
|
||||
claims := claimsFromContext(ctx)
|
||||
acct, err := a.s.db.GetAccountByUUID(claims.Subject)
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Unauthenticated, "account not found")
|
||||
}
|
||||
if acct.TOTPSecretEnc == nil {
|
||||
return nil, status.Error(codes.FailedPrecondition, "TOTP enrollment not started")
|
||||
}
|
||||
|
||||
secret, err := crypto.OpenAESGCM(a.s.masterKey, acct.TOTPSecretNonce, acct.TOTPSecretEnc)
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Internal, "internal error")
|
||||
}
|
||||
|
||||
valid, err := auth.ValidateTOTP(secret, req.Code)
|
||||
if err != nil || !valid {
|
||||
return nil, status.Error(codes.Unauthenticated, "invalid TOTP code")
|
||||
}
|
||||
|
||||
// SetTOTP with existing enc/nonce sets totp_required=1, confirming enrollment.
|
||||
if err := a.s.db.SetTOTP(acct.ID, acct.TOTPSecretEnc, acct.TOTPSecretNonce); err != nil {
|
||||
return nil, status.Error(codes.Internal, "internal error")
|
||||
}
|
||||
|
||||
a.s.db.WriteAuditEvent(model.EventTOTPEnrolled, &acct.ID, nil, peerIP(ctx), "") //nolint:errcheck
|
||||
return &mciasv1.ConfirmTOTPResponse{}, nil
|
||||
}
|
||||
|
||||
// RemoveTOTP removes TOTP from an account. Admin only.
|
||||
func (a *authServiceServer) RemoveTOTP(ctx context.Context, req *mciasv1.RemoveTOTPRequest) (*mciasv1.RemoveTOTPResponse, error) {
|
||||
if err := a.s.requireAdmin(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if req.AccountId == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "account_id is required")
|
||||
}
|
||||
|
||||
acct, err := a.s.db.GetAccountByUUID(req.AccountId)
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.NotFound, "account not found")
|
||||
}
|
||||
|
||||
if err := a.s.db.ClearTOTP(acct.ID); err != nil {
|
||||
return nil, status.Error(codes.Internal, "internal error")
|
||||
}
|
||||
|
||||
a.s.db.WriteAuditEvent(model.EventTOTPRemoved, nil, &acct.ID, peerIP(ctx), "") //nolint:errcheck
|
||||
return &mciasv1.RemoveTOTPResponse{}, nil
|
||||
}
|
||||
|
||||
// peerIP extracts the client IP from gRPC peer context.
|
||||
func peerIP(ctx context.Context) string {
|
||||
p, ok := peer.FromContext(ctx)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
host, _, err := net.SplitHostPort(p.Addr.String())
|
||||
if err != nil {
|
||||
return p.Addr.String()
|
||||
}
|
||||
return host
|
||||
}
|
||||
107
internal/grpcserver/credentialservice.go
Normal file
107
internal/grpcserver/credentialservice.go
Normal file
@@ -0,0 +1,107 @@
|
||||
// credentialServiceServer implements mciasv1.CredentialServiceServer.
|
||||
// All RPCs require admin role.
|
||||
package grpcserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"git.wntrmute.dev/kyle/mcias/internal/crypto"
|
||||
"git.wntrmute.dev/kyle/mcias/internal/db"
|
||||
"git.wntrmute.dev/kyle/mcias/internal/model"
|
||||
mciasv1 "git.wntrmute.dev/kyle/mcias/gen/mcias/v1"
|
||||
)
|
||||
|
||||
type credentialServiceServer struct {
|
||||
mciasv1.UnimplementedCredentialServiceServer
|
||||
s *Server
|
||||
}
|
||||
|
||||
// GetPGCreds decrypts and returns Postgres credentials. Admin only.
|
||||
// Security: the password field is decrypted and returned; this constitutes
|
||||
// a sensitive operation. The audit log records the access.
|
||||
func (c *credentialServiceServer) GetPGCreds(ctx context.Context, req *mciasv1.GetPGCredsRequest) (*mciasv1.GetPGCredsResponse, error) {
|
||||
if err := c.s.requireAdmin(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if req.Id == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "id is required")
|
||||
}
|
||||
acct, err := c.s.db.GetAccountByUUID(req.Id)
|
||||
if err != nil {
|
||||
if err == db.ErrNotFound {
|
||||
return nil, status.Error(codes.NotFound, "account not found")
|
||||
}
|
||||
return nil, status.Error(codes.Internal, "internal error")
|
||||
}
|
||||
|
||||
cred, err := c.s.db.ReadPGCredentials(acct.ID)
|
||||
if err != nil {
|
||||
if err == db.ErrNotFound {
|
||||
return nil, status.Error(codes.NotFound, "no credentials stored")
|
||||
}
|
||||
return nil, status.Error(codes.Internal, "internal error")
|
||||
}
|
||||
|
||||
// Decrypt the password for admin retrieval.
|
||||
password, err := crypto.OpenAESGCM(c.s.masterKey, cred.PGPasswordNonce, cred.PGPasswordEnc)
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Internal, "internal error")
|
||||
}
|
||||
|
||||
c.s.db.WriteAuditEvent(model.EventPGCredAccessed, nil, &acct.ID, peerIP(ctx), "") //nolint:errcheck
|
||||
|
||||
return &mciasv1.GetPGCredsResponse{
|
||||
Creds: &mciasv1.PGCreds{
|
||||
Host: cred.PGHost,
|
||||
Database: cred.PGDatabase,
|
||||
Username: cred.PGUsername,
|
||||
Password: string(password), // security: returned only on explicit admin request
|
||||
Port: int32(cred.PGPort),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SetPGCreds stores Postgres credentials for an account. Admin only.
|
||||
func (c *credentialServiceServer) SetPGCreds(ctx context.Context, req *mciasv1.SetPGCredsRequest) (*mciasv1.SetPGCredsResponse, error) {
|
||||
if err := c.s.requireAdmin(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if req.Id == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "id is required")
|
||||
}
|
||||
if req.Creds == nil {
|
||||
return nil, status.Error(codes.InvalidArgument, "creds is required")
|
||||
}
|
||||
|
||||
cr := req.Creds
|
||||
if cr.Host == "" || cr.Database == "" || cr.Username == "" || cr.Password == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "host, database, username, and password are required")
|
||||
}
|
||||
port := int(cr.Port)
|
||||
if port == 0 {
|
||||
port = 5432
|
||||
}
|
||||
|
||||
acct, err := c.s.db.GetAccountByUUID(req.Id)
|
||||
if err != nil {
|
||||
if err == db.ErrNotFound {
|
||||
return nil, status.Error(codes.NotFound, "account not found")
|
||||
}
|
||||
return nil, status.Error(codes.Internal, "internal error")
|
||||
}
|
||||
|
||||
enc, nonce, err := crypto.SealAESGCM(c.s.masterKey, []byte(cr.Password))
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Internal, "internal error")
|
||||
}
|
||||
|
||||
if err := c.s.db.WritePGCredentials(acct.ID, cr.Host, port, cr.Database, cr.Username, enc, nonce); err != nil {
|
||||
return nil, status.Error(codes.Internal, "internal error")
|
||||
}
|
||||
|
||||
c.s.db.WriteAuditEvent(model.EventPGCredUpdated, nil, &acct.ID, peerIP(ctx), "") //nolint:errcheck
|
||||
return &mciasv1.SetPGCredsResponse{}, nil
|
||||
}
|
||||
345
internal/grpcserver/grpcserver.go
Normal file
345
internal/grpcserver/grpcserver.go
Normal file
@@ -0,0 +1,345 @@
|
||||
// Package grpcserver provides a gRPC server that exposes the same
|
||||
// functionality as the REST HTTP server using the same internal packages.
|
||||
//
|
||||
// Security design:
|
||||
// - All RPCs share business logic with the REST server via internal/auth,
|
||||
// internal/token, internal/db, and internal/crypto packages.
|
||||
// - Authentication uses the same JWT validation path as the REST middleware:
|
||||
// alg-first check, signature verification, revocation table lookup.
|
||||
// - The authorization metadata key is "authorization"; its value must be
|
||||
// "Bearer <token>" (case-insensitive prefix check).
|
||||
// - Credential fields (PasswordHash, TOTPSecret*, PGPassword) are never
|
||||
// included in any RPC response message.
|
||||
// - No credential material is logged by any interceptor.
|
||||
// - TLS is required at the listener level (enforced by cmd/mciassrv).
|
||||
// - Public RPCs (Health, GetPublicKey, ValidateToken) bypass auth.
|
||||
package grpcserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"log/slog"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/peer"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"git.wntrmute.dev/kyle/mcias/internal/config"
|
||||
"git.wntrmute.dev/kyle/mcias/internal/db"
|
||||
"git.wntrmute.dev/kyle/mcias/internal/token"
|
||||
mciasv1 "git.wntrmute.dev/kyle/mcias/gen/mcias/v1"
|
||||
)
|
||||
|
||||
// contextKey is the unexported context key type for this package.
|
||||
type contextKey int
|
||||
|
||||
const (
|
||||
claimsCtxKey contextKey = iota
|
||||
)
|
||||
|
||||
// claimsFromContext retrieves JWT claims injected by the auth interceptor.
|
||||
// Returns nil for unauthenticated (public) RPCs.
|
||||
func claimsFromContext(ctx context.Context) *token.Claims {
|
||||
c, _ := ctx.Value(claimsCtxKey).(*token.Claims)
|
||||
return c
|
||||
}
|
||||
|
||||
// Server holds the shared state for all gRPC service implementations.
|
||||
type Server struct {
|
||||
db *db.DB
|
||||
cfg *config.Config
|
||||
privKey ed25519.PrivateKey
|
||||
pubKey ed25519.PublicKey
|
||||
masterKey []byte
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// New creates a Server with the given dependencies (same as the REST Server).
|
||||
func New(database *db.DB, cfg *config.Config, priv ed25519.PrivateKey, pub ed25519.PublicKey, masterKey []byte, logger *slog.Logger) *Server {
|
||||
return &Server{
|
||||
db: database,
|
||||
cfg: cfg,
|
||||
privKey: priv,
|
||||
pubKey: pub,
|
||||
masterKey: masterKey,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// publicMethods is the set of fully-qualified method names that bypass auth.
|
||||
// These match the gRPC full method path: /<package>.<Service>/<Method>.
|
||||
var publicMethods = map[string]bool{
|
||||
"/mcias.v1.AdminService/Health": true,
|
||||
"/mcias.v1.AdminService/GetPublicKey": true,
|
||||
"/mcias.v1.TokenService/ValidateToken": true,
|
||||
"/mcias.v1.AuthService/Login": true,
|
||||
}
|
||||
|
||||
// GRPCServer builds and returns a configured *grpc.Server with all services
|
||||
// registered and the interceptor chain installed. The returned server uses no
|
||||
// transport credentials; callers are expected to wrap with TLS via GRPCServerWithCreds.
|
||||
func (s *Server) GRPCServer() *grpc.Server {
|
||||
return s.buildServer()
|
||||
}
|
||||
|
||||
// GRPCServerWithCreds builds a *grpc.Server with TLS transport credentials.
|
||||
// This is the method to use when starting a TLS gRPC listener; TLS credentials
|
||||
// must be passed at server-construction time per the gRPC idiom.
|
||||
func (s *Server) GRPCServerWithCreds(creds credentials.TransportCredentials) *grpc.Server {
|
||||
return s.buildServer(grpc.Creds(creds))
|
||||
}
|
||||
|
||||
// buildServer constructs the grpc.Server with optional additional server options.
|
||||
func (s *Server) buildServer(extra ...grpc.ServerOption) *grpc.Server {
|
||||
opts := append(
|
||||
[]grpc.ServerOption{
|
||||
grpc.ChainUnaryInterceptor(
|
||||
s.loggingInterceptor,
|
||||
s.authInterceptor,
|
||||
s.rateLimitInterceptor,
|
||||
),
|
||||
},
|
||||
extra...,
|
||||
)
|
||||
srv := grpc.NewServer(opts...)
|
||||
|
||||
// Register service implementations.
|
||||
mciasv1.RegisterAdminServiceServer(srv, &adminServiceServer{s: s})
|
||||
mciasv1.RegisterAuthServiceServer(srv, &authServiceServer{s: s})
|
||||
mciasv1.RegisterTokenServiceServer(srv, &tokenServiceServer{s: s})
|
||||
mciasv1.RegisterAccountServiceServer(srv, &accountServiceServer{s: s})
|
||||
mciasv1.RegisterCredentialServiceServer(srv, &credentialServiceServer{s: s})
|
||||
|
||||
return srv
|
||||
}
|
||||
|
||||
// loggingInterceptor logs each unary RPC call with method, peer IP, status,
|
||||
// and duration. The authorization metadata value is never logged.
|
||||
func (s *Server) loggingInterceptor(
|
||||
ctx context.Context,
|
||||
req interface{},
|
||||
info *grpc.UnaryServerInfo,
|
||||
handler grpc.UnaryHandler,
|
||||
) (interface{}, error) {
|
||||
start := time.Now()
|
||||
|
||||
peerIP := ""
|
||||
if p, ok := peer.FromContext(ctx); ok {
|
||||
host, _, err := net.SplitHostPort(p.Addr.String())
|
||||
if err == nil {
|
||||
peerIP = host
|
||||
} else {
|
||||
peerIP = p.Addr.String()
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := handler(ctx, req)
|
||||
|
||||
code := codes.OK
|
||||
if err != nil {
|
||||
code = status.Code(err)
|
||||
}
|
||||
|
||||
// Security: authorization metadata is never logged.
|
||||
s.logger.Info("grpc request",
|
||||
"method", info.FullMethod,
|
||||
"peer_ip", peerIP,
|
||||
"code", code.String(),
|
||||
"duration_ms", time.Since(start).Milliseconds(),
|
||||
)
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// authInterceptor validates the Bearer JWT from gRPC metadata and injects
|
||||
// claims into the context. Public methods bypass this check.
|
||||
//
|
||||
// Security: Same validation path as the REST RequireAuth middleware:
|
||||
// 1. Extract "authorization" metadata value (case-insensitive key lookup).
|
||||
// 2. Validate JWT (alg-first, then signature, then expiry/issuer).
|
||||
// 3. Check JTI against revocation table.
|
||||
// 4. Inject claims into context.
|
||||
func (s *Server) authInterceptor(
|
||||
ctx context.Context,
|
||||
req interface{},
|
||||
info *grpc.UnaryServerInfo,
|
||||
handler grpc.UnaryHandler,
|
||||
) (interface{}, error) {
|
||||
if publicMethods[info.FullMethod] {
|
||||
return handler(ctx, req)
|
||||
}
|
||||
|
||||
tokenStr, err := extractBearerFromMD(ctx)
|
||||
if err != nil {
|
||||
// Security: do not reveal whether the header was missing vs. malformed.
|
||||
return nil, status.Error(codes.Unauthenticated, "missing or invalid authorization")
|
||||
}
|
||||
|
||||
claims, err := token.ValidateToken(s.pubKey, tokenStr, s.cfg.Tokens.Issuer)
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Unauthenticated, "invalid or expired token")
|
||||
}
|
||||
|
||||
// Security: check revocation table after signature validation.
|
||||
rec, err := s.db.GetTokenRecord(claims.JTI)
|
||||
if err != nil || rec.IsRevoked() {
|
||||
return nil, status.Error(codes.Unauthenticated, "token has been revoked")
|
||||
}
|
||||
|
||||
ctx = context.WithValue(ctx, claimsCtxKey, claims)
|
||||
return handler(ctx, req)
|
||||
}
|
||||
|
||||
// requireAdmin checks that the claims in context contain the "admin" role.
|
||||
// Called by admin-only RPC handlers after the authInterceptor has run.
|
||||
//
|
||||
// Security: Mirrors the REST RequireRole("admin") middleware check; checked
|
||||
// after auth so claims are always populated when this function is reached.
|
||||
func (s *Server) requireAdmin(ctx context.Context) error {
|
||||
claims := claimsFromContext(ctx)
|
||||
if claims == nil {
|
||||
return status.Error(codes.PermissionDenied, "insufficient privileges")
|
||||
}
|
||||
if !claims.HasRole("admin") {
|
||||
return status.Error(codes.PermissionDenied, "insufficient privileges")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- Rate limiter ---
|
||||
|
||||
// grpcRateLimiter is a per-IP token bucket for gRPC, sharing the same
|
||||
// algorithm as the REST RateLimit middleware.
|
||||
type grpcRateLimiter struct {
|
||||
mu sync.Mutex
|
||||
ips map[string]*grpcRateLimitEntry
|
||||
rps float64
|
||||
burst float64
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
type grpcRateLimitEntry struct {
|
||||
mu sync.Mutex
|
||||
lastSeen time.Time
|
||||
tokens float64
|
||||
}
|
||||
|
||||
func newGRPCRateLimiter(rps float64, burst int) *grpcRateLimiter {
|
||||
l := &grpcRateLimiter{
|
||||
rps: rps,
|
||||
burst: float64(burst),
|
||||
ttl: 10 * time.Minute,
|
||||
ips: make(map[string]*grpcRateLimitEntry),
|
||||
}
|
||||
go l.cleanup()
|
||||
return l
|
||||
}
|
||||
|
||||
func (l *grpcRateLimiter) allow(ip string) bool {
|
||||
l.mu.Lock()
|
||||
entry, ok := l.ips[ip]
|
||||
if !ok {
|
||||
entry = &grpcRateLimitEntry{tokens: l.burst, lastSeen: time.Now()}
|
||||
l.ips[ip] = entry
|
||||
}
|
||||
l.mu.Unlock()
|
||||
|
||||
entry.mu.Lock()
|
||||
defer entry.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
elapsed := now.Sub(entry.lastSeen).Seconds()
|
||||
entry.tokens = minFloat64(l.burst, entry.tokens+elapsed*l.rps)
|
||||
entry.lastSeen = now
|
||||
|
||||
if entry.tokens < 1 {
|
||||
return false
|
||||
}
|
||||
entry.tokens--
|
||||
return true
|
||||
}
|
||||
|
||||
func (l *grpcRateLimiter) cleanup() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
l.mu.Lock()
|
||||
cutoff := time.Now().Add(-l.ttl)
|
||||
for ip, entry := range l.ips {
|
||||
entry.mu.Lock()
|
||||
if entry.lastSeen.Before(cutoff) {
|
||||
delete(l.ips, ip)
|
||||
}
|
||||
entry.mu.Unlock()
|
||||
}
|
||||
l.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// defaultRateLimiter is the server-wide rate limiter instance.
|
||||
// 10 req/s sustained, burst 10 — same parameters as the REST limiter.
|
||||
var defaultRateLimiter = newGRPCRateLimiter(10, 10)
|
||||
|
||||
// rateLimitInterceptor applies per-IP rate limiting using the same token-bucket
|
||||
// parameters as the REST rate limiter (10 req/s, burst 10).
|
||||
func (s *Server) rateLimitInterceptor(
|
||||
ctx context.Context,
|
||||
req interface{},
|
||||
info *grpc.UnaryServerInfo,
|
||||
handler grpc.UnaryHandler,
|
||||
) (interface{}, error) {
|
||||
ip := ""
|
||||
if p, ok := peer.FromContext(ctx); ok {
|
||||
host, _, err := net.SplitHostPort(p.Addr.String())
|
||||
if err == nil {
|
||||
ip = host
|
||||
} else {
|
||||
ip = p.Addr.String()
|
||||
}
|
||||
}
|
||||
|
||||
if ip != "" && !defaultRateLimiter.allow(ip) {
|
||||
return nil, status.Error(codes.ResourceExhausted, "rate limit exceeded")
|
||||
}
|
||||
return handler(ctx, req)
|
||||
}
|
||||
|
||||
// extractBearerFromMD extracts the Bearer token from gRPC metadata.
|
||||
// The key lookup is case-insensitive per gRPC metadata convention (all keys
|
||||
// are lowercased by the framework; we match on "authorization").
|
||||
//
|
||||
// Security: The metadata value is never logged.
|
||||
func extractBearerFromMD(ctx context.Context) (string, error) {
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
return "", status.Error(codes.Unauthenticated, "no metadata")
|
||||
}
|
||||
vals := md.Get("authorization")
|
||||
if len(vals) == 0 {
|
||||
return "", status.Error(codes.Unauthenticated, "missing authorization metadata")
|
||||
}
|
||||
auth := vals[0]
|
||||
const prefix = "bearer "
|
||||
if len(auth) <= len(prefix) || !strings.EqualFold(auth[:len(prefix)], prefix) {
|
||||
return "", status.Error(codes.Unauthenticated, "malformed authorization metadata")
|
||||
}
|
||||
t := auth[len(prefix):]
|
||||
if t == "" {
|
||||
return "", status.Error(codes.Unauthenticated, "empty bearer token")
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// minFloat64 returns the smaller of two float64 values.
|
||||
func minFloat64(a, b float64) float64 {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
654
internal/grpcserver/grpcserver_test.go
Normal file
654
internal/grpcserver/grpcserver_test.go
Normal file
@@ -0,0 +1,654 @@
|
||||
// Tests for the gRPC server package.
|
||||
//
|
||||
// All tests use bufconn so no network sockets are opened. TLS is omitted
|
||||
// at the test layer (insecure credentials); TLS enforcement is the responsibility
|
||||
// of cmd/mciassrv which wraps the listener.
|
||||
package grpcserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/grpc/test/bufconn"
|
||||
|
||||
"git.wntrmute.dev/kyle/mcias/internal/auth"
|
||||
"git.wntrmute.dev/kyle/mcias/internal/config"
|
||||
"git.wntrmute.dev/kyle/mcias/internal/db"
|
||||
"git.wntrmute.dev/kyle/mcias/internal/model"
|
||||
"git.wntrmute.dev/kyle/mcias/internal/token"
|
||||
mciasv1 "git.wntrmute.dev/kyle/mcias/gen/mcias/v1"
|
||||
)
|
||||
|
||||
const (
|
||||
testIssuer = "https://auth.example.com"
|
||||
bufConnSize = 1024 * 1024
|
||||
)
|
||||
|
||||
// testEnv holds all resources for a single test's gRPC server.
|
||||
type testEnv struct {
|
||||
db *db.DB
|
||||
priv ed25519.PrivateKey
|
||||
pub ed25519.PublicKey
|
||||
masterKey []byte
|
||||
cfg *config.Config
|
||||
conn *grpc.ClientConn
|
||||
}
|
||||
|
||||
// newTestEnv spins up an in-process gRPC server using bufconn and returns
|
||||
// a client connection to it. All resources are cleaned up via t.Cleanup.
|
||||
func newTestEnv(t *testing.T) *testEnv {
|
||||
t.Helper()
|
||||
|
||||
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatalf("generate key: %v", err)
|
||||
}
|
||||
|
||||
database, err := db.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("open db: %v", err)
|
||||
}
|
||||
if err := db.Migrate(database); err != nil {
|
||||
t.Fatalf("migrate db: %v", err)
|
||||
}
|
||||
|
||||
masterKey := make([]byte, 32)
|
||||
if _, err := rand.Read(masterKey); err != nil {
|
||||
t.Fatalf("generate master key: %v", err)
|
||||
}
|
||||
|
||||
cfg := config.NewTestConfig(testIssuer)
|
||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
|
||||
srv := New(database, cfg, priv, pub, masterKey, logger)
|
||||
grpcSrv := srv.GRPCServer()
|
||||
|
||||
lis := bufconn.Listen(bufConnSize)
|
||||
go func() {
|
||||
if err := grpcSrv.Serve(lis); err != nil {
|
||||
// Serve returns when the listener is closed; ignore that error.
|
||||
}
|
||||
}()
|
||||
|
||||
conn, err := grpc.NewClient(
|
||||
"passthrough://bufnet",
|
||||
grpc.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) {
|
||||
return lis.DialContext(ctx)
|
||||
}),
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("dial bufconn: %v", err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
_ = conn.Close()
|
||||
grpcSrv.Stop()
|
||||
_ = lis.Close()
|
||||
_ = database.Close()
|
||||
})
|
||||
|
||||
return &testEnv{
|
||||
db: database,
|
||||
priv: priv,
|
||||
pub: pub,
|
||||
masterKey: masterKey,
|
||||
cfg: cfg,
|
||||
conn: conn,
|
||||
}
|
||||
}
|
||||
|
||||
// createHumanAccount creates a human account with the given username and
|
||||
// a fixed password "testpass123" directly in the database.
|
||||
func (e *testEnv) createHumanAccount(t *testing.T, username string) *model.Account {
|
||||
t.Helper()
|
||||
hash, err := auth.HashPassword("testpass123", auth.ArgonParams{Time: 3, Memory: 65536, Threads: 4})
|
||||
if err != nil {
|
||||
t.Fatalf("hash password: %v", err)
|
||||
}
|
||||
acct, err := e.db.CreateAccount(username, model.AccountTypeHuman, hash)
|
||||
if err != nil {
|
||||
t.Fatalf("create account: %v", err)
|
||||
}
|
||||
return acct
|
||||
}
|
||||
|
||||
// issueAdminToken creates an account with admin role, issues a JWT, tracks it in
|
||||
// the DB, and returns the token string.
|
||||
func (e *testEnv) issueAdminToken(t *testing.T, username string) (string, *model.Account) {
|
||||
t.Helper()
|
||||
acct := e.createHumanAccount(t, username)
|
||||
if err := e.db.GrantRole(acct.ID, "admin", nil); err != nil {
|
||||
t.Fatalf("grant admin role: %v", err)
|
||||
}
|
||||
tokenStr, claims, err := token.IssueToken(e.priv, testIssuer, acct.UUID, []string{"admin"}, time.Hour)
|
||||
if err != nil {
|
||||
t.Fatalf("issue token: %v", err)
|
||||
}
|
||||
if err := e.db.TrackToken(claims.JTI, acct.ID, claims.IssuedAt, claims.ExpiresAt); err != nil {
|
||||
t.Fatalf("track token: %v", err)
|
||||
}
|
||||
return tokenStr, acct
|
||||
}
|
||||
|
||||
// issueUserToken issues a regular (non-admin) token for an account.
|
||||
func (e *testEnv) issueUserToken(t *testing.T, acct *model.Account) string {
|
||||
t.Helper()
|
||||
tokenStr, claims, err := token.IssueToken(e.priv, testIssuer, acct.UUID, []string{}, time.Hour)
|
||||
if err != nil {
|
||||
t.Fatalf("issue token: %v", err)
|
||||
}
|
||||
if err := e.db.TrackToken(claims.JTI, acct.ID, claims.IssuedAt, claims.ExpiresAt); err != nil {
|
||||
t.Fatalf("track token: %v", err)
|
||||
}
|
||||
return tokenStr
|
||||
}
|
||||
|
||||
// authCtx returns a context with the Bearer token as gRPC metadata.
|
||||
func authCtx(tok string) context.Context {
|
||||
return metadata.AppendToOutgoingContext(context.Background(),
|
||||
"authorization", "Bearer "+tok)
|
||||
}
|
||||
|
||||
// ---- AdminService tests ----
|
||||
|
||||
// TestHealth verifies the public Health RPC requires no auth and returns "ok".
|
||||
func TestHealth(t *testing.T) {
|
||||
e := newTestEnv(t)
|
||||
cl := mciasv1.NewAdminServiceClient(e.conn)
|
||||
|
||||
resp, err := cl.Health(context.Background(), &mciasv1.HealthRequest{})
|
||||
if err != nil {
|
||||
t.Fatalf("Health: %v", err)
|
||||
}
|
||||
if resp.Status != "ok" {
|
||||
t.Errorf("Health: got status %q, want %q", resp.Status, "ok")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetPublicKey verifies the public GetPublicKey RPC returns JWK fields.
|
||||
func TestGetPublicKey(t *testing.T) {
|
||||
e := newTestEnv(t)
|
||||
cl := mciasv1.NewAdminServiceClient(e.conn)
|
||||
|
||||
resp, err := cl.GetPublicKey(context.Background(), &mciasv1.GetPublicKeyRequest{})
|
||||
if err != nil {
|
||||
t.Fatalf("GetPublicKey: %v", err)
|
||||
}
|
||||
if resp.Kty != "OKP" {
|
||||
t.Errorf("GetPublicKey: kty=%q, want OKP", resp.Kty)
|
||||
}
|
||||
if resp.Crv != "Ed25519" {
|
||||
t.Errorf("GetPublicKey: crv=%q, want Ed25519", resp.Crv)
|
||||
}
|
||||
if resp.X == "" {
|
||||
t.Error("GetPublicKey: x field is empty")
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Auth interceptor tests ----
|
||||
|
||||
// TestAuthRequired verifies that protected RPCs reject calls with no token.
|
||||
func TestAuthRequired(t *testing.T) {
|
||||
e := newTestEnv(t)
|
||||
cl := mciasv1.NewAuthServiceClient(e.conn)
|
||||
|
||||
// Logout requires auth; call without any metadata.
|
||||
_, err := cl.Logout(context.Background(), &mciasv1.LogoutRequest{})
|
||||
if err == nil {
|
||||
t.Fatal("Logout without token: expected error, got nil")
|
||||
}
|
||||
st, ok := status.FromError(err)
|
||||
if !ok {
|
||||
t.Fatalf("not a gRPC status error: %v", err)
|
||||
}
|
||||
if st.Code() != codes.Unauthenticated {
|
||||
t.Errorf("Logout without token: got code %v, want Unauthenticated", st.Code())
|
||||
}
|
||||
}
|
||||
|
||||
// TestInvalidTokenRejected verifies that a malformed token is rejected.
|
||||
func TestInvalidTokenRejected(t *testing.T) {
|
||||
e := newTestEnv(t)
|
||||
cl := mciasv1.NewAuthServiceClient(e.conn)
|
||||
|
||||
ctx := authCtx("not.a.valid.jwt")
|
||||
_, err := cl.Logout(ctx, &mciasv1.LogoutRequest{})
|
||||
if err == nil {
|
||||
t.Fatal("Logout with invalid token: expected error, got nil")
|
||||
}
|
||||
st, _ := status.FromError(err)
|
||||
if st.Code() != codes.Unauthenticated {
|
||||
t.Errorf("got code %v, want Unauthenticated", st.Code())
|
||||
}
|
||||
}
|
||||
|
||||
// TestRevokedTokenRejected verifies that a revoked token cannot be used.
|
||||
func TestRevokedTokenRejected(t *testing.T) {
|
||||
e := newTestEnv(t)
|
||||
|
||||
acct := e.createHumanAccount(t, "revokeduser")
|
||||
tokenStr, claims, err := token.IssueToken(e.priv, testIssuer, acct.UUID, []string{}, time.Hour)
|
||||
if err != nil {
|
||||
t.Fatalf("issue token: %v", err)
|
||||
}
|
||||
if err := e.db.TrackToken(claims.JTI, acct.ID, claims.IssuedAt, claims.ExpiresAt); err != nil {
|
||||
t.Fatalf("track token: %v", err)
|
||||
}
|
||||
// Revoke it before using it.
|
||||
if err := e.db.RevokeToken(claims.JTI, "test"); err != nil {
|
||||
t.Fatalf("revoke token: %v", err)
|
||||
}
|
||||
|
||||
cl := mciasv1.NewAuthServiceClient(e.conn)
|
||||
ctx := authCtx(tokenStr)
|
||||
_, err = cl.Logout(ctx, &mciasv1.LogoutRequest{})
|
||||
if err == nil {
|
||||
t.Fatal("Logout with revoked token: expected error, got nil")
|
||||
}
|
||||
st, _ := status.FromError(err)
|
||||
if st.Code() != codes.Unauthenticated {
|
||||
t.Errorf("got code %v, want Unauthenticated", st.Code())
|
||||
}
|
||||
}
|
||||
|
||||
// TestNonAdminCannotCallAdminRPC verifies that a regular user is denied access
|
||||
// to admin-only RPCs (PermissionDenied, not Unauthenticated).
|
||||
func TestNonAdminCannotCallAdminRPC(t *testing.T) {
|
||||
e := newTestEnv(t)
|
||||
acct := e.createHumanAccount(t, "regularuser")
|
||||
tok := e.issueUserToken(t, acct)
|
||||
|
||||
cl := mciasv1.NewAccountServiceClient(e.conn)
|
||||
ctx := authCtx(tok)
|
||||
_, err := cl.ListAccounts(ctx, &mciasv1.ListAccountsRequest{})
|
||||
if err == nil {
|
||||
t.Fatal("ListAccounts as non-admin: expected error, got nil")
|
||||
}
|
||||
st, _ := status.FromError(err)
|
||||
if st.Code() != codes.PermissionDenied {
|
||||
t.Errorf("got code %v, want PermissionDenied", st.Code())
|
||||
}
|
||||
}
|
||||
|
||||
// ---- AuthService tests ----
|
||||
|
||||
// TestLogin verifies successful login via gRPC.
|
||||
func TestLogin(t *testing.T) {
|
||||
e := newTestEnv(t)
|
||||
_ = e.createHumanAccount(t, "loginuser")
|
||||
|
||||
cl := mciasv1.NewAuthServiceClient(e.conn)
|
||||
resp, err := cl.Login(context.Background(), &mciasv1.LoginRequest{
|
||||
Username: "loginuser",
|
||||
Password: "testpass123",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Login: %v", err)
|
||||
}
|
||||
if resp.Token == "" {
|
||||
t.Error("Login: returned empty token")
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoginWrongPassword verifies that wrong-password returns Unauthenticated.
|
||||
func TestLoginWrongPassword(t *testing.T) {
|
||||
e := newTestEnv(t)
|
||||
_ = e.createHumanAccount(t, "loginuser2")
|
||||
|
||||
cl := mciasv1.NewAuthServiceClient(e.conn)
|
||||
_, err := cl.Login(context.Background(), &mciasv1.LoginRequest{
|
||||
Username: "loginuser2",
|
||||
Password: "wrongpassword",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("Login with wrong password: expected error, got nil")
|
||||
}
|
||||
st, _ := status.FromError(err)
|
||||
if st.Code() != codes.Unauthenticated {
|
||||
t.Errorf("got code %v, want Unauthenticated", st.Code())
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoginUnknownUser verifies that unknown-user returns the same error as
|
||||
// wrong-password (prevents user enumeration).
|
||||
func TestLoginUnknownUser(t *testing.T) {
|
||||
e := newTestEnv(t)
|
||||
cl := mciasv1.NewAuthServiceClient(e.conn)
|
||||
_, err := cl.Login(context.Background(), &mciasv1.LoginRequest{
|
||||
Username: "nosuchuser",
|
||||
Password: "whatever",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("Login for unknown user: expected error, got nil")
|
||||
}
|
||||
st, _ := status.FromError(err)
|
||||
if st.Code() != codes.Unauthenticated {
|
||||
t.Errorf("got code %v, want Unauthenticated", st.Code())
|
||||
}
|
||||
}
|
||||
|
||||
// TestLogout verifies that a valid token can log itself out.
|
||||
func TestLogout(t *testing.T) {
|
||||
e := newTestEnv(t)
|
||||
acct := e.createHumanAccount(t, "logoutuser")
|
||||
tok := e.issueUserToken(t, acct)
|
||||
|
||||
cl := mciasv1.NewAuthServiceClient(e.conn)
|
||||
ctx := authCtx(tok)
|
||||
_, err := cl.Logout(ctx, &mciasv1.LogoutRequest{})
|
||||
if err != nil {
|
||||
t.Fatalf("Logout: %v", err)
|
||||
}
|
||||
|
||||
// Second call with the same token must fail (token now revoked).
|
||||
_, err = cl.Logout(authCtx(tok), &mciasv1.LogoutRequest{})
|
||||
if err == nil {
|
||||
t.Fatal("second Logout with revoked token: expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRenewToken verifies that a valid token can be renewed.
|
||||
func TestRenewToken(t *testing.T) {
|
||||
e := newTestEnv(t)
|
||||
acct := e.createHumanAccount(t, "renewuser")
|
||||
tok := e.issueUserToken(t, acct)
|
||||
|
||||
cl := mciasv1.NewAuthServiceClient(e.conn)
|
||||
ctx := authCtx(tok)
|
||||
resp, err := cl.RenewToken(ctx, &mciasv1.RenewTokenRequest{})
|
||||
if err != nil {
|
||||
t.Fatalf("RenewToken: %v", err)
|
||||
}
|
||||
if resp.Token == "" {
|
||||
t.Error("RenewToken: returned empty token")
|
||||
}
|
||||
if resp.Token == tok {
|
||||
t.Error("RenewToken: returned same token instead of a fresh one")
|
||||
}
|
||||
}
|
||||
|
||||
// ---- TokenService tests ----
|
||||
|
||||
// TestValidateToken verifies the public ValidateToken RPC returns valid=true for
|
||||
// a good token and valid=false for a garbage input (no Unauthenticated error).
|
||||
func TestValidateToken(t *testing.T) {
|
||||
e := newTestEnv(t)
|
||||
acct := e.createHumanAccount(t, "validateuser")
|
||||
tok := e.issueUserToken(t, acct)
|
||||
|
||||
cl := mciasv1.NewTokenServiceClient(e.conn)
|
||||
|
||||
// Valid token.
|
||||
resp, err := cl.ValidateToken(context.Background(), &mciasv1.ValidateTokenRequest{Token: tok})
|
||||
if err != nil {
|
||||
t.Fatalf("ValidateToken (good): %v", err)
|
||||
}
|
||||
if !resp.Valid {
|
||||
t.Error("ValidateToken: got valid=false for a good token")
|
||||
}
|
||||
|
||||
// Invalid token: should return valid=false, not an RPC error.
|
||||
resp, err = cl.ValidateToken(context.Background(), &mciasv1.ValidateTokenRequest{Token: "garbage"})
|
||||
if err != nil {
|
||||
t.Fatalf("ValidateToken (bad): unexpected RPC error: %v", err)
|
||||
}
|
||||
if resp.Valid {
|
||||
t.Error("ValidateToken: got valid=true for a garbage token")
|
||||
}
|
||||
}
|
||||
|
||||
// TestIssueServiceTokenRequiresAdmin verifies that non-admin cannot issue tokens.
|
||||
func TestIssueServiceTokenRequiresAdmin(t *testing.T) {
|
||||
e := newTestEnv(t)
|
||||
acct := e.createHumanAccount(t, "notadmin")
|
||||
tok := e.issueUserToken(t, acct)
|
||||
|
||||
cl := mciasv1.NewTokenServiceClient(e.conn)
|
||||
_, err := cl.IssueServiceToken(authCtx(tok), &mciasv1.IssueServiceTokenRequest{AccountId: acct.UUID})
|
||||
if err == nil {
|
||||
t.Fatal("IssueServiceToken as non-admin: expected error, got nil")
|
||||
}
|
||||
st, _ := status.FromError(err)
|
||||
if st.Code() != codes.PermissionDenied {
|
||||
t.Errorf("got code %v, want PermissionDenied", st.Code())
|
||||
}
|
||||
}
|
||||
|
||||
// ---- AccountService tests ----
|
||||
|
||||
// TestListAccountsAdminOnly verifies that ListAccounts requires admin role.
|
||||
func TestListAccountsAdminOnly(t *testing.T) {
|
||||
e := newTestEnv(t)
|
||||
|
||||
// Non-admin call.
|
||||
acct := e.createHumanAccount(t, "nonadmin")
|
||||
tok := e.issueUserToken(t, acct)
|
||||
|
||||
cl := mciasv1.NewAccountServiceClient(e.conn)
|
||||
_, err := cl.ListAccounts(authCtx(tok), &mciasv1.ListAccountsRequest{})
|
||||
if err == nil {
|
||||
t.Fatal("ListAccounts as non-admin: expected error, got nil")
|
||||
}
|
||||
st, _ := status.FromError(err)
|
||||
if st.Code() != codes.PermissionDenied {
|
||||
t.Errorf("got code %v, want PermissionDenied", st.Code())
|
||||
}
|
||||
|
||||
// Admin call.
|
||||
adminTok, _ := e.issueAdminToken(t, "adminuser")
|
||||
resp, err := cl.ListAccounts(authCtx(adminTok), &mciasv1.ListAccountsRequest{})
|
||||
if err != nil {
|
||||
t.Fatalf("ListAccounts as admin: %v", err)
|
||||
}
|
||||
if len(resp.Accounts) == 0 {
|
||||
t.Error("ListAccounts: expected at least one account")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCreateAndGetAccount exercises the full create→get lifecycle.
|
||||
func TestCreateAndGetAccount(t *testing.T) {
|
||||
e := newTestEnv(t)
|
||||
adminTok, _ := e.issueAdminToken(t, "admin2")
|
||||
|
||||
cl := mciasv1.NewAccountServiceClient(e.conn)
|
||||
|
||||
createResp, err := cl.CreateAccount(authCtx(adminTok), &mciasv1.CreateAccountRequest{
|
||||
Username: "newuser",
|
||||
Password: "securepassword1",
|
||||
AccountType: "human",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CreateAccount: %v", err)
|
||||
}
|
||||
if createResp.Account == nil {
|
||||
t.Fatal("CreateAccount: returned nil account")
|
||||
}
|
||||
if createResp.Account.Id == "" {
|
||||
t.Error("CreateAccount: returned empty UUID")
|
||||
}
|
||||
|
||||
// Security: credential fields must not appear in the response.
|
||||
// The Account proto has no password_hash or totp_secret fields by design.
|
||||
// Verify via GetAccount too.
|
||||
getResp, err := cl.GetAccount(authCtx(adminTok), &mciasv1.GetAccountRequest{Id: createResp.Account.Id})
|
||||
if err != nil {
|
||||
t.Fatalf("GetAccount: %v", err)
|
||||
}
|
||||
if getResp.Account.Username != "newuser" {
|
||||
t.Errorf("GetAccount: username=%q, want %q", getResp.Account.Username, "newuser")
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateAccount verifies that account status can be changed.
|
||||
func TestUpdateAccount(t *testing.T) {
|
||||
e := newTestEnv(t)
|
||||
adminTok, _ := e.issueAdminToken(t, "admin3")
|
||||
|
||||
cl := mciasv1.NewAccountServiceClient(e.conn)
|
||||
|
||||
createResp, err := cl.CreateAccount(authCtx(adminTok), &mciasv1.CreateAccountRequest{
|
||||
Username: "updateme",
|
||||
Password: "pass12345",
|
||||
AccountType: "human",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CreateAccount: %v", err)
|
||||
}
|
||||
id := createResp.Account.Id
|
||||
|
||||
_, err = cl.UpdateAccount(authCtx(adminTok), &mciasv1.UpdateAccountRequest{
|
||||
Id: id,
|
||||
Status: "inactive",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("UpdateAccount: %v", err)
|
||||
}
|
||||
|
||||
getResp, err := cl.GetAccount(authCtx(adminTok), &mciasv1.GetAccountRequest{Id: id})
|
||||
if err != nil {
|
||||
t.Fatalf("GetAccount after update: %v", err)
|
||||
}
|
||||
if getResp.Account.Status != "inactive" {
|
||||
t.Errorf("after update: status=%q, want inactive", getResp.Account.Status)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSetAndGetRoles verifies that roles can be assigned and retrieved.
|
||||
func TestSetAndGetRoles(t *testing.T) {
|
||||
e := newTestEnv(t)
|
||||
adminTok, _ := e.issueAdminToken(t, "admin4")
|
||||
|
||||
cl := mciasv1.NewAccountServiceClient(e.conn)
|
||||
|
||||
createResp, err := cl.CreateAccount(authCtx(adminTok), &mciasv1.CreateAccountRequest{
|
||||
Username: "roleuser",
|
||||
Password: "pass12345",
|
||||
AccountType: "human",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CreateAccount: %v", err)
|
||||
}
|
||||
id := createResp.Account.Id
|
||||
|
||||
_, err = cl.SetRoles(authCtx(adminTok), &mciasv1.SetRolesRequest{
|
||||
Id: id,
|
||||
Roles: []string{"editor", "viewer"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SetRoles: %v", err)
|
||||
}
|
||||
|
||||
getRolesResp, err := cl.GetRoles(authCtx(adminTok), &mciasv1.GetRolesRequest{Id: id})
|
||||
if err != nil {
|
||||
t.Fatalf("GetRoles: %v", err)
|
||||
}
|
||||
if len(getRolesResp.Roles) != 2 {
|
||||
t.Errorf("GetRoles: got %d roles, want 2", len(getRolesResp.Roles))
|
||||
}
|
||||
}
|
||||
|
||||
// ---- CredentialService tests ----
|
||||
|
||||
// TestSetAndGetPGCreds verifies that PG credentials can be stored and retrieved.
|
||||
// Security: the password is decrypted only in the GetPGCreds response; it is
|
||||
// never present in account list or other responses.
|
||||
func TestSetAndGetPGCreds(t *testing.T) {
|
||||
e := newTestEnv(t)
|
||||
adminTok, _ := e.issueAdminToken(t, "admin5")
|
||||
|
||||
// Create a system account to hold the PG credentials.
|
||||
accCl := mciasv1.NewAccountServiceClient(e.conn)
|
||||
createResp, err := accCl.CreateAccount(authCtx(adminTok), &mciasv1.CreateAccountRequest{
|
||||
Username: "sysaccount",
|
||||
AccountType: "system",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CreateAccount: %v", err)
|
||||
}
|
||||
accountID := createResp.Account.Id
|
||||
|
||||
credCl := mciasv1.NewCredentialServiceClient(e.conn)
|
||||
|
||||
_, err = credCl.SetPGCreds(authCtx(adminTok), &mciasv1.SetPGCredsRequest{
|
||||
Id: accountID,
|
||||
Creds: &mciasv1.PGCreds{
|
||||
Host: "db.example.com",
|
||||
Port: 5432,
|
||||
Database: "mydb",
|
||||
Username: "myuser",
|
||||
Password: "supersecret",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SetPGCreds: %v", err)
|
||||
}
|
||||
|
||||
getResp, err := credCl.GetPGCreds(authCtx(adminTok), &mciasv1.GetPGCredsRequest{Id: accountID})
|
||||
if err != nil {
|
||||
t.Fatalf("GetPGCreds: %v", err)
|
||||
}
|
||||
if getResp.Creds == nil {
|
||||
t.Fatal("GetPGCreds: returned nil creds")
|
||||
}
|
||||
if getResp.Creds.Password != "supersecret" {
|
||||
t.Errorf("GetPGCreds: password=%q, want supersecret", getResp.Creds.Password)
|
||||
}
|
||||
if getResp.Creds.Host != "db.example.com" {
|
||||
t.Errorf("GetPGCreds: host=%q, want db.example.com", getResp.Creds.Host)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPGCredsRequireAdmin verifies that non-admin cannot access PG creds.
|
||||
func TestPGCredsRequireAdmin(t *testing.T) {
|
||||
e := newTestEnv(t)
|
||||
acct := e.createHumanAccount(t, "notadmin2")
|
||||
tok := e.issueUserToken(t, acct)
|
||||
|
||||
cl := mciasv1.NewCredentialServiceClient(e.conn)
|
||||
_, err := cl.GetPGCreds(authCtx(tok), &mciasv1.GetPGCredsRequest{Id: acct.UUID})
|
||||
if err == nil {
|
||||
t.Fatal("GetPGCreds as non-admin: expected error, got nil")
|
||||
}
|
||||
st, _ := status.FromError(err)
|
||||
if st.Code() != codes.PermissionDenied {
|
||||
t.Errorf("got code %v, want PermissionDenied", st.Code())
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Security: credential fields absent from responses ----
|
||||
|
||||
// TestCredentialFieldsAbsentFromAccountResponse verifies that account responses
|
||||
// never include password_hash or totp_secret fields. The Account proto message
|
||||
// does not define these fields, providing compile-time enforcement. This test
|
||||
// provides a runtime confirmation by checking the returned Account struct.
|
||||
func TestCredentialFieldsAbsentFromAccountResponse(t *testing.T) {
|
||||
e := newTestEnv(t)
|
||||
adminTok, _ := e.issueAdminToken(t, "admin6")
|
||||
|
||||
cl := mciasv1.NewAccountServiceClient(e.conn)
|
||||
resp, err := cl.ListAccounts(authCtx(adminTok), &mciasv1.ListAccountsRequest{})
|
||||
if err != nil {
|
||||
t.Fatalf("ListAccounts: %v", err)
|
||||
}
|
||||
for _, a := range resp.Accounts {
|
||||
// Account proto only has: id, username, account_type, status,
|
||||
// totp_enabled, created_at, updated_at. No credential fields.
|
||||
// This loop body intentionally checks the fields that exist;
|
||||
// the absence of credential fields is enforced by the proto definition.
|
||||
if a.Id == "" {
|
||||
t.Error("account has empty id")
|
||||
}
|
||||
}
|
||||
}
|
||||
122
internal/grpcserver/tokenservice.go
Normal file
122
internal/grpcserver/tokenservice.go
Normal file
@@ -0,0 +1,122 @@
|
||||
// tokenServiceServer implements mciasv1.TokenServiceServer.
|
||||
package grpcserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"git.wntrmute.dev/kyle/mcias/internal/db"
|
||||
"git.wntrmute.dev/kyle/mcias/internal/model"
|
||||
"git.wntrmute.dev/kyle/mcias/internal/token"
|
||||
mciasv1 "git.wntrmute.dev/kyle/mcias/gen/mcias/v1"
|
||||
)
|
||||
|
||||
type tokenServiceServer struct {
|
||||
mciasv1.UnimplementedTokenServiceServer
|
||||
s *Server
|
||||
}
|
||||
|
||||
// ValidateToken validates a JWT and returns its claims.
|
||||
// Public RPC — no auth required.
|
||||
//
|
||||
// Security: Always returns a valid=false response on any error; never
|
||||
// exposes which specific validation step failed.
|
||||
func (t *tokenServiceServer) ValidateToken(_ context.Context, req *mciasv1.ValidateTokenRequest) (*mciasv1.ValidateTokenResponse, error) {
|
||||
tokenStr := req.Token
|
||||
if tokenStr == "" {
|
||||
return &mciasv1.ValidateTokenResponse{Valid: false}, nil
|
||||
}
|
||||
|
||||
claims, err := token.ValidateToken(t.s.pubKey, tokenStr, t.s.cfg.Tokens.Issuer)
|
||||
if err != nil {
|
||||
return &mciasv1.ValidateTokenResponse{Valid: false}, nil
|
||||
}
|
||||
|
||||
rec, err := t.s.db.GetTokenRecord(claims.JTI)
|
||||
if err != nil || rec.IsRevoked() {
|
||||
return &mciasv1.ValidateTokenResponse{Valid: false}, nil
|
||||
}
|
||||
|
||||
return &mciasv1.ValidateTokenResponse{
|
||||
Valid: true,
|
||||
Subject: claims.Subject,
|
||||
Roles: claims.Roles,
|
||||
ExpiresAt: timestamppb.New(claims.ExpiresAt),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// IssueServiceToken issues a token for a system account. Admin only.
|
||||
func (ts *tokenServiceServer) IssueServiceToken(ctx context.Context, req *mciasv1.IssueServiceTokenRequest) (*mciasv1.IssueServiceTokenResponse, error) {
|
||||
if err := ts.s.requireAdmin(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if req.AccountId == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "account_id is required")
|
||||
}
|
||||
|
||||
acct, err := ts.s.db.GetAccountByUUID(req.AccountId)
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.NotFound, "account not found")
|
||||
}
|
||||
if acct.AccountType != model.AccountTypeSystem {
|
||||
return nil, status.Error(codes.InvalidArgument, "token issue is only for system accounts")
|
||||
}
|
||||
|
||||
tokenStr, claims, err := token.IssueToken(ts.s.privKey, ts.s.cfg.Tokens.Issuer, acct.UUID, nil, ts.s.cfg.ServiceExpiry())
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Internal, "internal error")
|
||||
}
|
||||
|
||||
// Revoke existing system token if any.
|
||||
existing, err := ts.s.db.GetSystemToken(acct.ID)
|
||||
if err == nil && existing != nil {
|
||||
_ = ts.s.db.RevokeToken(existing.JTI, "rotated")
|
||||
}
|
||||
|
||||
if err := ts.s.db.TrackToken(claims.JTI, acct.ID, claims.IssuedAt, claims.ExpiresAt); err != nil {
|
||||
return nil, status.Error(codes.Internal, "internal error")
|
||||
}
|
||||
if err := ts.s.db.SetSystemToken(acct.ID, claims.JTI, claims.ExpiresAt); err != nil {
|
||||
return nil, status.Error(codes.Internal, "internal error")
|
||||
}
|
||||
|
||||
actorClaims := claimsFromContext(ctx)
|
||||
var actorID *int64
|
||||
if actorClaims != nil {
|
||||
if a, err := ts.s.db.GetAccountByUUID(actorClaims.Subject); err == nil {
|
||||
actorID = &a.ID
|
||||
}
|
||||
}
|
||||
ts.s.db.WriteAuditEvent(model.EventTokenIssued, actorID, &acct.ID, peerIP(ctx), //nolint:errcheck
|
||||
fmt.Sprintf(`{"jti":%q}`, claims.JTI))
|
||||
|
||||
return &mciasv1.IssueServiceTokenResponse{
|
||||
Token: tokenStr,
|
||||
ExpiresAt: timestamppb.New(claims.ExpiresAt),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RevokeToken revokes a token by JTI. Admin only.
|
||||
func (ts *tokenServiceServer) RevokeToken(ctx context.Context, req *mciasv1.RevokeTokenRequest) (*mciasv1.RevokeTokenResponse, error) {
|
||||
if err := ts.s.requireAdmin(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if req.Jti == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "jti is required")
|
||||
}
|
||||
|
||||
if err := ts.s.db.RevokeToken(req.Jti, "admin revocation"); err != nil {
|
||||
if err == db.ErrNotFound {
|
||||
return nil, status.Error(codes.NotFound, "token not found or already revoked")
|
||||
}
|
||||
return nil, status.Error(codes.Internal, "internal error")
|
||||
}
|
||||
|
||||
ts.s.db.WriteAuditEvent(model.EventTokenRevoked, nil, nil, peerIP(ctx), //nolint:errcheck
|
||||
fmt.Sprintf(`{"jti":%q}`, req.Jti))
|
||||
return &mciasv1.RevokeTokenResponse{}, nil
|
||||
}
|
||||
Reference in New Issue
Block a user