- db/accounts.go: add RenewToken(oldJTI, reason, newJTI, accountID, issuedAt, expiresAt) which wraps RevokeToken + TrackToken in a single BEGIN/COMMIT transaction; if either step fails the whole tx rolls back, so the user is never left with neither old nor new token valid - server.go (handleRenewToken): replace separate RevokeToken + TrackToken calls with single RenewToken call; failure now returns 500 instead of silently losing revocation - grpcserver/auth.go (RenewToken): same replacement - db/db_test.go: TestRenewTokenAtomic verifies old token is revoked with correct reason, new token is tracked and not revoked, and a second renewal on the already-revoked old token returns an error - AUDIT.md: mark F-03 as fixed Security: without atomicity a crash/error between revoke and track could leave the old token active alongside the new one (two live tokens) or revoke the old token without tracking the new one (user locked out). The transaction ensures exactly one of the two tokens is valid at all times.
406 lines
10 KiB
Go
406 lines
10 KiB
Go
package db
|
|
|
|
import (
|
|
"errors"
|
|
"testing"
|
|
"time"
|
|
|
|
"git.wntrmute.dev/kyle/mcias/internal/model"
|
|
)
|
|
|
|
// openTestDB opens an in-memory SQLite database for testing.
|
|
func openTestDB(t *testing.T) *DB {
|
|
t.Helper()
|
|
db, err := Open(":memory:")
|
|
if err != nil {
|
|
t.Fatalf("Open: %v", err)
|
|
}
|
|
if err := Migrate(db); err != nil {
|
|
t.Fatalf("Migrate: %v", err)
|
|
}
|
|
t.Cleanup(func() { _ = db.Close() })
|
|
return db
|
|
}
|
|
|
|
func TestMigrateIdempotent(t *testing.T) {
|
|
db := openTestDB(t)
|
|
// Run again — should be a no-op.
|
|
if err := Migrate(db); err != nil {
|
|
t.Errorf("second Migrate call returned error: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestCreateAndGetAccount(t *testing.T) {
|
|
db := openTestDB(t)
|
|
|
|
acct, err := db.CreateAccount("alice", model.AccountTypeHuman, "$argon2id$v=19$...")
|
|
if err != nil {
|
|
t.Fatalf("CreateAccount: %v", err)
|
|
}
|
|
if acct.UUID == "" {
|
|
t.Error("expected non-empty UUID")
|
|
}
|
|
if acct.Username != "alice" {
|
|
t.Errorf("Username = %q, want %q", acct.Username, "alice")
|
|
}
|
|
if acct.Status != model.AccountStatusActive {
|
|
t.Errorf("Status = %q, want active", acct.Status)
|
|
}
|
|
|
|
// Retrieve by UUID.
|
|
got, err := db.GetAccountByUUID(acct.UUID)
|
|
if err != nil {
|
|
t.Fatalf("GetAccountByUUID: %v", err)
|
|
}
|
|
if got.Username != "alice" {
|
|
t.Errorf("fetched Username = %q, want %q", got.Username, "alice")
|
|
}
|
|
|
|
// Retrieve by username.
|
|
got2, err := db.GetAccountByUsername("alice")
|
|
if err != nil {
|
|
t.Fatalf("GetAccountByUsername: %v", err)
|
|
}
|
|
if got2.UUID != acct.UUID {
|
|
t.Errorf("UUID mismatch: got %q, want %q", got2.UUID, acct.UUID)
|
|
}
|
|
}
|
|
|
|
func TestGetAccountNotFound(t *testing.T) {
|
|
db := openTestDB(t)
|
|
|
|
_, err := db.GetAccountByUUID("nonexistent-uuid")
|
|
if !errors.Is(err, ErrNotFound) {
|
|
t.Errorf("expected ErrNotFound, got %v", err)
|
|
}
|
|
|
|
_, err = db.GetAccountByUsername("nobody")
|
|
if !errors.Is(err, ErrNotFound) {
|
|
t.Errorf("expected ErrNotFound, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestUpdateAccountStatus(t *testing.T) {
|
|
db := openTestDB(t)
|
|
acct, err := db.CreateAccount("bob", model.AccountTypeHuman, "hash")
|
|
if err != nil {
|
|
t.Fatalf("CreateAccount: %v", err)
|
|
}
|
|
|
|
if err := db.UpdateAccountStatus(acct.ID, model.AccountStatusInactive); err != nil {
|
|
t.Fatalf("UpdateAccountStatus: %v", err)
|
|
}
|
|
|
|
got, err := db.GetAccountByUUID(acct.UUID)
|
|
if err != nil {
|
|
t.Fatalf("GetAccountByUUID: %v", err)
|
|
}
|
|
if got.Status != model.AccountStatusInactive {
|
|
t.Errorf("Status = %q, want inactive", got.Status)
|
|
}
|
|
}
|
|
|
|
func TestListAccounts(t *testing.T) {
|
|
db := openTestDB(t)
|
|
for _, name := range []string{"charlie", "delta", "eve"} {
|
|
if _, err := db.CreateAccount(name, model.AccountTypeHuman, "hash"); err != nil {
|
|
t.Fatalf("CreateAccount %q: %v", name, err)
|
|
}
|
|
}
|
|
|
|
accts, err := db.ListAccounts()
|
|
if err != nil {
|
|
t.Fatalf("ListAccounts: %v", err)
|
|
}
|
|
if len(accts) != 3 {
|
|
t.Errorf("ListAccounts returned %d accounts, want 3", len(accts))
|
|
}
|
|
}
|
|
|
|
func TestRoleOperations(t *testing.T) {
|
|
db := openTestDB(t)
|
|
acct, err := db.CreateAccount("frank", model.AccountTypeHuman, "hash")
|
|
if err != nil {
|
|
t.Fatalf("CreateAccount: %v", err)
|
|
}
|
|
|
|
// GrantRole
|
|
if err := db.GrantRole(acct.ID, "admin", nil); err != nil {
|
|
t.Fatalf("GrantRole: %v", err)
|
|
}
|
|
// Grant again — should be no-op.
|
|
if err := db.GrantRole(acct.ID, "admin", nil); err != nil {
|
|
t.Fatalf("GrantRole duplicate: %v", err)
|
|
}
|
|
|
|
roles, err := db.GetRoles(acct.ID)
|
|
if err != nil {
|
|
t.Fatalf("GetRoles: %v", err)
|
|
}
|
|
if len(roles) != 1 || roles[0] != "admin" {
|
|
t.Errorf("GetRoles = %v, want [admin]", roles)
|
|
}
|
|
|
|
has, err := db.HasRole(acct.ID, "admin")
|
|
if err != nil {
|
|
t.Fatalf("HasRole: %v", err)
|
|
}
|
|
if !has {
|
|
t.Error("expected HasRole to return true for 'admin'")
|
|
}
|
|
|
|
// RevokeRole
|
|
if err := db.RevokeRole(acct.ID, "admin"); err != nil {
|
|
t.Fatalf("RevokeRole: %v", err)
|
|
}
|
|
roles, err = db.GetRoles(acct.ID)
|
|
if err != nil {
|
|
t.Fatalf("GetRoles after revoke: %v", err)
|
|
}
|
|
if len(roles) != 0 {
|
|
t.Errorf("expected no roles after revoke, got %v", roles)
|
|
}
|
|
|
|
// SetRoles
|
|
if err := db.SetRoles(acct.ID, []string{"reader", "writer"}, nil); err != nil {
|
|
t.Fatalf("SetRoles: %v", err)
|
|
}
|
|
roles, err = db.GetRoles(acct.ID)
|
|
if err != nil {
|
|
t.Fatalf("GetRoles after SetRoles: %v", err)
|
|
}
|
|
if len(roles) != 2 {
|
|
t.Errorf("expected 2 roles after SetRoles, got %d", len(roles))
|
|
}
|
|
}
|
|
|
|
func TestTokenTrackingAndRevocation(t *testing.T) {
|
|
db := openTestDB(t)
|
|
acct, err := db.CreateAccount("grace", model.AccountTypeHuman, "hash")
|
|
if err != nil {
|
|
t.Fatalf("CreateAccount: %v", err)
|
|
}
|
|
|
|
jti := "test-jti-1234"
|
|
issuedAt := time.Now().UTC()
|
|
expiresAt := issuedAt.Add(time.Hour)
|
|
|
|
if err := db.TrackToken(jti, acct.ID, issuedAt, expiresAt); err != nil {
|
|
t.Fatalf("TrackToken: %v", err)
|
|
}
|
|
|
|
// Retrieve
|
|
rec, err := db.GetTokenRecord(jti)
|
|
if err != nil {
|
|
t.Fatalf("GetTokenRecord: %v", err)
|
|
}
|
|
if rec.JTI != jti {
|
|
t.Errorf("JTI = %q, want %q", rec.JTI, jti)
|
|
}
|
|
if rec.IsRevoked() {
|
|
t.Error("newly tracked token should not be revoked")
|
|
}
|
|
|
|
// Revoke
|
|
if err := db.RevokeToken(jti, "test revocation"); err != nil {
|
|
t.Fatalf("RevokeToken: %v", err)
|
|
}
|
|
rec, err = db.GetTokenRecord(jti)
|
|
if err != nil {
|
|
t.Fatalf("GetTokenRecord after revoke: %v", err)
|
|
}
|
|
if !rec.IsRevoked() {
|
|
t.Error("token should be revoked after RevokeToken")
|
|
}
|
|
|
|
// Revoking again should fail (already revoked).
|
|
if err := db.RevokeToken(jti, "again"); err == nil {
|
|
t.Error("expected error when revoking already-revoked token")
|
|
}
|
|
}
|
|
|
|
func TestGetTokenRecordNotFound(t *testing.T) {
|
|
db := openTestDB(t)
|
|
_, err := db.GetTokenRecord("no-such-jti")
|
|
if !errors.Is(err, ErrNotFound) {
|
|
t.Errorf("expected ErrNotFound, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestPruneExpiredTokens(t *testing.T) {
|
|
db := openTestDB(t)
|
|
acct, err := db.CreateAccount("henry", model.AccountTypeHuman, "hash")
|
|
if err != nil {
|
|
t.Fatalf("CreateAccount: %v", err)
|
|
}
|
|
|
|
past := time.Now().UTC().Add(-time.Hour)
|
|
future := time.Now().UTC().Add(time.Hour)
|
|
|
|
if err := db.TrackToken("expired-jti", acct.ID, past.Add(-time.Hour), past); err != nil {
|
|
t.Fatalf("TrackToken expired: %v", err)
|
|
}
|
|
if err := db.TrackToken("valid-jti", acct.ID, time.Now(), future); err != nil {
|
|
t.Fatalf("TrackToken valid: %v", err)
|
|
}
|
|
|
|
n, err := db.PruneExpiredTokens()
|
|
if err != nil {
|
|
t.Fatalf("PruneExpiredTokens: %v", err)
|
|
}
|
|
if n != 1 {
|
|
t.Errorf("pruned %d rows, want 1", n)
|
|
}
|
|
|
|
// Valid token should still be retrievable.
|
|
if _, err := db.GetTokenRecord("valid-jti"); err != nil {
|
|
t.Errorf("valid token missing after prune: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestServerConfig(t *testing.T) {
|
|
db := openTestDB(t)
|
|
|
|
// No config initially.
|
|
_, _, err := db.ReadServerConfig()
|
|
if !errors.Is(err, ErrNotFound) {
|
|
t.Errorf("expected ErrNotFound for missing config, got %v", err)
|
|
}
|
|
|
|
enc := []byte("encrypted-key-data")
|
|
nonce := []byte("nonce12345678901")
|
|
|
|
if err := db.WriteServerConfig(enc, nonce); err != nil {
|
|
t.Fatalf("WriteServerConfig: %v", err)
|
|
}
|
|
|
|
gotEnc, gotNonce, err := db.ReadServerConfig()
|
|
if err != nil {
|
|
t.Fatalf("ReadServerConfig: %v", err)
|
|
}
|
|
if string(gotEnc) != string(enc) {
|
|
t.Errorf("enc mismatch: got %q, want %q", gotEnc, enc)
|
|
}
|
|
if string(gotNonce) != string(nonce) {
|
|
t.Errorf("nonce mismatch: got %q, want %q", gotNonce, nonce)
|
|
}
|
|
|
|
// Overwrite — should work without error.
|
|
if err := db.WriteServerConfig([]byte("new-key"), []byte("new-nonce123456")); err != nil {
|
|
t.Fatalf("WriteServerConfig overwrite: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestForeignKeyEnforcement(t *testing.T) {
|
|
db := openTestDB(t)
|
|
// Attempting to track a token for a non-existent account should fail.
|
|
err := db.TrackToken("jti-x", 999999, time.Now(), time.Now().Add(time.Hour))
|
|
if err == nil {
|
|
t.Error("expected foreign key error for non-existent account_id, got nil")
|
|
}
|
|
}
|
|
|
|
func TestPGCredentials(t *testing.T) {
|
|
db := openTestDB(t)
|
|
acct, err := db.CreateAccount("svc", model.AccountTypeSystem, "")
|
|
if err != nil {
|
|
t.Fatalf("CreateAccount: %v", err)
|
|
}
|
|
|
|
enc := []byte("encrypted-pg-password")
|
|
nonce := []byte("pg-nonce12345678")
|
|
|
|
if err := db.WritePGCredentials(acct.ID, "localhost", 5432, "mydb", "myuser", enc, nonce); err != nil {
|
|
t.Fatalf("WritePGCredentials: %v", err)
|
|
}
|
|
|
|
cred, err := db.ReadPGCredentials(acct.ID)
|
|
if err != nil {
|
|
t.Fatalf("ReadPGCredentials: %v", err)
|
|
}
|
|
if cred.PGHost != "localhost" {
|
|
t.Errorf("PGHost = %q, want %q", cred.PGHost, "localhost")
|
|
}
|
|
if cred.PGDatabase != "mydb" {
|
|
t.Errorf("PGDatabase = %q, want %q", cred.PGDatabase, "mydb")
|
|
}
|
|
}
|
|
|
|
func TestRenewTokenAtomic(t *testing.T) {
|
|
db := openTestDB(t)
|
|
acct, err := db.CreateAccount("judy", model.AccountTypeHuman, "hash")
|
|
if err != nil {
|
|
t.Fatalf("CreateAccount: %v", err)
|
|
}
|
|
|
|
now := time.Now().UTC()
|
|
exp := now.Add(time.Hour)
|
|
|
|
// Set up the old token.
|
|
oldJTI := "renew-old-jti"
|
|
newJTI := "renew-new-jti"
|
|
if err := db.TrackToken(oldJTI, acct.ID, now, exp); err != nil {
|
|
t.Fatalf("TrackToken old: %v", err)
|
|
}
|
|
|
|
// RenewToken should atomically revoke old and track new.
|
|
if err := db.RenewToken(oldJTI, "renewed", newJTI, acct.ID, now, exp); err != nil {
|
|
t.Fatalf("RenewToken: %v", err)
|
|
}
|
|
|
|
// Old token must be revoked.
|
|
oldRec, err := db.GetTokenRecord(oldJTI)
|
|
if err != nil {
|
|
t.Fatalf("GetTokenRecord old: %v", err)
|
|
}
|
|
if !oldRec.IsRevoked() {
|
|
t.Error("old token should be revoked after RenewToken")
|
|
}
|
|
if oldRec.RevokeReason != "renewed" {
|
|
t.Errorf("old token revoke reason = %q, want %q", oldRec.RevokeReason, "renewed")
|
|
}
|
|
|
|
// New token must be tracked and not revoked.
|
|
newRec, err := db.GetTokenRecord(newJTI)
|
|
if err != nil {
|
|
t.Fatalf("GetTokenRecord new: %v", err)
|
|
}
|
|
if newRec.IsRevoked() {
|
|
t.Error("new token should not be revoked")
|
|
}
|
|
|
|
// RenewToken on an already-revoked old token must fail (atomicity guard).
|
|
if err := db.RenewToken(oldJTI, "renewed", "other-jti", acct.ID, now, exp); err == nil {
|
|
t.Error("expected error when renewing an already-revoked token")
|
|
}
|
|
}
|
|
|
|
func TestRevokeAllUserTokens(t *testing.T) {
|
|
db := openTestDB(t)
|
|
acct, err := db.CreateAccount("ivan", model.AccountTypeHuman, "hash")
|
|
if err != nil {
|
|
t.Fatalf("CreateAccount: %v", err)
|
|
}
|
|
|
|
future := time.Now().UTC().Add(time.Hour)
|
|
for _, jti := range []string{"tok1", "tok2", "tok3"} {
|
|
if err := db.TrackToken(jti, acct.ID, time.Now(), future); err != nil {
|
|
t.Fatalf("TrackToken %q: %v", jti, err)
|
|
}
|
|
}
|
|
|
|
if err := db.RevokeAllUserTokens(acct.ID, "account suspended"); err != nil {
|
|
t.Fatalf("RevokeAllUserTokens: %v", err)
|
|
}
|
|
|
|
for _, jti := range []string{"tok1", "tok2", "tok3"} {
|
|
rec, err := db.GetTokenRecord(jti)
|
|
if err != nil {
|
|
t.Fatalf("GetTokenRecord %q: %v", jti, err)
|
|
}
|
|
if !rec.IsRevoked() {
|
|
t.Errorf("token %q should be revoked", jti)
|
|
}
|
|
}
|
|
}
|