package main import ( "bytes" "fmt" "io" "os" "strings" "testing" "time" "git.wntrmute.dev/kyle/mcias/internal/crypto" "git.wntrmute.dev/kyle/mcias/internal/db" "git.wntrmute.dev/kyle/mcias/internal/model" ) // newTestTool creates a tool backed by an in-memory SQLite database with a // freshly generated master key. The database is migrated to the latest schema. func newTestTool(t *testing.T) *tool { t.Helper() database, err := db.Open(":memory:") if err != nil { t.Fatalf("open test DB: %v", err) } if err := db.Migrate(database); err != nil { t.Fatalf("migrate test DB: %v", err) } t.Cleanup(func() { _ = database.Close() }) // Use a random 32-byte master key for encryption tests. masterKey, err := crypto.RandomBytes(32) if err != nil { t.Fatalf("generate master key: %v", err) } return &tool{db: database, masterKey: masterKey} } // captureStdout captures stdout output during fn execution. func captureStdout(t *testing.T, fn func()) string { t.Helper() orig := os.Stdout r, w, err := os.Pipe() if err != nil { t.Fatalf("create pipe: %v", err) } os.Stdout = w fn() _ = w.Close() os.Stdout = orig var buf bytes.Buffer if _, err := io.Copy(&buf, r); err != nil { t.Fatalf("copy stdout: %v", err) } return buf.String() } // ---- schema tests ---- func TestSchemaVerifyUpToDate(t *testing.T) { tool := newTestTool(t) // Capture output; schemaVerify calls exitCode1 if migrations pending, // but with a freshly migrated DB it should print "up-to-date". out := captureStdout(t, tool.schemaVerify) if !strings.Contains(out, "up-to-date") { t.Errorf("expected 'up-to-date' in output, got: %s", out) } } // ---- account tests ---- func TestAccountListEmpty(t *testing.T) { tool := newTestTool(t) out := captureStdout(t, tool.accountList) if !strings.Contains(out, "no accounts") { t.Errorf("expected 'no accounts' in output, got: %s", out) } } func TestAccountCreateAndList(t *testing.T) { tool := newTestTool(t) // Create via DB method directly (accountCreate reads args via flags so // we test the DB path to avoid os.Exit on parse error). a, err := tool.db.CreateAccount("testuser", model.AccountTypeHuman, "") if err != nil { t.Fatalf("create account: %v", err) } if a.UUID == "" { t.Error("expected UUID to be set") } out := captureStdout(t, tool.accountList) if !strings.Contains(out, "testuser") { t.Errorf("expected 'testuser' in list output, got: %s", out) } } func TestAccountGetByUUID(t *testing.T) { tool := newTestTool(t) a, err := tool.db.CreateAccount("getuser", model.AccountTypeSystem, "") if err != nil { t.Fatalf("create account: %v", err) } out := captureStdout(t, func() { tool.accountGet([]string{"--id", a.UUID}) }) if !strings.Contains(out, "getuser") { t.Errorf("expected 'getuser' in get output, got: %s", out) } if !strings.Contains(out, "system") { t.Errorf("expected 'system' in get output, got: %s", out) } } func TestAccountSetStatus(t *testing.T) { tool := newTestTool(t) a, err := tool.db.CreateAccount("statususer", model.AccountTypeHuman, "hash") if err != nil { t.Fatalf("create account: %v", err) } captureStdout(t, func() { tool.accountSetStatus([]string{"--id", a.UUID, "--status", "inactive"}) }) updated, err := tool.db.GetAccountByUUID(a.UUID) if err != nil { t.Fatalf("get account after update: %v", err) } if updated.Status != model.AccountStatusInactive { t.Errorf("expected inactive status, got %s", updated.Status) } } func TestAccountResetTOTP(t *testing.T) { tool := newTestTool(t) a, err := tool.db.CreateAccount("totpuser", model.AccountTypeHuman, "hash") if err != nil { t.Fatalf("create account: %v", err) } // Set TOTP fields. if err := tool.db.SetTOTP(a.ID, []byte("enc"), []byte("nonce")); err != nil { t.Fatalf("set TOTP: %v", err) } captureStdout(t, func() { tool.accountResetTOTP([]string{"--id", a.UUID}) }) updated, err := tool.db.GetAccountByUUID(a.UUID) if err != nil { t.Fatalf("get account after reset: %v", err) } if updated.TOTPRequired { t.Error("expected TOTP to be cleared") } if len(updated.TOTPSecretEnc) != 0 { t.Error("expected TOTP secret to be cleared") } } // ---- role tests ---- func TestRoleGrantAndList(t *testing.T) { tool := newTestTool(t) a, err := tool.db.CreateAccount("roleuser", model.AccountTypeHuman, "hash") if err != nil { t.Fatalf("create account: %v", err) } captureStdout(t, func() { tool.roleGrant([]string{"--id", a.UUID, "--role", "admin"}) }) roles, err := tool.db.GetRoles(a.ID) if err != nil { t.Fatalf("get roles: %v", err) } if len(roles) != 1 || roles[0] != "admin" { t.Errorf("expected [admin], got %v", roles) } out := captureStdout(t, func() { tool.roleList([]string{"--id", a.UUID}) }) if !strings.Contains(out, "admin") { t.Errorf("expected 'admin' in role list, got: %s", out) } } func TestRoleRevoke(t *testing.T) { tool := newTestTool(t) a, err := tool.db.CreateAccount("revokeuser", model.AccountTypeHuman, "hash") if err != nil { t.Fatalf("create account: %v", err) } if err := tool.db.GrantRole(a.ID, "editor", nil); err != nil { t.Fatalf("grant role: %v", err) } captureStdout(t, func() { tool.roleRevoke([]string{"--id", a.UUID, "--role", "editor"}) }) roles, err := tool.db.GetRoles(a.ID) if err != nil { t.Fatalf("get roles after revoke: %v", err) } if len(roles) != 0 { t.Errorf("expected no roles after revoke, got %v", roles) } } // ---- token tests ---- func TestTokenListAndRevoke(t *testing.T) { tool := newTestTool(t) a, err := tool.db.CreateAccount("tokenuser", model.AccountTypeHuman, "hash") if err != nil { t.Fatalf("create account: %v", err) } now := time.Now().UTC() if err := tool.db.TrackToken("test-jti-1", a.ID, now, now.Add(time.Hour)); err != nil { t.Fatalf("track token: %v", err) } out := captureStdout(t, func() { tool.tokenList([]string{"--id", a.UUID}) }) if !strings.Contains(out, "test-jti-1") { t.Errorf("expected jti in token list, got: %s", out) } captureStdout(t, func() { tool.tokenRevoke([]string{"--jti", "test-jti-1"}) }) rec, err := tool.db.GetTokenRecord("test-jti-1") if err != nil { t.Fatalf("get token record: %v", err) } if rec.RevokedAt == nil { t.Error("expected token to be revoked") } } func TestTokenRevokeAll(t *testing.T) { tool := newTestTool(t) a, err := tool.db.CreateAccount("revokealluser", model.AccountTypeHuman, "hash") if err != nil { t.Fatalf("create account: %v", err) } now := time.Now().UTC() for i := 0; i < 3; i++ { jti := fmt.Sprintf("bulk-jti-%d", i) if err := tool.db.TrackToken(jti, a.ID, now, now.Add(time.Hour)); err != nil { t.Fatalf("track token %d: %v", i, err) } } captureStdout(t, func() { tool.tokenRevokeAll([]string{"--id", a.UUID}) }) // Verify all tokens are revoked. records, err := tool.db.ListTokensForAccount(a.ID) if err != nil { t.Fatalf("list tokens: %v", err) } for _, r := range records { if r.RevokedAt == nil { t.Errorf("token %s should be revoked", r.JTI) } } } func TestPruneTokens(t *testing.T) { tool := newTestTool(t) a, err := tool.db.CreateAccount("pruneuser", model.AccountTypeHuman, "hash") if err != nil { t.Fatalf("create account: %v", err) } past := time.Now().Add(-2 * time.Hour).UTC() future := time.Now().Add(time.Hour).UTC() if err := tool.db.TrackToken("expired-jti", a.ID, past, past.Add(time.Minute)); err != nil { t.Fatalf("track expired token: %v", err) } if err := tool.db.TrackToken("valid-jti", a.ID, future.Add(-time.Minute), future); err != nil { t.Fatalf("track valid token: %v", err) } out := captureStdout(t, tool.pruneTokens) if !strings.Contains(out, "1") { t.Errorf("expected 1 pruned in output, got: %s", out) } // Valid token should still exist. if _, err := tool.db.GetTokenRecord("valid-jti"); err != nil { t.Errorf("valid token should survive pruning: %v", err) } } // ---- audit tests ---- func TestAuditTail(t *testing.T) { tool := newTestTool(t) a, err := tool.db.CreateAccount("audituser", model.AccountTypeHuman, "hash") if err != nil { t.Fatalf("create account: %v", err) } for i := 0; i < 5; i++ { if err := tool.db.WriteAuditEvent(model.EventLoginOK, &a.ID, nil, "", ""); err != nil { t.Fatalf("write audit event: %v", err) } } out := captureStdout(t, func() { tool.auditTail([]string{"--n", "3"}) }) // Output should contain the event type. if !strings.Contains(out, "login_ok") { t.Errorf("expected login_ok in tail output, got: %s", out) } } func TestAuditQueryByType(t *testing.T) { tool := newTestTool(t) a, err := tool.db.CreateAccount("auditquery", model.AccountTypeHuman, "hash") if err != nil { t.Fatalf("create account: %v", err) } if err := tool.db.WriteAuditEvent(model.EventLoginOK, &a.ID, nil, "", ""); err != nil { t.Fatalf("write login_ok: %v", err) } if err := tool.db.WriteAuditEvent(model.EventLoginFail, &a.ID, nil, "", ""); err != nil { t.Fatalf("write login_fail: %v", err) } out := captureStdout(t, func() { tool.auditQuery([]string{"--type", "login_fail"}) }) if !strings.Contains(out, "login_fail") { t.Errorf("expected login_fail in query output, got: %s", out) } if strings.Contains(out, "login_ok") { t.Errorf("unexpected login_ok in filtered query output, got: %s", out) } } func TestAuditQueryJSON(t *testing.T) { tool := newTestTool(t) a, err := tool.db.CreateAccount("jsonaudit", model.AccountTypeHuman, "hash") if err != nil { t.Fatalf("create account: %v", err) } if err := tool.db.WriteAuditEvent(model.EventLoginOK, &a.ID, nil, "", ""); err != nil { t.Fatalf("write event: %v", err) } out := captureStdout(t, func() { tool.auditQuery([]string{"--json"}) }) if !strings.Contains(out, `"event_type"`) { t.Errorf("expected JSON output with event_type, got: %s", out) } } // ---- pgcreds tests ---- func TestPGCredsSetAndGet(t *testing.T) { tool := newTestTool(t) a, err := tool.db.CreateAccount("pguser", model.AccountTypeSystem, "") if err != nil { t.Fatalf("create account: %v", err) } // Encrypt and store credentials directly using the tool's master key. password := "s3cr3t-pg-pass" enc, nonce, err := crypto.SealAESGCM(tool.masterKey, []byte(password)) if err != nil { t.Fatalf("seal pgcreds: %v", err) } if err := tool.db.WritePGCredentials(a.ID, "db.example.com", 5432, "mydb", "myuser", enc, nonce); err != nil { t.Fatalf("write pg credentials: %v", err) } // pgCredsGet calls pgCredsGet which calls fatalf if decryption fails. // We test round-trip via DB + crypto directly. cred, err := tool.db.ReadPGCredentials(a.ID) if err != nil { t.Fatalf("read pg credentials: %v", err) } plaintext, err := crypto.OpenAESGCM(tool.masterKey, cred.PGPasswordNonce, cred.PGPasswordEnc) if err != nil { t.Fatalf("decrypt pg password: %v", err) } if string(plaintext) != password { t.Errorf("expected password %q, got %q", password, string(plaintext)) } } func TestPGCredsGetNotFound(t *testing.T) { tool := newTestTool(t) a, err := tool.db.CreateAccount("nopguser", model.AccountTypeSystem, "") if err != nil { t.Fatalf("create account: %v", err) } // ReadPGCredentials for account with no credentials should return ErrNotFound. _, err = tool.db.ReadPGCredentials(a.ID) if err == nil { t.Fatal("expected ErrNotFound, got nil") } }