package db import ( "errors" "testing" "time" "git.wntrmute.dev/kyle/mcias/internal/model" ) // openTestDB opens an in-memory SQLite database for testing. func openTestDB(t *testing.T) *DB { t.Helper() db, err := Open(":memory:") if err != nil { t.Fatalf("Open: %v", err) } if err := Migrate(db); err != nil { t.Fatalf("Migrate: %v", err) } t.Cleanup(func() { _ = db.Close() }) return db } func TestMigrateIdempotent(t *testing.T) { db := openTestDB(t) // Run again — should be a no-op. if err := Migrate(db); err != nil { t.Errorf("second Migrate call returned error: %v", err) } } func TestCreateAndGetAccount(t *testing.T) { db := openTestDB(t) acct, err := db.CreateAccount("alice", model.AccountTypeHuman, "$argon2id$v=19$...") if err != nil { t.Fatalf("CreateAccount: %v", err) } if acct.UUID == "" { t.Error("expected non-empty UUID") } if acct.Username != "alice" { t.Errorf("Username = %q, want %q", acct.Username, "alice") } if acct.Status != model.AccountStatusActive { t.Errorf("Status = %q, want active", acct.Status) } // Retrieve by UUID. got, err := db.GetAccountByUUID(acct.UUID) if err != nil { t.Fatalf("GetAccountByUUID: %v", err) } if got.Username != "alice" { t.Errorf("fetched Username = %q, want %q", got.Username, "alice") } // Retrieve by username. got2, err := db.GetAccountByUsername("alice") if err != nil { t.Fatalf("GetAccountByUsername: %v", err) } if got2.UUID != acct.UUID { t.Errorf("UUID mismatch: got %q, want %q", got2.UUID, acct.UUID) } } func TestGetAccountNotFound(t *testing.T) { db := openTestDB(t) _, err := db.GetAccountByUUID("nonexistent-uuid") if !errors.Is(err, ErrNotFound) { t.Errorf("expected ErrNotFound, got %v", err) } _, err = db.GetAccountByUsername("nobody") if !errors.Is(err, ErrNotFound) { t.Errorf("expected ErrNotFound, got %v", err) } } func TestUpdateAccountStatus(t *testing.T) { db := openTestDB(t) acct, err := db.CreateAccount("bob", model.AccountTypeHuman, "hash") if err != nil { t.Fatalf("CreateAccount: %v", err) } if err := db.UpdateAccountStatus(acct.ID, model.AccountStatusInactive); err != nil { t.Fatalf("UpdateAccountStatus: %v", err) } got, err := db.GetAccountByUUID(acct.UUID) if err != nil { t.Fatalf("GetAccountByUUID: %v", err) } if got.Status != model.AccountStatusInactive { t.Errorf("Status = %q, want inactive", got.Status) } } func TestListAccounts(t *testing.T) { db := openTestDB(t) for _, name := range []string{"charlie", "delta", "eve"} { if _, err := db.CreateAccount(name, model.AccountTypeHuman, "hash"); err != nil { t.Fatalf("CreateAccount %q: %v", name, err) } } accts, err := db.ListAccounts() if err != nil { t.Fatalf("ListAccounts: %v", err) } if len(accts) != 3 { t.Errorf("ListAccounts returned %d accounts, want 3", len(accts)) } } func TestRoleOperations(t *testing.T) { db := openTestDB(t) acct, err := db.CreateAccount("frank", model.AccountTypeHuman, "hash") if err != nil { t.Fatalf("CreateAccount: %v", err) } // GrantRole if err := db.GrantRole(acct.ID, "admin", nil); err != nil { t.Fatalf("GrantRole: %v", err) } // Grant again — should be no-op. if err := db.GrantRole(acct.ID, "admin", nil); err != nil { t.Fatalf("GrantRole duplicate: %v", err) } roles, err := db.GetRoles(acct.ID) if err != nil { t.Fatalf("GetRoles: %v", err) } if len(roles) != 1 || roles[0] != "admin" { t.Errorf("GetRoles = %v, want [admin]", roles) } has, err := db.HasRole(acct.ID, "admin") if err != nil { t.Fatalf("HasRole: %v", err) } if !has { t.Error("expected HasRole to return true for 'admin'") } // RevokeRole if err := db.RevokeRole(acct.ID, "admin"); err != nil { t.Fatalf("RevokeRole: %v", err) } roles, err = db.GetRoles(acct.ID) if err != nil { t.Fatalf("GetRoles after revoke: %v", err) } if len(roles) != 0 { t.Errorf("expected no roles after revoke, got %v", roles) } // SetRoles if err := db.SetRoles(acct.ID, []string{"reader", "writer"}, nil); err != nil { t.Fatalf("SetRoles: %v", err) } roles, err = db.GetRoles(acct.ID) if err != nil { t.Fatalf("GetRoles after SetRoles: %v", err) } if len(roles) != 2 { t.Errorf("expected 2 roles after SetRoles, got %d", len(roles)) } } func TestTokenTrackingAndRevocation(t *testing.T) { db := openTestDB(t) acct, err := db.CreateAccount("grace", model.AccountTypeHuman, "hash") if err != nil { t.Fatalf("CreateAccount: %v", err) } jti := "test-jti-1234" issuedAt := time.Now().UTC() expiresAt := issuedAt.Add(time.Hour) if err := db.TrackToken(jti, acct.ID, issuedAt, expiresAt); err != nil { t.Fatalf("TrackToken: %v", err) } // Retrieve rec, err := db.GetTokenRecord(jti) if err != nil { t.Fatalf("GetTokenRecord: %v", err) } if rec.JTI != jti { t.Errorf("JTI = %q, want %q", rec.JTI, jti) } if rec.IsRevoked() { t.Error("newly tracked token should not be revoked") } // Revoke if err := db.RevokeToken(jti, "test revocation"); err != nil { t.Fatalf("RevokeToken: %v", err) } rec, err = db.GetTokenRecord(jti) if err != nil { t.Fatalf("GetTokenRecord after revoke: %v", err) } if !rec.IsRevoked() { t.Error("token should be revoked after RevokeToken") } // Revoking again should fail (already revoked). if err := db.RevokeToken(jti, "again"); err == nil { t.Error("expected error when revoking already-revoked token") } } func TestGetTokenRecordNotFound(t *testing.T) { db := openTestDB(t) _, err := db.GetTokenRecord("no-such-jti") if !errors.Is(err, ErrNotFound) { t.Errorf("expected ErrNotFound, got %v", err) } } func TestPruneExpiredTokens(t *testing.T) { db := openTestDB(t) acct, err := db.CreateAccount("henry", model.AccountTypeHuman, "hash") if err != nil { t.Fatalf("CreateAccount: %v", err) } past := time.Now().UTC().Add(-time.Hour) future := time.Now().UTC().Add(time.Hour) if err := db.TrackToken("expired-jti", acct.ID, past.Add(-time.Hour), past); err != nil { t.Fatalf("TrackToken expired: %v", err) } if err := db.TrackToken("valid-jti", acct.ID, time.Now(), future); err != nil { t.Fatalf("TrackToken valid: %v", err) } n, err := db.PruneExpiredTokens() if err != nil { t.Fatalf("PruneExpiredTokens: %v", err) } if n != 1 { t.Errorf("pruned %d rows, want 1", n) } // Valid token should still be retrievable. if _, err := db.GetTokenRecord("valid-jti"); err != nil { t.Errorf("valid token missing after prune: %v", err) } } func TestServerConfig(t *testing.T) { db := openTestDB(t) // No config initially. _, _, err := db.ReadServerConfig() if !errors.Is(err, ErrNotFound) { t.Errorf("expected ErrNotFound for missing config, got %v", err) } enc := []byte("encrypted-key-data") nonce := []byte("nonce12345678901") if err := db.WriteServerConfig(enc, nonce); err != nil { t.Fatalf("WriteServerConfig: %v", err) } gotEnc, gotNonce, err := db.ReadServerConfig() if err != nil { t.Fatalf("ReadServerConfig: %v", err) } if string(gotEnc) != string(enc) { t.Errorf("enc mismatch: got %q, want %q", gotEnc, enc) } if string(gotNonce) != string(nonce) { t.Errorf("nonce mismatch: got %q, want %q", gotNonce, nonce) } // Overwrite — should work without error. if err := db.WriteServerConfig([]byte("new-key"), []byte("new-nonce123456")); err != nil { t.Fatalf("WriteServerConfig overwrite: %v", err) } } func TestForeignKeyEnforcement(t *testing.T) { db := openTestDB(t) // Attempting to track a token for a non-existent account should fail. err := db.TrackToken("jti-x", 999999, time.Now(), time.Now().Add(time.Hour)) if err == nil { t.Error("expected foreign key error for non-existent account_id, got nil") } } func TestPGCredentials(t *testing.T) { db := openTestDB(t) acct, err := db.CreateAccount("svc", model.AccountTypeSystem, "") if err != nil { t.Fatalf("CreateAccount: %v", err) } enc := []byte("encrypted-pg-password") nonce := []byte("pg-nonce12345678") if err := db.WritePGCredentials(acct.ID, "localhost", 5432, "mydb", "myuser", enc, nonce); err != nil { t.Fatalf("WritePGCredentials: %v", err) } cred, err := db.ReadPGCredentials(acct.ID) if err != nil { t.Fatalf("ReadPGCredentials: %v", err) } if cred.PGHost != "localhost" { t.Errorf("PGHost = %q, want %q", cred.PGHost, "localhost") } if cred.PGDatabase != "mydb" { t.Errorf("PGDatabase = %q, want %q", cred.PGDatabase, "mydb") } } func TestRenewTokenAtomic(t *testing.T) { db := openTestDB(t) acct, err := db.CreateAccount("judy", model.AccountTypeHuman, "hash") if err != nil { t.Fatalf("CreateAccount: %v", err) } now := time.Now().UTC() exp := now.Add(time.Hour) // Set up the old token. oldJTI := "renew-old-jti" newJTI := "renew-new-jti" if err := db.TrackToken(oldJTI, acct.ID, now, exp); err != nil { t.Fatalf("TrackToken old: %v", err) } // RenewToken should atomically revoke old and track new. if err := db.RenewToken(oldJTI, "renewed", newJTI, acct.ID, now, exp); err != nil { t.Fatalf("RenewToken: %v", err) } // Old token must be revoked. oldRec, err := db.GetTokenRecord(oldJTI) if err != nil { t.Fatalf("GetTokenRecord old: %v", err) } if !oldRec.IsRevoked() { t.Error("old token should be revoked after RenewToken") } if oldRec.RevokeReason != "renewed" { t.Errorf("old token revoke reason = %q, want %q", oldRec.RevokeReason, "renewed") } // New token must be tracked and not revoked. newRec, err := db.GetTokenRecord(newJTI) if err != nil { t.Fatalf("GetTokenRecord new: %v", err) } if newRec.IsRevoked() { t.Error("new token should not be revoked") } // RenewToken on an already-revoked old token must fail (atomicity guard). if err := db.RenewToken(oldJTI, "renewed", "other-jti", acct.ID, now, exp); err == nil { t.Error("expected error when renewing an already-revoked token") } } // TestSystemTokenRotationRevokesOld verifies the rotation pattern used by // handleIssueSystemToken (F-16): issuing a second system token revokes the first. func TestSystemTokenRotationRevokesOld(t *testing.T) { db := openTestDB(t) acct, err := db.CreateAccount("svc", model.AccountTypeSystem, "hash") if err != nil { t.Fatalf("CreateAccount: %v", err) } now := time.Now().UTC() exp := now.Add(time.Hour) // Issue first token. jti1 := "sys-tok-1" if err := db.TrackToken(jti1, acct.ID, now, exp); err != nil { t.Fatalf("TrackToken jti1: %v", err) } if err := db.SetSystemToken(acct.ID, jti1, exp); err != nil { t.Fatalf("SetSystemToken jti1: %v", err) } // Simulate token rotation: look up existing, revoke it, issue second. existing, err := db.GetSystemToken(acct.ID) if err != nil { t.Fatalf("GetSystemToken: %v", err) } if existing.JTI != jti1 { t.Errorf("expected JTI %q, got %q", jti1, existing.JTI) } _ = db.RevokeToken(existing.JTI, "rotated") jti2 := "sys-tok-2" if err := db.TrackToken(jti2, acct.ID, now, exp); err != nil { t.Fatalf("TrackToken jti2: %v", err) } if err := db.SetSystemToken(acct.ID, jti2, exp); err != nil { t.Fatalf("SetSystemToken jti2: %v", err) } // Old token must be revoked. old, err := db.GetTokenRecord(jti1) if err != nil { t.Fatalf("GetTokenRecord jti1: %v", err) } if !old.IsRevoked() { t.Error("old system token should be revoked after rotation") } if old.RevokeReason != "rotated" { t.Errorf("revoke reason = %q, want %q", old.RevokeReason, "rotated") } // New token must be active. newRec, err := db.GetTokenRecord(jti2) if err != nil { t.Fatalf("GetTokenRecord jti2: %v", err) } if newRec.IsRevoked() { t.Error("new system token should not be revoked after rotation") } // GetSystemToken must return the new JTI. cur, err := db.GetSystemToken(acct.ID) if err != nil { t.Fatalf("GetSystemToken after rotation: %v", err) } if cur.JTI != jti2 { t.Errorf("current system token JTI = %q, want %q", cur.JTI, jti2) } } func TestRevokeAllUserTokens(t *testing.T) { db := openTestDB(t) acct, err := db.CreateAccount("ivan", model.AccountTypeHuman, "hash") if err != nil { t.Fatalf("CreateAccount: %v", err) } future := time.Now().UTC().Add(time.Hour) for _, jti := range []string{"tok1", "tok2", "tok3"} { if err := db.TrackToken(jti, acct.ID, time.Now(), future); err != nil { t.Fatalf("TrackToken %q: %v", jti, err) } } if err := db.RevokeAllUserTokens(acct.ID, "account suspended"); err != nil { t.Fatalf("RevokeAllUserTokens: %v", err) } for _, jti := range []string{"tok1", "tok2", "tok3"} { rec, err := db.GetTokenRecord(jti) if err != nil { t.Fatalf("GetTokenRecord %q: %v", jti, err) } if !rec.IsRevoked() { t.Errorf("token %q should be revoked", jti) } } }