Files
mcias/internal/db/db_test.go
Kyle Isom 005e734842 Fix F-16: revoke old system token before issuing new one
- ui/handlers_accounts.go (handleIssueSystemToken): call
  GetSystemToken before issuing; if one exists, call
  RevokeToken(existing.JTI, "rotated") before TrackToken
  and SetSystemToken for the new token; mirrors the pattern
  in REST handleTokenIssue and gRPC IssueServiceToken
- db/db_test.go: TestSystemTokenRotationRevokesOld verifies
  the full rotation flow: old JTI revoked with reason
  "rotated", new JTI tracked and active, GetSystemToken
  returns the new JTI
- AUDIT.md: mark F-16 as fixed
Security: without this fix an old system token remained valid
  after rotation until its natural expiry, giving a leaked or
  stolen old token extra lifetime. With the revocation the old
  JTI is immediately marked in token_revocation so any validator
  checking revocation status rejects it.
2026-03-11 20:34:57 -07:00

476 lines
12 KiB
Go

package db
import (
"errors"
"testing"
"time"
"git.wntrmute.dev/kyle/mcias/internal/model"
)
// openTestDB opens an in-memory SQLite database for testing.
func openTestDB(t *testing.T) *DB {
t.Helper()
db, err := Open(":memory:")
if err != nil {
t.Fatalf("Open: %v", err)
}
if err := Migrate(db); err != nil {
t.Fatalf("Migrate: %v", err)
}
t.Cleanup(func() { _ = db.Close() })
return db
}
func TestMigrateIdempotent(t *testing.T) {
db := openTestDB(t)
// Run again — should be a no-op.
if err := Migrate(db); err != nil {
t.Errorf("second Migrate call returned error: %v", err)
}
}
func TestCreateAndGetAccount(t *testing.T) {
db := openTestDB(t)
acct, err := db.CreateAccount("alice", model.AccountTypeHuman, "$argon2id$v=19$...")
if err != nil {
t.Fatalf("CreateAccount: %v", err)
}
if acct.UUID == "" {
t.Error("expected non-empty UUID")
}
if acct.Username != "alice" {
t.Errorf("Username = %q, want %q", acct.Username, "alice")
}
if acct.Status != model.AccountStatusActive {
t.Errorf("Status = %q, want active", acct.Status)
}
// Retrieve by UUID.
got, err := db.GetAccountByUUID(acct.UUID)
if err != nil {
t.Fatalf("GetAccountByUUID: %v", err)
}
if got.Username != "alice" {
t.Errorf("fetched Username = %q, want %q", got.Username, "alice")
}
// Retrieve by username.
got2, err := db.GetAccountByUsername("alice")
if err != nil {
t.Fatalf("GetAccountByUsername: %v", err)
}
if got2.UUID != acct.UUID {
t.Errorf("UUID mismatch: got %q, want %q", got2.UUID, acct.UUID)
}
}
func TestGetAccountNotFound(t *testing.T) {
db := openTestDB(t)
_, err := db.GetAccountByUUID("nonexistent-uuid")
if !errors.Is(err, ErrNotFound) {
t.Errorf("expected ErrNotFound, got %v", err)
}
_, err = db.GetAccountByUsername("nobody")
if !errors.Is(err, ErrNotFound) {
t.Errorf("expected ErrNotFound, got %v", err)
}
}
func TestUpdateAccountStatus(t *testing.T) {
db := openTestDB(t)
acct, err := db.CreateAccount("bob", model.AccountTypeHuman, "hash")
if err != nil {
t.Fatalf("CreateAccount: %v", err)
}
if err := db.UpdateAccountStatus(acct.ID, model.AccountStatusInactive); err != nil {
t.Fatalf("UpdateAccountStatus: %v", err)
}
got, err := db.GetAccountByUUID(acct.UUID)
if err != nil {
t.Fatalf("GetAccountByUUID: %v", err)
}
if got.Status != model.AccountStatusInactive {
t.Errorf("Status = %q, want inactive", got.Status)
}
}
func TestListAccounts(t *testing.T) {
db := openTestDB(t)
for _, name := range []string{"charlie", "delta", "eve"} {
if _, err := db.CreateAccount(name, model.AccountTypeHuman, "hash"); err != nil {
t.Fatalf("CreateAccount %q: %v", name, err)
}
}
accts, err := db.ListAccounts()
if err != nil {
t.Fatalf("ListAccounts: %v", err)
}
if len(accts) != 3 {
t.Errorf("ListAccounts returned %d accounts, want 3", len(accts))
}
}
func TestRoleOperations(t *testing.T) {
db := openTestDB(t)
acct, err := db.CreateAccount("frank", model.AccountTypeHuman, "hash")
if err != nil {
t.Fatalf("CreateAccount: %v", err)
}
// GrantRole
if err := db.GrantRole(acct.ID, "admin", nil); err != nil {
t.Fatalf("GrantRole: %v", err)
}
// Grant again — should be no-op.
if err := db.GrantRole(acct.ID, "admin", nil); err != nil {
t.Fatalf("GrantRole duplicate: %v", err)
}
roles, err := db.GetRoles(acct.ID)
if err != nil {
t.Fatalf("GetRoles: %v", err)
}
if len(roles) != 1 || roles[0] != "admin" {
t.Errorf("GetRoles = %v, want [admin]", roles)
}
has, err := db.HasRole(acct.ID, "admin")
if err != nil {
t.Fatalf("HasRole: %v", err)
}
if !has {
t.Error("expected HasRole to return true for 'admin'")
}
// RevokeRole
if err := db.RevokeRole(acct.ID, "admin"); err != nil {
t.Fatalf("RevokeRole: %v", err)
}
roles, err = db.GetRoles(acct.ID)
if err != nil {
t.Fatalf("GetRoles after revoke: %v", err)
}
if len(roles) != 0 {
t.Errorf("expected no roles after revoke, got %v", roles)
}
// SetRoles
if err := db.SetRoles(acct.ID, []string{"reader", "writer"}, nil); err != nil {
t.Fatalf("SetRoles: %v", err)
}
roles, err = db.GetRoles(acct.ID)
if err != nil {
t.Fatalf("GetRoles after SetRoles: %v", err)
}
if len(roles) != 2 {
t.Errorf("expected 2 roles after SetRoles, got %d", len(roles))
}
}
func TestTokenTrackingAndRevocation(t *testing.T) {
db := openTestDB(t)
acct, err := db.CreateAccount("grace", model.AccountTypeHuman, "hash")
if err != nil {
t.Fatalf("CreateAccount: %v", err)
}
jti := "test-jti-1234"
issuedAt := time.Now().UTC()
expiresAt := issuedAt.Add(time.Hour)
if err := db.TrackToken(jti, acct.ID, issuedAt, expiresAt); err != nil {
t.Fatalf("TrackToken: %v", err)
}
// Retrieve
rec, err := db.GetTokenRecord(jti)
if err != nil {
t.Fatalf("GetTokenRecord: %v", err)
}
if rec.JTI != jti {
t.Errorf("JTI = %q, want %q", rec.JTI, jti)
}
if rec.IsRevoked() {
t.Error("newly tracked token should not be revoked")
}
// Revoke
if err := db.RevokeToken(jti, "test revocation"); err != nil {
t.Fatalf("RevokeToken: %v", err)
}
rec, err = db.GetTokenRecord(jti)
if err != nil {
t.Fatalf("GetTokenRecord after revoke: %v", err)
}
if !rec.IsRevoked() {
t.Error("token should be revoked after RevokeToken")
}
// Revoking again should fail (already revoked).
if err := db.RevokeToken(jti, "again"); err == nil {
t.Error("expected error when revoking already-revoked token")
}
}
func TestGetTokenRecordNotFound(t *testing.T) {
db := openTestDB(t)
_, err := db.GetTokenRecord("no-such-jti")
if !errors.Is(err, ErrNotFound) {
t.Errorf("expected ErrNotFound, got %v", err)
}
}
func TestPruneExpiredTokens(t *testing.T) {
db := openTestDB(t)
acct, err := db.CreateAccount("henry", model.AccountTypeHuman, "hash")
if err != nil {
t.Fatalf("CreateAccount: %v", err)
}
past := time.Now().UTC().Add(-time.Hour)
future := time.Now().UTC().Add(time.Hour)
if err := db.TrackToken("expired-jti", acct.ID, past.Add(-time.Hour), past); err != nil {
t.Fatalf("TrackToken expired: %v", err)
}
if err := db.TrackToken("valid-jti", acct.ID, time.Now(), future); err != nil {
t.Fatalf("TrackToken valid: %v", err)
}
n, err := db.PruneExpiredTokens()
if err != nil {
t.Fatalf("PruneExpiredTokens: %v", err)
}
if n != 1 {
t.Errorf("pruned %d rows, want 1", n)
}
// Valid token should still be retrievable.
if _, err := db.GetTokenRecord("valid-jti"); err != nil {
t.Errorf("valid token missing after prune: %v", err)
}
}
func TestServerConfig(t *testing.T) {
db := openTestDB(t)
// No config initially.
_, _, err := db.ReadServerConfig()
if !errors.Is(err, ErrNotFound) {
t.Errorf("expected ErrNotFound for missing config, got %v", err)
}
enc := []byte("encrypted-key-data")
nonce := []byte("nonce12345678901")
if err := db.WriteServerConfig(enc, nonce); err != nil {
t.Fatalf("WriteServerConfig: %v", err)
}
gotEnc, gotNonce, err := db.ReadServerConfig()
if err != nil {
t.Fatalf("ReadServerConfig: %v", err)
}
if string(gotEnc) != string(enc) {
t.Errorf("enc mismatch: got %q, want %q", gotEnc, enc)
}
if string(gotNonce) != string(nonce) {
t.Errorf("nonce mismatch: got %q, want %q", gotNonce, nonce)
}
// Overwrite — should work without error.
if err := db.WriteServerConfig([]byte("new-key"), []byte("new-nonce123456")); err != nil {
t.Fatalf("WriteServerConfig overwrite: %v", err)
}
}
func TestForeignKeyEnforcement(t *testing.T) {
db := openTestDB(t)
// Attempting to track a token for a non-existent account should fail.
err := db.TrackToken("jti-x", 999999, time.Now(), time.Now().Add(time.Hour))
if err == nil {
t.Error("expected foreign key error for non-existent account_id, got nil")
}
}
func TestPGCredentials(t *testing.T) {
db := openTestDB(t)
acct, err := db.CreateAccount("svc", model.AccountTypeSystem, "")
if err != nil {
t.Fatalf("CreateAccount: %v", err)
}
enc := []byte("encrypted-pg-password")
nonce := []byte("pg-nonce12345678")
if err := db.WritePGCredentials(acct.ID, "localhost", 5432, "mydb", "myuser", enc, nonce); err != nil {
t.Fatalf("WritePGCredentials: %v", err)
}
cred, err := db.ReadPGCredentials(acct.ID)
if err != nil {
t.Fatalf("ReadPGCredentials: %v", err)
}
if cred.PGHost != "localhost" {
t.Errorf("PGHost = %q, want %q", cred.PGHost, "localhost")
}
if cred.PGDatabase != "mydb" {
t.Errorf("PGDatabase = %q, want %q", cred.PGDatabase, "mydb")
}
}
func TestRenewTokenAtomic(t *testing.T) {
db := openTestDB(t)
acct, err := db.CreateAccount("judy", model.AccountTypeHuman, "hash")
if err != nil {
t.Fatalf("CreateAccount: %v", err)
}
now := time.Now().UTC()
exp := now.Add(time.Hour)
// Set up the old token.
oldJTI := "renew-old-jti"
newJTI := "renew-new-jti"
if err := db.TrackToken(oldJTI, acct.ID, now, exp); err != nil {
t.Fatalf("TrackToken old: %v", err)
}
// RenewToken should atomically revoke old and track new.
if err := db.RenewToken(oldJTI, "renewed", newJTI, acct.ID, now, exp); err != nil {
t.Fatalf("RenewToken: %v", err)
}
// Old token must be revoked.
oldRec, err := db.GetTokenRecord(oldJTI)
if err != nil {
t.Fatalf("GetTokenRecord old: %v", err)
}
if !oldRec.IsRevoked() {
t.Error("old token should be revoked after RenewToken")
}
if oldRec.RevokeReason != "renewed" {
t.Errorf("old token revoke reason = %q, want %q", oldRec.RevokeReason, "renewed")
}
// New token must be tracked and not revoked.
newRec, err := db.GetTokenRecord(newJTI)
if err != nil {
t.Fatalf("GetTokenRecord new: %v", err)
}
if newRec.IsRevoked() {
t.Error("new token should not be revoked")
}
// RenewToken on an already-revoked old token must fail (atomicity guard).
if err := db.RenewToken(oldJTI, "renewed", "other-jti", acct.ID, now, exp); err == nil {
t.Error("expected error when renewing an already-revoked token")
}
}
// TestSystemTokenRotationRevokesOld verifies the rotation pattern used by
// handleIssueSystemToken (F-16): issuing a second system token revokes the first.
func TestSystemTokenRotationRevokesOld(t *testing.T) {
db := openTestDB(t)
acct, err := db.CreateAccount("svc", model.AccountTypeSystem, "hash")
if err != nil {
t.Fatalf("CreateAccount: %v", err)
}
now := time.Now().UTC()
exp := now.Add(time.Hour)
// Issue first token.
jti1 := "sys-tok-1"
if err := db.TrackToken(jti1, acct.ID, now, exp); err != nil {
t.Fatalf("TrackToken jti1: %v", err)
}
if err := db.SetSystemToken(acct.ID, jti1, exp); err != nil {
t.Fatalf("SetSystemToken jti1: %v", err)
}
// Simulate token rotation: look up existing, revoke it, issue second.
existing, err := db.GetSystemToken(acct.ID)
if err != nil {
t.Fatalf("GetSystemToken: %v", err)
}
if existing.JTI != jti1 {
t.Errorf("expected JTI %q, got %q", jti1, existing.JTI)
}
_ = db.RevokeToken(existing.JTI, "rotated")
jti2 := "sys-tok-2"
if err := db.TrackToken(jti2, acct.ID, now, exp); err != nil {
t.Fatalf("TrackToken jti2: %v", err)
}
if err := db.SetSystemToken(acct.ID, jti2, exp); err != nil {
t.Fatalf("SetSystemToken jti2: %v", err)
}
// Old token must be revoked.
old, err := db.GetTokenRecord(jti1)
if err != nil {
t.Fatalf("GetTokenRecord jti1: %v", err)
}
if !old.IsRevoked() {
t.Error("old system token should be revoked after rotation")
}
if old.RevokeReason != "rotated" {
t.Errorf("revoke reason = %q, want %q", old.RevokeReason, "rotated")
}
// New token must be active.
newRec, err := db.GetTokenRecord(jti2)
if err != nil {
t.Fatalf("GetTokenRecord jti2: %v", err)
}
if newRec.IsRevoked() {
t.Error("new system token should not be revoked after rotation")
}
// GetSystemToken must return the new JTI.
cur, err := db.GetSystemToken(acct.ID)
if err != nil {
t.Fatalf("GetSystemToken after rotation: %v", err)
}
if cur.JTI != jti2 {
t.Errorf("current system token JTI = %q, want %q", cur.JTI, jti2)
}
}
func TestRevokeAllUserTokens(t *testing.T) {
db := openTestDB(t)
acct, err := db.CreateAccount("ivan", model.AccountTypeHuman, "hash")
if err != nil {
t.Fatalf("CreateAccount: %v", err)
}
future := time.Now().UTC().Add(time.Hour)
for _, jti := range []string{"tok1", "tok2", "tok3"} {
if err := db.TrackToken(jti, acct.ID, time.Now(), future); err != nil {
t.Fatalf("TrackToken %q: %v", jti, err)
}
}
if err := db.RevokeAllUserTokens(acct.ID, "account suspended"); err != nil {
t.Fatalf("RevokeAllUserTokens: %v", err)
}
for _, jti := range []string{"tok1", "tok2", "tok3"} {
rec, err := db.GetTokenRecord(jti)
if err != nil {
t.Fatalf("GetTokenRecord %q: %v", jti, err)
}
if !rec.IsRevoked() {
t.Errorf("token %q should be revoked", jti)
}
}
}