- internal/db/accounts.go: add ListAccountsWithTOTP, ListAllPGCredentials, TOTPRekeyRow, PGRekeyRow, and Rekey — atomic transaction that replaces master_key_salt, signing_key_enc/nonce, all TOTP enc/nonce, and all pg_password enc/nonce in one SQLite BEGIN/COMMIT - cmd/mciasdb/rekey.go: runRekey — decrypts all secrets under old master key, prompts for new passphrase (with confirmation), derives new key from fresh Argon2id salt, re-encrypts everything, and commits atomically - cmd/mciasdb/main.go: wire "rekey" command + update usage - Tests: DB-layer tests for ListAccountsWithTOTP, ListAllPGCredentials, Rekey (happy path, empty DB, salt replacement); command-level TestRekeyCommandRoundTrip verifies full round-trip and adversarially confirms old key no longer decrypts after rekey Security: fresh random salt is always generated so a reused passphrase still produces an independent key; old and new master keys are zeroed via defer; no passphrase or key material appears in logs or audit events; the entire re-encryption is done in-memory before the single atomic DB write so the database is never in a mixed state. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
404 lines
12 KiB
Go
404 lines
12 KiB
Go
package db
|
|
|
|
import (
|
|
"testing"
|
|
"time"
|
|
|
|
"git.wntrmute.dev/kyle/mcias/internal/model"
|
|
)
|
|
|
|
// openTestDB is defined in db_test.go in this package; reused here.
|
|
|
|
func TestListTokensForAccount(t *testing.T) {
|
|
database := openTestDB(t)
|
|
|
|
acc, err := database.CreateAccount("tokenuser", model.AccountTypeHuman, "hash")
|
|
if err != nil {
|
|
t.Fatalf("create account: %v", err)
|
|
}
|
|
|
|
// No tokens yet.
|
|
records, err := database.ListTokensForAccount(acc.ID)
|
|
if err != nil {
|
|
t.Fatalf("list tokens (empty): %v", err)
|
|
}
|
|
if len(records) != 0 {
|
|
t.Fatalf("expected 0 tokens, got %d", len(records))
|
|
}
|
|
|
|
// Track two tokens.
|
|
now := time.Now().UTC()
|
|
if err := database.TrackToken("jti-aaa", acc.ID, now, now.Add(time.Hour)); err != nil {
|
|
t.Fatalf("track token 1: %v", err)
|
|
}
|
|
if err := database.TrackToken("jti-bbb", acc.ID, now.Add(time.Second), now.Add(2*time.Hour)); err != nil {
|
|
t.Fatalf("track token 2: %v", err)
|
|
}
|
|
|
|
records, err = database.ListTokensForAccount(acc.ID)
|
|
if err != nil {
|
|
t.Fatalf("list tokens: %v", err)
|
|
}
|
|
if len(records) != 2 {
|
|
t.Fatalf("expected 2 tokens, got %d", len(records))
|
|
}
|
|
// Newest first.
|
|
if records[0].JTI != "jti-bbb" {
|
|
t.Errorf("expected jti-bbb first, got %s", records[0].JTI)
|
|
}
|
|
if records[1].JTI != "jti-aaa" {
|
|
t.Errorf("expected jti-aaa second, got %s", records[1].JTI)
|
|
}
|
|
}
|
|
|
|
func TestListAuditEventsFilter(t *testing.T) {
|
|
database := openTestDB(t)
|
|
|
|
acc1, err := database.CreateAccount("audituser1", model.AccountTypeHuman, "hash")
|
|
if err != nil {
|
|
t.Fatalf("create account 1: %v", err)
|
|
}
|
|
acc2, err := database.CreateAccount("audituser2", model.AccountTypeHuman, "hash")
|
|
if err != nil {
|
|
t.Fatalf("create account 2: %v", err)
|
|
}
|
|
|
|
// Write events for both accounts with different types.
|
|
if err := database.WriteAuditEvent(model.EventLoginOK, &acc1.ID, nil, "1.2.3.4", ""); err != nil {
|
|
t.Fatalf("write audit event 1: %v", err)
|
|
}
|
|
if err := database.WriteAuditEvent(model.EventLoginFail, &acc2.ID, nil, "5.6.7.8", ""); err != nil {
|
|
t.Fatalf("write audit event 2: %v", err)
|
|
}
|
|
if err := database.WriteAuditEvent(model.EventTokenIssued, &acc1.ID, nil, "", ""); err != nil {
|
|
t.Fatalf("write audit event 3: %v", err)
|
|
}
|
|
|
|
// Filter by account.
|
|
events, err := database.ListAuditEvents(AuditQueryParams{AccountID: &acc1.ID})
|
|
if err != nil {
|
|
t.Fatalf("list by account: %v", err)
|
|
}
|
|
if len(events) != 2 {
|
|
t.Fatalf("expected 2 events for acc1, got %d", len(events))
|
|
}
|
|
|
|
// Filter by event type.
|
|
events, err = database.ListAuditEvents(AuditQueryParams{EventType: model.EventLoginFail})
|
|
if err != nil {
|
|
t.Fatalf("list by type: %v", err)
|
|
}
|
|
if len(events) != 1 {
|
|
t.Fatalf("expected 1 login_fail event, got %d", len(events))
|
|
}
|
|
|
|
// Filter by since (after all events).
|
|
future := time.Now().Add(time.Hour)
|
|
events, err = database.ListAuditEvents(AuditQueryParams{Since: &future})
|
|
if err != nil {
|
|
t.Fatalf("list by since (future): %v", err)
|
|
}
|
|
if len(events) != 0 {
|
|
t.Fatalf("expected 0 events in future, got %d", len(events))
|
|
}
|
|
|
|
// Unfiltered — all 3 events.
|
|
events, err = database.ListAuditEvents(AuditQueryParams{})
|
|
if err != nil {
|
|
t.Fatalf("list unfiltered: %v", err)
|
|
}
|
|
if len(events) != 3 {
|
|
t.Fatalf("expected 3 events unfiltered, got %d", len(events))
|
|
}
|
|
|
|
_ = acc2
|
|
}
|
|
|
|
func TestTailAuditEvents(t *testing.T) {
|
|
database := openTestDB(t)
|
|
|
|
acc, err := database.CreateAccount("tailuser", model.AccountTypeHuman, "hash")
|
|
if err != nil {
|
|
t.Fatalf("create account: %v", err)
|
|
}
|
|
|
|
// Write 5 events.
|
|
for i := 0; i < 5; i++ {
|
|
if err := database.WriteAuditEvent(model.EventLoginOK, &acc.ID, nil, "", ""); err != nil {
|
|
t.Fatalf("write audit event %d: %v", i, err)
|
|
}
|
|
}
|
|
|
|
// Tail 3 — should return the 3 most recent, oldest-first.
|
|
events, err := database.TailAuditEvents(3)
|
|
if err != nil {
|
|
t.Fatalf("tail audit events: %v", err)
|
|
}
|
|
if len(events) != 3 {
|
|
t.Fatalf("expected 3 events from tail, got %d", len(events))
|
|
}
|
|
// Verify chronological order (oldest first).
|
|
for i := 1; i < len(events); i++ {
|
|
if events[i].EventTime.Before(events[i-1].EventTime) {
|
|
// Allow equal times (written in same second).
|
|
if events[i].EventTime.Equal(events[i-1].EventTime) {
|
|
continue
|
|
}
|
|
t.Errorf("events not in chronological order at index %d", i)
|
|
}
|
|
}
|
|
|
|
// Tail more than exist — should return all 5.
|
|
events, err = database.TailAuditEvents(100)
|
|
if err != nil {
|
|
t.Fatalf("tail 100: %v", err)
|
|
}
|
|
if len(events) != 5 {
|
|
t.Fatalf("expected 5 from tail(100), got %d", len(events))
|
|
}
|
|
}
|
|
|
|
func TestListAuditEventsCombinedFilters(t *testing.T) {
|
|
database := openTestDB(t)
|
|
|
|
acc, err := database.CreateAccount("combo", model.AccountTypeHuman, "hash")
|
|
if err != nil {
|
|
t.Fatalf("create account: %v", err)
|
|
}
|
|
|
|
if err := database.WriteAuditEvent(model.EventLoginOK, &acc.ID, nil, "", ""); err != nil {
|
|
t.Fatalf("write event: %v", err)
|
|
}
|
|
|
|
// Combine account + type filters.
|
|
events, err := database.ListAuditEvents(AuditQueryParams{
|
|
AccountID: &acc.ID,
|
|
EventType: model.EventLoginOK,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("combined filter: %v", err)
|
|
}
|
|
if len(events) != 1 {
|
|
t.Fatalf("expected 1 event, got %d", len(events))
|
|
}
|
|
|
|
// Combine account + wrong type.
|
|
events, err = database.ListAuditEvents(AuditQueryParams{
|
|
AccountID: &acc.ID,
|
|
EventType: model.EventLoginFail,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("combined filter no match: %v", err)
|
|
}
|
|
if len(events) != 0 {
|
|
t.Fatalf("expected 0 events, got %d", len(events))
|
|
}
|
|
}
|
|
|
|
// ---- rekey helper tests ----
|
|
|
|
func TestListAccountsWithTOTP(t *testing.T) {
|
|
database := openTestDB(t)
|
|
|
|
// No accounts with TOTP yet.
|
|
accounts, err := database.ListAccountsWithTOTP()
|
|
if err != nil {
|
|
t.Fatalf("ListAccountsWithTOTP (empty): %v", err)
|
|
}
|
|
if len(accounts) != 0 {
|
|
t.Fatalf("expected 0 accounts, got %d", len(accounts))
|
|
}
|
|
|
|
// Create an account and store a TOTP secret.
|
|
a, err := database.CreateAccount("totpuser", model.AccountTypeHuman, "hash")
|
|
if err != nil {
|
|
t.Fatalf("create account: %v", err)
|
|
}
|
|
if err := database.SetTOTP(a.ID, []byte("enc"), []byte("nonce")); err != nil {
|
|
t.Fatalf("set TOTP: %v", err)
|
|
}
|
|
|
|
// Create another account without TOTP.
|
|
if _, err := database.CreateAccount("nototp", model.AccountTypeHuman, "hash"); err != nil {
|
|
t.Fatalf("create account: %v", err)
|
|
}
|
|
|
|
accounts, err = database.ListAccountsWithTOTP()
|
|
if err != nil {
|
|
t.Fatalf("ListAccountsWithTOTP: %v", err)
|
|
}
|
|
if len(accounts) != 1 {
|
|
t.Fatalf("expected 1 account with TOTP, got %d", len(accounts))
|
|
}
|
|
if accounts[0].ID != a.ID {
|
|
t.Errorf("expected account ID %d, got %d", a.ID, accounts[0].ID)
|
|
}
|
|
}
|
|
|
|
func TestListAllPGCredentials(t *testing.T) {
|
|
database := openTestDB(t)
|
|
|
|
creds, err := database.ListAllPGCredentials()
|
|
if err != nil {
|
|
t.Fatalf("ListAllPGCredentials (empty): %v", err)
|
|
}
|
|
if len(creds) != 0 {
|
|
t.Fatalf("expected 0 creds, got %d", len(creds))
|
|
}
|
|
|
|
a, err := database.CreateAccount("pguser", model.AccountTypeSystem, "")
|
|
if err != nil {
|
|
t.Fatalf("create account: %v", err)
|
|
}
|
|
if err := database.WritePGCredentials(a.ID, "host", 5432, "db", "user", []byte("enc"), []byte("nonce")); err != nil {
|
|
t.Fatalf("write pg credentials: %v", err)
|
|
}
|
|
|
|
creds, err = database.ListAllPGCredentials()
|
|
if err != nil {
|
|
t.Fatalf("ListAllPGCredentials: %v", err)
|
|
}
|
|
if len(creds) != 1 {
|
|
t.Fatalf("expected 1 credential, got %d", len(creds))
|
|
}
|
|
if creds[0].AccountID != a.ID {
|
|
t.Errorf("expected account ID %d, got %d", a.ID, creds[0].AccountID)
|
|
}
|
|
}
|
|
|
|
func TestRekey(t *testing.T) {
|
|
database := openTestDB(t)
|
|
|
|
// Set up: salt + signing key.
|
|
oldSalt := []byte("oldsaltoldsaltoldsaltoldsaltoldt") // 32 bytes
|
|
if err := database.WriteMasterKeySalt(oldSalt); err != nil {
|
|
t.Fatalf("write salt: %v", err)
|
|
}
|
|
if err := database.WriteServerConfig([]byte("oldenc"), []byte("oldnonce")); err != nil {
|
|
t.Fatalf("write server config: %v", err)
|
|
}
|
|
|
|
// Set up: account with TOTP.
|
|
a, err := database.CreateAccount("rekeyuser", model.AccountTypeHuman, "hash")
|
|
if err != nil {
|
|
t.Fatalf("create account: %v", err)
|
|
}
|
|
if err := database.SetTOTP(a.ID, []byte("totpenc"), []byte("totpnonce")); err != nil {
|
|
t.Fatalf("set TOTP: %v", err)
|
|
}
|
|
|
|
// Set up: pg credential.
|
|
if err := database.WritePGCredentials(a.ID, "host", 5432, "db", "user", []byte("pgenc"), []byte("pgnonce")); err != nil {
|
|
t.Fatalf("write pg creds: %v", err)
|
|
}
|
|
|
|
// Execute Rekey.
|
|
newSalt := []byte("newsaltnewsaltnewsaltnewsaltnews") // 32 bytes
|
|
totpRows := []TOTPRekeyRow{{AccountID: a.ID, Enc: []byte("newtotpenc"), Nonce: []byte("newtotpnonce")}}
|
|
pgCred, err := database.ReadPGCredentials(a.ID)
|
|
if err != nil {
|
|
t.Fatalf("read pg creds: %v", err)
|
|
}
|
|
pgRows := []PGRekeyRow{{CredentialID: pgCred.ID, Enc: []byte("newpgenc"), Nonce: []byte("newpgnonce")}}
|
|
|
|
if err := database.Rekey(newSalt, []byte("newenc"), []byte("newnonce"), totpRows, pgRows); err != nil {
|
|
t.Fatalf("Rekey: %v", err)
|
|
}
|
|
|
|
// Verify: salt replaced.
|
|
gotSalt, err := database.ReadMasterKeySalt()
|
|
if err != nil {
|
|
t.Fatalf("read salt after rekey: %v", err)
|
|
}
|
|
if string(gotSalt) != string(newSalt) {
|
|
t.Errorf("salt mismatch: got %q, want %q", gotSalt, newSalt)
|
|
}
|
|
|
|
// Verify: signing key replaced.
|
|
gotEnc, gotNonce, err := database.ReadServerConfig()
|
|
if err != nil {
|
|
t.Fatalf("read server config after rekey: %v", err)
|
|
}
|
|
if string(gotEnc) != "newenc" || string(gotNonce) != "newnonce" {
|
|
t.Errorf("signing key enc/nonce mismatch after rekey")
|
|
}
|
|
|
|
// Verify: TOTP replaced.
|
|
updatedAcct, err := database.GetAccountByID(a.ID)
|
|
if err != nil {
|
|
t.Fatalf("get account after rekey: %v", err)
|
|
}
|
|
if string(updatedAcct.TOTPSecretEnc) != "newtotpenc" || string(updatedAcct.TOTPSecretNonce) != "newtotpnonce" {
|
|
t.Errorf("TOTP enc/nonce mismatch after rekey: enc=%q nonce=%q",
|
|
updatedAcct.TOTPSecretEnc, updatedAcct.TOTPSecretNonce)
|
|
}
|
|
|
|
// Verify: pg credential replaced.
|
|
updatedCred, err := database.ReadPGCredentials(a.ID)
|
|
if err != nil {
|
|
t.Fatalf("read pg creds after rekey: %v", err)
|
|
}
|
|
if string(updatedCred.PGPasswordEnc) != "newpgenc" || string(updatedCred.PGPasswordNonce) != "newpgnonce" {
|
|
t.Errorf("pg enc/nonce mismatch after rekey: enc=%q nonce=%q",
|
|
updatedCred.PGPasswordEnc, updatedCred.PGPasswordNonce)
|
|
}
|
|
}
|
|
|
|
func TestRekeyEmptyDatabase(t *testing.T) {
|
|
database := openTestDB(t)
|
|
|
|
// Minimal setup: salt and signing key only; no TOTP, no pg creds.
|
|
salt := []byte("saltsaltsaltsaltsaltsaltsaltsalt") // 32 bytes
|
|
if err := database.WriteMasterKeySalt(salt); err != nil {
|
|
t.Fatalf("write salt: %v", err)
|
|
}
|
|
if err := database.WriteServerConfig([]byte("enc"), []byte("nonce")); err != nil {
|
|
t.Fatalf("write server config: %v", err)
|
|
}
|
|
|
|
newSalt := []byte("newsaltnewsaltnewsaltnewsaltnews") // 32 bytes
|
|
if err := database.Rekey(newSalt, []byte("newenc"), []byte("newnonce"), nil, nil); err != nil {
|
|
t.Fatalf("Rekey (empty): %v", err)
|
|
}
|
|
|
|
gotSalt, err := database.ReadMasterKeySalt()
|
|
if err != nil {
|
|
t.Fatalf("read salt: %v", err)
|
|
}
|
|
if string(gotSalt) != string(newSalt) {
|
|
t.Errorf("salt mismatch")
|
|
}
|
|
}
|
|
|
|
// TestRekeyOldSaltUnchangedOnQueryError verifies the salt and encrypted data
|
|
// is only present under the new values after a successful Rekey — the old
|
|
// values must be gone. Uses the same approach as TestRekey but reads the
|
|
// stored salt before and confirms it changes.
|
|
func TestRekeyReplacesSalt(t *testing.T) {
|
|
database := openTestDB(t)
|
|
|
|
oldSalt := []byte("oldsaltoldsaltoldsaltoldsaltoldt") // 32 bytes
|
|
if err := database.WriteMasterKeySalt(oldSalt); err != nil {
|
|
t.Fatalf("write salt: %v", err)
|
|
}
|
|
if err := database.WriteServerConfig([]byte("enc"), []byte("nonce")); err != nil {
|
|
t.Fatalf("write server config: %v", err)
|
|
}
|
|
|
|
newSalt := []byte("newsaltnewsaltnewsaltnewsaltnews") // 32 bytes
|
|
if err := database.Rekey(newSalt, []byte("newenc"), []byte("newnonce"), nil, nil); err != nil {
|
|
t.Fatalf("Rekey: %v", err)
|
|
}
|
|
|
|
gotSalt, err := database.ReadMasterKeySalt()
|
|
if err != nil {
|
|
t.Fatalf("read salt: %v", err)
|
|
}
|
|
if string(gotSalt) == string(oldSalt) {
|
|
t.Error("old salt still present after rekey")
|
|
}
|
|
if string(gotSalt) != string(newSalt) {
|
|
t.Errorf("expected new salt %q, got %q", newSalt, gotSalt)
|
|
}
|
|
}
|