package main import ( "bytes" "fmt" "io" "os" "strings" "testing" "time" "git.wntrmute.dev/kyle/mcias/internal/crypto" "git.wntrmute.dev/kyle/mcias/internal/db" "git.wntrmute.dev/kyle/mcias/internal/model" ) // newTestTool creates a tool backed by an in-memory SQLite database with a // freshly generated master key. The database is migrated to the latest schema. func newTestTool(t *testing.T) *tool { t.Helper() database, err := db.Open(":memory:") if err != nil { t.Fatalf("open test DB: %v", err) } if err := db.Migrate(database); err != nil { t.Fatalf("migrate test DB: %v", err) } t.Cleanup(func() { _ = database.Close() }) // Use a random 32-byte master key for encryption tests. masterKey, err := crypto.RandomBytes(32) if err != nil { t.Fatalf("generate master key: %v", err) } return &tool{db: database, masterKey: masterKey} } // captureStdout captures stdout output during fn execution. func captureStdout(t *testing.T, fn func()) string { t.Helper() orig := os.Stdout r, w, err := os.Pipe() if err != nil { t.Fatalf("create pipe: %v", err) } os.Stdout = w fn() _ = w.Close() os.Stdout = orig var buf bytes.Buffer if _, err := io.Copy(&buf, r); err != nil { t.Fatalf("copy stdout: %v", err) } return buf.String() } // ---- schema tests ---- func TestSchemaVerifyUpToDate(t *testing.T) { tool := newTestTool(t) // Capture output; schemaVerify calls exitCode1 if migrations pending, // but with a freshly migrated DB it should print "up-to-date". out := captureStdout(t, tool.schemaVerify) if !strings.Contains(out, "up-to-date") { t.Errorf("expected 'up-to-date' in output, got: %s", out) } } // ---- account tests ---- func TestAccountListEmpty(t *testing.T) { tool := newTestTool(t) out := captureStdout(t, tool.accountList) if !strings.Contains(out, "no accounts") { t.Errorf("expected 'no accounts' in output, got: %s", out) } } func TestAccountCreateAndList(t *testing.T) { tool := newTestTool(t) // Create via DB method directly (accountCreate reads args via flags so // we test the DB path to avoid os.Exit on parse error). a, err := tool.db.CreateAccount("testuser", model.AccountTypeHuman, "") if err != nil { t.Fatalf("create account: %v", err) } if a.UUID == "" { t.Error("expected UUID to be set") } out := captureStdout(t, tool.accountList) if !strings.Contains(out, "testuser") { t.Errorf("expected 'testuser' in list output, got: %s", out) } } func TestAccountGetByUUID(t *testing.T) { tool := newTestTool(t) a, err := tool.db.CreateAccount("getuser", model.AccountTypeSystem, "") if err != nil { t.Fatalf("create account: %v", err) } out := captureStdout(t, func() { tool.accountGet([]string{"--id", a.UUID}) }) if !strings.Contains(out, "getuser") { t.Errorf("expected 'getuser' in get output, got: %s", out) } if !strings.Contains(out, "system") { t.Errorf("expected 'system' in get output, got: %s", out) } } func TestAccountSetStatus(t *testing.T) { tool := newTestTool(t) a, err := tool.db.CreateAccount("statususer", model.AccountTypeHuman, "hash") if err != nil { t.Fatalf("create account: %v", err) } captureStdout(t, func() { tool.accountSetStatus([]string{"--id", a.UUID, "--status", "inactive"}) }) updated, err := tool.db.GetAccountByUUID(a.UUID) if err != nil { t.Fatalf("get account after update: %v", err) } if updated.Status != model.AccountStatusInactive { t.Errorf("expected inactive status, got %s", updated.Status) } } func TestAccountResetTOTP(t *testing.T) { tool := newTestTool(t) a, err := tool.db.CreateAccount("totpuser", model.AccountTypeHuman, "hash") if err != nil { t.Fatalf("create account: %v", err) } // Set TOTP fields. if err := tool.db.SetTOTP(a.ID, []byte("enc"), []byte("nonce")); err != nil { t.Fatalf("set TOTP: %v", err) } captureStdout(t, func() { tool.accountResetTOTP([]string{"--id", a.UUID}) }) updated, err := tool.db.GetAccountByUUID(a.UUID) if err != nil { t.Fatalf("get account after reset: %v", err) } if updated.TOTPRequired { t.Error("expected TOTP to be cleared") } if len(updated.TOTPSecretEnc) != 0 { t.Error("expected TOTP secret to be cleared") } } // ---- role tests ---- func TestRoleGrantAndList(t *testing.T) { tool := newTestTool(t) a, err := tool.db.CreateAccount("roleuser", model.AccountTypeHuman, "hash") if err != nil { t.Fatalf("create account: %v", err) } captureStdout(t, func() { tool.roleGrant([]string{"--id", a.UUID, "--role", "admin"}) }) roles, err := tool.db.GetRoles(a.ID) if err != nil { t.Fatalf("get roles: %v", err) } if len(roles) != 1 || roles[0] != "admin" { t.Errorf("expected [admin], got %v", roles) } out := captureStdout(t, func() { tool.roleList([]string{"--id", a.UUID}) }) if !strings.Contains(out, "admin") { t.Errorf("expected 'admin' in role list, got: %s", out) } } func TestRoleRevoke(t *testing.T) { tool := newTestTool(t) a, err := tool.db.CreateAccount("revokeuser", model.AccountTypeHuman, "hash") if err != nil { t.Fatalf("create account: %v", err) } if err := tool.db.GrantRole(a.ID, "user", nil); err != nil { t.Fatalf("grant role: %v", err) } captureStdout(t, func() { tool.roleRevoke([]string{"--id", a.UUID, "--role", "user"}) }) roles, err := tool.db.GetRoles(a.ID) if err != nil { t.Fatalf("get roles after revoke: %v", err) } if len(roles) != 0 { t.Errorf("expected no roles after revoke, got %v", roles) } } // ---- token tests ---- func TestTokenListAndRevoke(t *testing.T) { tool := newTestTool(t) a, err := tool.db.CreateAccount("tokenuser", model.AccountTypeHuman, "hash") if err != nil { t.Fatalf("create account: %v", err) } now := time.Now().UTC() if err := tool.db.TrackToken("test-jti-1", a.ID, now, now.Add(time.Hour)); err != nil { t.Fatalf("track token: %v", err) } out := captureStdout(t, func() { tool.tokenList([]string{"--id", a.UUID}) }) if !strings.Contains(out, "test-jti-1") { t.Errorf("expected jti in token list, got: %s", out) } captureStdout(t, func() { tool.tokenRevoke([]string{"--jti", "test-jti-1"}) }) rec, err := tool.db.GetTokenRecord("test-jti-1") if err != nil { t.Fatalf("get token record: %v", err) } if rec.RevokedAt == nil { t.Error("expected token to be revoked") } } func TestTokenRevokeAll(t *testing.T) { tool := newTestTool(t) a, err := tool.db.CreateAccount("revokealluser", model.AccountTypeHuman, "hash") if err != nil { t.Fatalf("create account: %v", err) } now := time.Now().UTC() for i := 0; i < 3; i++ { jti := fmt.Sprintf("bulk-jti-%d", i) if err := tool.db.TrackToken(jti, a.ID, now, now.Add(time.Hour)); err != nil { t.Fatalf("track token %d: %v", i, err) } } captureStdout(t, func() { tool.tokenRevokeAll([]string{"--id", a.UUID}) }) // Verify all tokens are revoked. records, err := tool.db.ListTokensForAccount(a.ID) if err != nil { t.Fatalf("list tokens: %v", err) } for _, r := range records { if r.RevokedAt == nil { t.Errorf("token %s should be revoked", r.JTI) } } } func TestPruneTokens(t *testing.T) { tool := newTestTool(t) a, err := tool.db.CreateAccount("pruneuser", model.AccountTypeHuman, "hash") if err != nil { t.Fatalf("create account: %v", err) } past := time.Now().Add(-2 * time.Hour).UTC() future := time.Now().Add(time.Hour).UTC() if err := tool.db.TrackToken("expired-jti", a.ID, past, past.Add(time.Minute)); err != nil { t.Fatalf("track expired token: %v", err) } if err := tool.db.TrackToken("valid-jti", a.ID, future.Add(-time.Minute), future); err != nil { t.Fatalf("track valid token: %v", err) } out := captureStdout(t, tool.pruneTokens) if !strings.Contains(out, "1") { t.Errorf("expected 1 pruned in output, got: %s", out) } // Valid token should still exist. if _, err := tool.db.GetTokenRecord("valid-jti"); err != nil { t.Errorf("valid token should survive pruning: %v", err) } } // ---- audit tests ---- func TestAuditTail(t *testing.T) { tool := newTestTool(t) a, err := tool.db.CreateAccount("audituser", model.AccountTypeHuman, "hash") if err != nil { t.Fatalf("create account: %v", err) } for i := 0; i < 5; i++ { if err := tool.db.WriteAuditEvent(model.EventLoginOK, &a.ID, nil, "", ""); err != nil { t.Fatalf("write audit event: %v", err) } } out := captureStdout(t, func() { tool.auditTail([]string{"--n", "3"}) }) // Output should contain the event type. if !strings.Contains(out, "login_ok") { t.Errorf("expected login_ok in tail output, got: %s", out) } } func TestAuditQueryByType(t *testing.T) { tool := newTestTool(t) a, err := tool.db.CreateAccount("auditquery", model.AccountTypeHuman, "hash") if err != nil { t.Fatalf("create account: %v", err) } if err := tool.db.WriteAuditEvent(model.EventLoginOK, &a.ID, nil, "", ""); err != nil { t.Fatalf("write login_ok: %v", err) } if err := tool.db.WriteAuditEvent(model.EventLoginFail, &a.ID, nil, "", ""); err != nil { t.Fatalf("write login_fail: %v", err) } out := captureStdout(t, func() { tool.auditQuery([]string{"--type", "login_fail"}) }) if !strings.Contains(out, "login_fail") { t.Errorf("expected login_fail in query output, got: %s", out) } if strings.Contains(out, "login_ok") { t.Errorf("unexpected login_ok in filtered query output, got: %s", out) } } func TestAuditQueryJSON(t *testing.T) { tool := newTestTool(t) a, err := tool.db.CreateAccount("jsonaudit", model.AccountTypeHuman, "hash") if err != nil { t.Fatalf("create account: %v", err) } if err := tool.db.WriteAuditEvent(model.EventLoginOK, &a.ID, nil, "", ""); err != nil { t.Fatalf("write event: %v", err) } out := captureStdout(t, func() { tool.auditQuery([]string{"--json"}) }) if !strings.Contains(out, `"event_type"`) { t.Errorf("expected JSON output with event_type, got: %s", out) } } // ---- pgcreds tests ---- func TestPGCredsSetAndGet(t *testing.T) { tool := newTestTool(t) a, err := tool.db.CreateAccount("pguser", model.AccountTypeSystem, "") if err != nil { t.Fatalf("create account: %v", err) } // Encrypt and store credentials directly using the tool's master key. password := "s3cr3t-pg-pass" enc, nonce, err := crypto.SealAESGCM(tool.masterKey, []byte(password)) if err != nil { t.Fatalf("seal pgcreds: %v", err) } if err := tool.db.WritePGCredentials(a.ID, "db.example.com", 5432, "mydb", "myuser", enc, nonce); err != nil { t.Fatalf("write pg credentials: %v", err) } // pgCredsGet calls pgCredsGet which calls fatalf if decryption fails. // We test round-trip via DB + crypto directly. cred, err := tool.db.ReadPGCredentials(a.ID) if err != nil { t.Fatalf("read pg credentials: %v", err) } plaintext, err := crypto.OpenAESGCM(tool.masterKey, cred.PGPasswordNonce, cred.PGPasswordEnc) if err != nil { t.Fatalf("decrypt pg password: %v", err) } if string(plaintext) != password { t.Errorf("expected password %q, got %q", password, string(plaintext)) } } func TestPGCredsGetNotFound(t *testing.T) { tool := newTestTool(t) a, err := tool.db.CreateAccount("nopguser", model.AccountTypeSystem, "") if err != nil { t.Fatalf("create account: %v", err) } // ReadPGCredentials for account with no credentials should return ErrNotFound. _, err = tool.db.ReadPGCredentials(a.ID) if err == nil { t.Fatal("expected ErrNotFound, got nil") } } // ---- rekey command tests ---- // TestRekeyCommandRoundTrip exercises runRekey end-to-end with real AES-256-GCM // encryption and actual Argon2id key derivation. It verifies that all secrets // (signing key, TOTP, pg password) remain accessible after rekey and that the // old master key no longer decrypts the re-encrypted values. // // Note: Argon2id derivation (time=3, memory=128 MiB) makes this test slow (~2 s). func TestRekeyCommandRoundTrip(t *testing.T) { tool := newTestTool(t) // ── Setup: signing key encrypted under old master key ── _, privKey, err := crypto.GenerateEd25519KeyPair() if err != nil { t.Fatalf("generate key pair: %v", err) } sigKeyPEM, err := crypto.MarshalPrivateKeyPEM(privKey) if err != nil { t.Fatalf("marshal key: %v", err) } sigEnc, sigNonce, err := crypto.SealAESGCM(tool.masterKey, sigKeyPEM) if err != nil { t.Fatalf("seal signing key: %v", err) } if err := tool.db.WriteServerConfig(sigEnc, sigNonce); err != nil { t.Fatalf("write server config: %v", err) } // WriteMasterKeySalt so ReadServerConfig has a valid salt row. oldSalt, err := crypto.NewSalt() if err != nil { t.Fatalf("gen salt: %v", err) } if err := tool.db.WriteMasterKeySalt(oldSalt); err != nil { t.Fatalf("write salt: %v", err) } // ── Setup: account with TOTP ── a, err := tool.db.CreateAccount("rekeyuser", "human", "") if err != nil { t.Fatalf("create account: %v", err) } totpSecret := []byte("JBSWY3DPEHPK3PXP") totpEnc, totpNonce, err := crypto.SealAESGCM(tool.masterKey, totpSecret) if err != nil { t.Fatalf("seal totp: %v", err) } if err := tool.db.SetTOTP(a.ID, totpEnc, totpNonce); err != nil { t.Fatalf("set totp: %v", err) } // ── Setup: pg credentials ── pgPass := []byte("pgpassword123") pgEnc, pgNonce, err := crypto.SealAESGCM(tool.masterKey, pgPass) if err != nil { t.Fatalf("seal pg pass: %v", err) } if err := tool.db.WritePGCredentials(a.ID, "localhost", 5432, "mydb", "myuser", pgEnc, pgNonce); err != nil { t.Fatalf("write pg creds: %v", err) } // ── Pipe new passphrase twice into stdin ── const newPassphrase = "new-master-passphrase-for-test" r, w, err := os.Pipe() if err != nil { t.Fatalf("create stdin pipe: %v", err) } origStdin := os.Stdin os.Stdin = r t.Cleanup(func() { os.Stdin = origStdin }) if _, err := fmt.Fprintf(w, "%s\n%s\n", newPassphrase, newPassphrase); err != nil { t.Fatalf("write stdin: %v", err) } _ = w.Close() // ── Execute rekey ── tool.runRekey(nil) // ── Derive new key from stored salt + new passphrase ── newSalt, err := tool.db.ReadMasterKeySalt() if err != nil { t.Fatalf("read new salt: %v", err) } newKey, err := crypto.DeriveKey(newPassphrase, newSalt) if err != nil { t.Fatalf("derive new key: %v", err) } defer func() { for i := range newKey { newKey[i] = 0 } }() // Signing key must decrypt with new key. newSigEnc, newSigNonce, err := tool.db.ReadServerConfig() if err != nil { t.Fatalf("read server config after rekey: %v", err) } decPEM, err := crypto.OpenAESGCM(newKey, newSigNonce, newSigEnc) if err != nil { t.Fatalf("decrypt signing key with new key: %v", err) } if string(decPEM) != string(sigKeyPEM) { t.Error("signing key PEM mismatch after rekey") } // Old key must NOT decrypt the re-encrypted signing key. // Security: adversarial check that old key is invalidated. if _, err := crypto.OpenAESGCM(tool.masterKey, newSigNonce, newSigEnc); err == nil { t.Error("old key still decrypts signing key after rekey — ciphertext was not replaced") } // TOTP must decrypt with new key. updatedAcct, err := tool.db.GetAccountByUUID(a.UUID) if err != nil { t.Fatalf("get account after rekey: %v", err) } decTOTP, err := crypto.OpenAESGCM(newKey, updatedAcct.TOTPSecretNonce, updatedAcct.TOTPSecretEnc) if err != nil { t.Fatalf("decrypt TOTP with new key: %v", err) } if string(decTOTP) != string(totpSecret) { t.Errorf("TOTP mismatch: got %q, want %q", decTOTP, totpSecret) } // pg password must decrypt with new key. updatedCred, err := tool.db.ReadPGCredentials(a.ID) if err != nil { t.Fatalf("read pg creds after rekey: %v", err) } decPG, err := crypto.OpenAESGCM(newKey, updatedCred.PGPasswordNonce, updatedCred.PGPasswordEnc) if err != nil { t.Fatalf("decrypt pg password with new key: %v", err) } if string(decPG) != string(pgPass) { t.Errorf("pg password mismatch: got %q, want %q", decPG, pgPass) } }