Files
mcias/cmd/mciasdb/mciasdb_test.go
Kyle Isom ec7c966ad2 trusted proxy, TOTP replay protection, new tests
- Trusted proxy config option for proxy-aware IP extraction
  used by rate limiting and audit logs; validates proxy IP
  before trusting X-Forwarded-For / X-Real-IP headers
- TOTP replay protection via counter-based validation to
  reject reused codes within the same time step (±30s)
- RateLimit middleware updated to extract client IP from
  proxy headers without IP spoofing risk
- New tests for ClientIP proxy logic (spoofed headers,
  fallback) and extended rate-limit proxy coverage
- HTMX error banner script integrated into web UI base
- .gitignore updated for mciasdb build artifact

Security: resolves CRIT-01 (TOTP replay attack) and
DEF-03 (proxy-unaware rate limiting); gRPC TOTP
enrollment aligned with REST via StorePendingTOTP

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-12 17:44:01 -07:00

441 lines
11 KiB
Go

package main
import (
"bytes"
"fmt"
"io"
"os"
"strings"
"testing"
"time"
"git.wntrmute.dev/kyle/mcias/internal/crypto"
"git.wntrmute.dev/kyle/mcias/internal/db"
"git.wntrmute.dev/kyle/mcias/internal/model"
)
// newTestTool creates a tool backed by an in-memory SQLite database with a
// freshly generated master key. The database is migrated to the latest schema.
func newTestTool(t *testing.T) *tool {
t.Helper()
database, err := db.Open(":memory:")
if err != nil {
t.Fatalf("open test DB: %v", err)
}
if err := db.Migrate(database); err != nil {
t.Fatalf("migrate test DB: %v", err)
}
t.Cleanup(func() { _ = database.Close() })
// Use a random 32-byte master key for encryption tests.
masterKey, err := crypto.RandomBytes(32)
if err != nil {
t.Fatalf("generate master key: %v", err)
}
return &tool{db: database, masterKey: masterKey}
}
// captureStdout captures stdout output during fn execution.
func captureStdout(t *testing.T, fn func()) string {
t.Helper()
orig := os.Stdout
r, w, err := os.Pipe()
if err != nil {
t.Fatalf("create pipe: %v", err)
}
os.Stdout = w
fn()
_ = w.Close()
os.Stdout = orig
var buf bytes.Buffer
if _, err := io.Copy(&buf, r); err != nil {
t.Fatalf("copy stdout: %v", err)
}
return buf.String()
}
// ---- schema tests ----
func TestSchemaVerifyUpToDate(t *testing.T) {
tool := newTestTool(t)
// Capture output; schemaVerify calls exitCode1 if migrations pending,
// but with a freshly migrated DB it should print "up-to-date".
out := captureStdout(t, tool.schemaVerify)
if !strings.Contains(out, "up-to-date") {
t.Errorf("expected 'up-to-date' in output, got: %s", out)
}
}
// ---- account tests ----
func TestAccountListEmpty(t *testing.T) {
tool := newTestTool(t)
out := captureStdout(t, tool.accountList)
if !strings.Contains(out, "no accounts") {
t.Errorf("expected 'no accounts' in output, got: %s", out)
}
}
func TestAccountCreateAndList(t *testing.T) {
tool := newTestTool(t)
// Create via DB method directly (accountCreate reads args via flags so
// we test the DB path to avoid os.Exit on parse error).
a, err := tool.db.CreateAccount("testuser", model.AccountTypeHuman, "")
if err != nil {
t.Fatalf("create account: %v", err)
}
if a.UUID == "" {
t.Error("expected UUID to be set")
}
out := captureStdout(t, tool.accountList)
if !strings.Contains(out, "testuser") {
t.Errorf("expected 'testuser' in list output, got: %s", out)
}
}
func TestAccountGetByUUID(t *testing.T) {
tool := newTestTool(t)
a, err := tool.db.CreateAccount("getuser", model.AccountTypeSystem, "")
if err != nil {
t.Fatalf("create account: %v", err)
}
out := captureStdout(t, func() {
tool.accountGet([]string{"--id", a.UUID})
})
if !strings.Contains(out, "getuser") {
t.Errorf("expected 'getuser' in get output, got: %s", out)
}
if !strings.Contains(out, "system") {
t.Errorf("expected 'system' in get output, got: %s", out)
}
}
func TestAccountSetStatus(t *testing.T) {
tool := newTestTool(t)
a, err := tool.db.CreateAccount("statususer", model.AccountTypeHuman, "hash")
if err != nil {
t.Fatalf("create account: %v", err)
}
captureStdout(t, func() {
tool.accountSetStatus([]string{"--id", a.UUID, "--status", "inactive"})
})
updated, err := tool.db.GetAccountByUUID(a.UUID)
if err != nil {
t.Fatalf("get account after update: %v", err)
}
if updated.Status != model.AccountStatusInactive {
t.Errorf("expected inactive status, got %s", updated.Status)
}
}
func TestAccountResetTOTP(t *testing.T) {
tool := newTestTool(t)
a, err := tool.db.CreateAccount("totpuser", model.AccountTypeHuman, "hash")
if err != nil {
t.Fatalf("create account: %v", err)
}
// Set TOTP fields.
if err := tool.db.SetTOTP(a.ID, []byte("enc"), []byte("nonce")); err != nil {
t.Fatalf("set TOTP: %v", err)
}
captureStdout(t, func() {
tool.accountResetTOTP([]string{"--id", a.UUID})
})
updated, err := tool.db.GetAccountByUUID(a.UUID)
if err != nil {
t.Fatalf("get account after reset: %v", err)
}
if updated.TOTPRequired {
t.Error("expected TOTP to be cleared")
}
if len(updated.TOTPSecretEnc) != 0 {
t.Error("expected TOTP secret to be cleared")
}
}
// ---- role tests ----
func TestRoleGrantAndList(t *testing.T) {
tool := newTestTool(t)
a, err := tool.db.CreateAccount("roleuser", model.AccountTypeHuman, "hash")
if err != nil {
t.Fatalf("create account: %v", err)
}
captureStdout(t, func() {
tool.roleGrant([]string{"--id", a.UUID, "--role", "admin"})
})
roles, err := tool.db.GetRoles(a.ID)
if err != nil {
t.Fatalf("get roles: %v", err)
}
if len(roles) != 1 || roles[0] != "admin" {
t.Errorf("expected [admin], got %v", roles)
}
out := captureStdout(t, func() {
tool.roleList([]string{"--id", a.UUID})
})
if !strings.Contains(out, "admin") {
t.Errorf("expected 'admin' in role list, got: %s", out)
}
}
func TestRoleRevoke(t *testing.T) {
tool := newTestTool(t)
a, err := tool.db.CreateAccount("revokeuser", model.AccountTypeHuman, "hash")
if err != nil {
t.Fatalf("create account: %v", err)
}
if err := tool.db.GrantRole(a.ID, "user", nil); err != nil {
t.Fatalf("grant role: %v", err)
}
captureStdout(t, func() {
tool.roleRevoke([]string{"--id", a.UUID, "--role", "user"})
})
roles, err := tool.db.GetRoles(a.ID)
if err != nil {
t.Fatalf("get roles after revoke: %v", err)
}
if len(roles) != 0 {
t.Errorf("expected no roles after revoke, got %v", roles)
}
}
// ---- token tests ----
func TestTokenListAndRevoke(t *testing.T) {
tool := newTestTool(t)
a, err := tool.db.CreateAccount("tokenuser", model.AccountTypeHuman, "hash")
if err != nil {
t.Fatalf("create account: %v", err)
}
now := time.Now().UTC()
if err := tool.db.TrackToken("test-jti-1", a.ID, now, now.Add(time.Hour)); err != nil {
t.Fatalf("track token: %v", err)
}
out := captureStdout(t, func() {
tool.tokenList([]string{"--id", a.UUID})
})
if !strings.Contains(out, "test-jti-1") {
t.Errorf("expected jti in token list, got: %s", out)
}
captureStdout(t, func() {
tool.tokenRevoke([]string{"--jti", "test-jti-1"})
})
rec, err := tool.db.GetTokenRecord("test-jti-1")
if err != nil {
t.Fatalf("get token record: %v", err)
}
if rec.RevokedAt == nil {
t.Error("expected token to be revoked")
}
}
func TestTokenRevokeAll(t *testing.T) {
tool := newTestTool(t)
a, err := tool.db.CreateAccount("revokealluser", model.AccountTypeHuman, "hash")
if err != nil {
t.Fatalf("create account: %v", err)
}
now := time.Now().UTC()
for i := 0; i < 3; i++ {
jti := fmt.Sprintf("bulk-jti-%d", i)
if err := tool.db.TrackToken(jti, a.ID, now, now.Add(time.Hour)); err != nil {
t.Fatalf("track token %d: %v", i, err)
}
}
captureStdout(t, func() {
tool.tokenRevokeAll([]string{"--id", a.UUID})
})
// Verify all tokens are revoked.
records, err := tool.db.ListTokensForAccount(a.ID)
if err != nil {
t.Fatalf("list tokens: %v", err)
}
for _, r := range records {
if r.RevokedAt == nil {
t.Errorf("token %s should be revoked", r.JTI)
}
}
}
func TestPruneTokens(t *testing.T) {
tool := newTestTool(t)
a, err := tool.db.CreateAccount("pruneuser", model.AccountTypeHuman, "hash")
if err != nil {
t.Fatalf("create account: %v", err)
}
past := time.Now().Add(-2 * time.Hour).UTC()
future := time.Now().Add(time.Hour).UTC()
if err := tool.db.TrackToken("expired-jti", a.ID, past, past.Add(time.Minute)); err != nil {
t.Fatalf("track expired token: %v", err)
}
if err := tool.db.TrackToken("valid-jti", a.ID, future.Add(-time.Minute), future); err != nil {
t.Fatalf("track valid token: %v", err)
}
out := captureStdout(t, tool.pruneTokens)
if !strings.Contains(out, "1") {
t.Errorf("expected 1 pruned in output, got: %s", out)
}
// Valid token should still exist.
if _, err := tool.db.GetTokenRecord("valid-jti"); err != nil {
t.Errorf("valid token should survive pruning: %v", err)
}
}
// ---- audit tests ----
func TestAuditTail(t *testing.T) {
tool := newTestTool(t)
a, err := tool.db.CreateAccount("audituser", model.AccountTypeHuman, "hash")
if err != nil {
t.Fatalf("create account: %v", err)
}
for i := 0; i < 5; i++ {
if err := tool.db.WriteAuditEvent(model.EventLoginOK, &a.ID, nil, "", ""); err != nil {
t.Fatalf("write audit event: %v", err)
}
}
out := captureStdout(t, func() {
tool.auditTail([]string{"--n", "3"})
})
// Output should contain the event type.
if !strings.Contains(out, "login_ok") {
t.Errorf("expected login_ok in tail output, got: %s", out)
}
}
func TestAuditQueryByType(t *testing.T) {
tool := newTestTool(t)
a, err := tool.db.CreateAccount("auditquery", model.AccountTypeHuman, "hash")
if err != nil {
t.Fatalf("create account: %v", err)
}
if err := tool.db.WriteAuditEvent(model.EventLoginOK, &a.ID, nil, "", ""); err != nil {
t.Fatalf("write login_ok: %v", err)
}
if err := tool.db.WriteAuditEvent(model.EventLoginFail, &a.ID, nil, "", ""); err != nil {
t.Fatalf("write login_fail: %v", err)
}
out := captureStdout(t, func() {
tool.auditQuery([]string{"--type", "login_fail"})
})
if !strings.Contains(out, "login_fail") {
t.Errorf("expected login_fail in query output, got: %s", out)
}
if strings.Contains(out, "login_ok") {
t.Errorf("unexpected login_ok in filtered query output, got: %s", out)
}
}
func TestAuditQueryJSON(t *testing.T) {
tool := newTestTool(t)
a, err := tool.db.CreateAccount("jsonaudit", model.AccountTypeHuman, "hash")
if err != nil {
t.Fatalf("create account: %v", err)
}
if err := tool.db.WriteAuditEvent(model.EventLoginOK, &a.ID, nil, "", ""); err != nil {
t.Fatalf("write event: %v", err)
}
out := captureStdout(t, func() {
tool.auditQuery([]string{"--json"})
})
if !strings.Contains(out, `"event_type"`) {
t.Errorf("expected JSON output with event_type, got: %s", out)
}
}
// ---- pgcreds tests ----
func TestPGCredsSetAndGet(t *testing.T) {
tool := newTestTool(t)
a, err := tool.db.CreateAccount("pguser", model.AccountTypeSystem, "")
if err != nil {
t.Fatalf("create account: %v", err)
}
// Encrypt and store credentials directly using the tool's master key.
password := "s3cr3t-pg-pass"
enc, nonce, err := crypto.SealAESGCM(tool.masterKey, []byte(password))
if err != nil {
t.Fatalf("seal pgcreds: %v", err)
}
if err := tool.db.WritePGCredentials(a.ID, "db.example.com", 5432, "mydb", "myuser", enc, nonce); err != nil {
t.Fatalf("write pg credentials: %v", err)
}
// pgCredsGet calls pgCredsGet which calls fatalf if decryption fails.
// We test round-trip via DB + crypto directly.
cred, err := tool.db.ReadPGCredentials(a.ID)
if err != nil {
t.Fatalf("read pg credentials: %v", err)
}
plaintext, err := crypto.OpenAESGCM(tool.masterKey, cred.PGPasswordNonce, cred.PGPasswordEnc)
if err != nil {
t.Fatalf("decrypt pg password: %v", err)
}
if string(plaintext) != password {
t.Errorf("expected password %q, got %q", password, string(plaintext))
}
}
func TestPGCredsGetNotFound(t *testing.T) {
tool := newTestTool(t)
a, err := tool.db.CreateAccount("nopguser", model.AccountTypeSystem, "")
if err != nil {
t.Fatalf("create account: %v", err)
}
// ReadPGCredentials for account with no credentials should return ErrNotFound.
_, err = tool.db.ReadPGCredentials(a.ID)
if err == nil {
t.Fatal("expected ErrNotFound, got nil")
}
}