356 lines
9.0 KiB
Go
356 lines
9.0 KiB
Go
package db
|
|
|
|
import (
|
|
"testing"
|
|
"time"
|
|
|
|
"git.wntrmute.dev/kyle/mcias/internal/model"
|
|
)
|
|
|
|
// openTestDB opens an in-memory SQLite database for testing.
|
|
func openTestDB(t *testing.T) *DB {
|
|
t.Helper()
|
|
db, err := Open(":memory:")
|
|
if err != nil {
|
|
t.Fatalf("Open: %v", err)
|
|
}
|
|
if err := Migrate(db); err != nil {
|
|
t.Fatalf("Migrate: %v", err)
|
|
}
|
|
t.Cleanup(func() { _ = db.Close() })
|
|
return db
|
|
}
|
|
|
|
func TestMigrateIdempotent(t *testing.T) {
|
|
db := openTestDB(t)
|
|
// Run again — should be a no-op.
|
|
if err := Migrate(db); err != nil {
|
|
t.Errorf("second Migrate call returned error: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestCreateAndGetAccount(t *testing.T) {
|
|
db := openTestDB(t)
|
|
|
|
acct, err := db.CreateAccount("alice", model.AccountTypeHuman, "$argon2id$v=19$...")
|
|
if err != nil {
|
|
t.Fatalf("CreateAccount: %v", err)
|
|
}
|
|
if acct.UUID == "" {
|
|
t.Error("expected non-empty UUID")
|
|
}
|
|
if acct.Username != "alice" {
|
|
t.Errorf("Username = %q, want %q", acct.Username, "alice")
|
|
}
|
|
if acct.Status != model.AccountStatusActive {
|
|
t.Errorf("Status = %q, want active", acct.Status)
|
|
}
|
|
|
|
// Retrieve by UUID.
|
|
got, err := db.GetAccountByUUID(acct.UUID)
|
|
if err != nil {
|
|
t.Fatalf("GetAccountByUUID: %v", err)
|
|
}
|
|
if got.Username != "alice" {
|
|
t.Errorf("fetched Username = %q, want %q", got.Username, "alice")
|
|
}
|
|
|
|
// Retrieve by username.
|
|
got2, err := db.GetAccountByUsername("alice")
|
|
if err != nil {
|
|
t.Fatalf("GetAccountByUsername: %v", err)
|
|
}
|
|
if got2.UUID != acct.UUID {
|
|
t.Errorf("UUID mismatch: got %q, want %q", got2.UUID, acct.UUID)
|
|
}
|
|
}
|
|
|
|
func TestGetAccountNotFound(t *testing.T) {
|
|
db := openTestDB(t)
|
|
|
|
_, err := db.GetAccountByUUID("nonexistent-uuid")
|
|
if err != ErrNotFound {
|
|
t.Errorf("expected ErrNotFound, got %v", err)
|
|
}
|
|
|
|
_, err = db.GetAccountByUsername("nobody")
|
|
if err != ErrNotFound {
|
|
t.Errorf("expected ErrNotFound, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestUpdateAccountStatus(t *testing.T) {
|
|
db := openTestDB(t)
|
|
acct, err := db.CreateAccount("bob", model.AccountTypeHuman, "hash")
|
|
if err != nil {
|
|
t.Fatalf("CreateAccount: %v", err)
|
|
}
|
|
|
|
if err := db.UpdateAccountStatus(acct.ID, model.AccountStatusInactive); err != nil {
|
|
t.Fatalf("UpdateAccountStatus: %v", err)
|
|
}
|
|
|
|
got, err := db.GetAccountByUUID(acct.UUID)
|
|
if err != nil {
|
|
t.Fatalf("GetAccountByUUID: %v", err)
|
|
}
|
|
if got.Status != model.AccountStatusInactive {
|
|
t.Errorf("Status = %q, want inactive", got.Status)
|
|
}
|
|
}
|
|
|
|
func TestListAccounts(t *testing.T) {
|
|
db := openTestDB(t)
|
|
for _, name := range []string{"charlie", "delta", "eve"} {
|
|
if _, err := db.CreateAccount(name, model.AccountTypeHuman, "hash"); err != nil {
|
|
t.Fatalf("CreateAccount %q: %v", name, err)
|
|
}
|
|
}
|
|
|
|
accts, err := db.ListAccounts()
|
|
if err != nil {
|
|
t.Fatalf("ListAccounts: %v", err)
|
|
}
|
|
if len(accts) != 3 {
|
|
t.Errorf("ListAccounts returned %d accounts, want 3", len(accts))
|
|
}
|
|
}
|
|
|
|
func TestRoleOperations(t *testing.T) {
|
|
db := openTestDB(t)
|
|
acct, err := db.CreateAccount("frank", model.AccountTypeHuman, "hash")
|
|
if err != nil {
|
|
t.Fatalf("CreateAccount: %v", err)
|
|
}
|
|
|
|
// GrantRole
|
|
if err := db.GrantRole(acct.ID, "admin", nil); err != nil {
|
|
t.Fatalf("GrantRole: %v", err)
|
|
}
|
|
// Grant again — should be no-op.
|
|
if err := db.GrantRole(acct.ID, "admin", nil); err != nil {
|
|
t.Fatalf("GrantRole duplicate: %v", err)
|
|
}
|
|
|
|
roles, err := db.GetRoles(acct.ID)
|
|
if err != nil {
|
|
t.Fatalf("GetRoles: %v", err)
|
|
}
|
|
if len(roles) != 1 || roles[0] != "admin" {
|
|
t.Errorf("GetRoles = %v, want [admin]", roles)
|
|
}
|
|
|
|
has, err := db.HasRole(acct.ID, "admin")
|
|
if err != nil {
|
|
t.Fatalf("HasRole: %v", err)
|
|
}
|
|
if !has {
|
|
t.Error("expected HasRole to return true for 'admin'")
|
|
}
|
|
|
|
// RevokeRole
|
|
if err := db.RevokeRole(acct.ID, "admin"); err != nil {
|
|
t.Fatalf("RevokeRole: %v", err)
|
|
}
|
|
roles, err = db.GetRoles(acct.ID)
|
|
if err != nil {
|
|
t.Fatalf("GetRoles after revoke: %v", err)
|
|
}
|
|
if len(roles) != 0 {
|
|
t.Errorf("expected no roles after revoke, got %v", roles)
|
|
}
|
|
|
|
// SetRoles
|
|
if err := db.SetRoles(acct.ID, []string{"reader", "writer"}, nil); err != nil {
|
|
t.Fatalf("SetRoles: %v", err)
|
|
}
|
|
roles, err = db.GetRoles(acct.ID)
|
|
if err != nil {
|
|
t.Fatalf("GetRoles after SetRoles: %v", err)
|
|
}
|
|
if len(roles) != 2 {
|
|
t.Errorf("expected 2 roles after SetRoles, got %d", len(roles))
|
|
}
|
|
}
|
|
|
|
func TestTokenTrackingAndRevocation(t *testing.T) {
|
|
db := openTestDB(t)
|
|
acct, err := db.CreateAccount("grace", model.AccountTypeHuman, "hash")
|
|
if err != nil {
|
|
t.Fatalf("CreateAccount: %v", err)
|
|
}
|
|
|
|
jti := "test-jti-1234"
|
|
issuedAt := time.Now().UTC()
|
|
expiresAt := issuedAt.Add(time.Hour)
|
|
|
|
if err := db.TrackToken(jti, acct.ID, issuedAt, expiresAt); err != nil {
|
|
t.Fatalf("TrackToken: %v", err)
|
|
}
|
|
|
|
// Retrieve
|
|
rec, err := db.GetTokenRecord(jti)
|
|
if err != nil {
|
|
t.Fatalf("GetTokenRecord: %v", err)
|
|
}
|
|
if rec.JTI != jti {
|
|
t.Errorf("JTI = %q, want %q", rec.JTI, jti)
|
|
}
|
|
if rec.IsRevoked() {
|
|
t.Error("newly tracked token should not be revoked")
|
|
}
|
|
|
|
// Revoke
|
|
if err := db.RevokeToken(jti, "test revocation"); err != nil {
|
|
t.Fatalf("RevokeToken: %v", err)
|
|
}
|
|
rec, err = db.GetTokenRecord(jti)
|
|
if err != nil {
|
|
t.Fatalf("GetTokenRecord after revoke: %v", err)
|
|
}
|
|
if !rec.IsRevoked() {
|
|
t.Error("token should be revoked after RevokeToken")
|
|
}
|
|
|
|
// Revoking again should fail (already revoked).
|
|
if err := db.RevokeToken(jti, "again"); err == nil {
|
|
t.Error("expected error when revoking already-revoked token")
|
|
}
|
|
}
|
|
|
|
func TestGetTokenRecordNotFound(t *testing.T) {
|
|
db := openTestDB(t)
|
|
_, err := db.GetTokenRecord("no-such-jti")
|
|
if err != ErrNotFound {
|
|
t.Errorf("expected ErrNotFound, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestPruneExpiredTokens(t *testing.T) {
|
|
db := openTestDB(t)
|
|
acct, err := db.CreateAccount("henry", model.AccountTypeHuman, "hash")
|
|
if err != nil {
|
|
t.Fatalf("CreateAccount: %v", err)
|
|
}
|
|
|
|
past := time.Now().UTC().Add(-time.Hour)
|
|
future := time.Now().UTC().Add(time.Hour)
|
|
|
|
if err := db.TrackToken("expired-jti", acct.ID, past.Add(-time.Hour), past); err != nil {
|
|
t.Fatalf("TrackToken expired: %v", err)
|
|
}
|
|
if err := db.TrackToken("valid-jti", acct.ID, time.Now(), future); err != nil {
|
|
t.Fatalf("TrackToken valid: %v", err)
|
|
}
|
|
|
|
n, err := db.PruneExpiredTokens()
|
|
if err != nil {
|
|
t.Fatalf("PruneExpiredTokens: %v", err)
|
|
}
|
|
if n != 1 {
|
|
t.Errorf("pruned %d rows, want 1", n)
|
|
}
|
|
|
|
// Valid token should still be retrievable.
|
|
if _, err := db.GetTokenRecord("valid-jti"); err != nil {
|
|
t.Errorf("valid token missing after prune: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestServerConfig(t *testing.T) {
|
|
db := openTestDB(t)
|
|
|
|
// No config initially.
|
|
_, _, err := db.ReadServerConfig()
|
|
if err != ErrNotFound {
|
|
t.Errorf("expected ErrNotFound for missing config, got %v", err)
|
|
}
|
|
|
|
enc := []byte("encrypted-key-data")
|
|
nonce := []byte("nonce12345678901")
|
|
|
|
if err := db.WriteServerConfig(enc, nonce); err != nil {
|
|
t.Fatalf("WriteServerConfig: %v", err)
|
|
}
|
|
|
|
gotEnc, gotNonce, err := db.ReadServerConfig()
|
|
if err != nil {
|
|
t.Fatalf("ReadServerConfig: %v", err)
|
|
}
|
|
if string(gotEnc) != string(enc) {
|
|
t.Errorf("enc mismatch: got %q, want %q", gotEnc, enc)
|
|
}
|
|
if string(gotNonce) != string(nonce) {
|
|
t.Errorf("nonce mismatch: got %q, want %q", gotNonce, nonce)
|
|
}
|
|
|
|
// Overwrite — should work without error.
|
|
if err := db.WriteServerConfig([]byte("new-key"), []byte("new-nonce123456")); err != nil {
|
|
t.Fatalf("WriteServerConfig overwrite: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestForeignKeyEnforcement(t *testing.T) {
|
|
db := openTestDB(t)
|
|
// Attempting to track a token for a non-existent account should fail.
|
|
err := db.TrackToken("jti-x", 999999, time.Now(), time.Now().Add(time.Hour))
|
|
if err == nil {
|
|
t.Error("expected foreign key error for non-existent account_id, got nil")
|
|
}
|
|
}
|
|
|
|
func TestPGCredentials(t *testing.T) {
|
|
db := openTestDB(t)
|
|
acct, err := db.CreateAccount("svc", model.AccountTypeSystem, "")
|
|
if err != nil {
|
|
t.Fatalf("CreateAccount: %v", err)
|
|
}
|
|
|
|
enc := []byte("encrypted-pg-password")
|
|
nonce := []byte("pg-nonce12345678")
|
|
|
|
if err := db.WritePGCredentials(acct.ID, "localhost", 5432, "mydb", "myuser", enc, nonce); err != nil {
|
|
t.Fatalf("WritePGCredentials: %v", err)
|
|
}
|
|
|
|
cred, err := db.ReadPGCredentials(acct.ID)
|
|
if err != nil {
|
|
t.Fatalf("ReadPGCredentials: %v", err)
|
|
}
|
|
if cred.PGHost != "localhost" {
|
|
t.Errorf("PGHost = %q, want %q", cred.PGHost, "localhost")
|
|
}
|
|
if cred.PGDatabase != "mydb" {
|
|
t.Errorf("PGDatabase = %q, want %q", cred.PGDatabase, "mydb")
|
|
}
|
|
}
|
|
|
|
func TestRevokeAllUserTokens(t *testing.T) {
|
|
db := openTestDB(t)
|
|
acct, err := db.CreateAccount("ivan", model.AccountTypeHuman, "hash")
|
|
if err != nil {
|
|
t.Fatalf("CreateAccount: %v", err)
|
|
}
|
|
|
|
future := time.Now().UTC().Add(time.Hour)
|
|
for _, jti := range []string{"tok1", "tok2", "tok3"} {
|
|
if err := db.TrackToken(jti, acct.ID, time.Now(), future); err != nil {
|
|
t.Fatalf("TrackToken %q: %v", jti, err)
|
|
}
|
|
}
|
|
|
|
if err := db.RevokeAllUserTokens(acct.ID, "account suspended"); err != nil {
|
|
t.Fatalf("RevokeAllUserTokens: %v", err)
|
|
}
|
|
|
|
for _, jti := range []string{"tok1", "tok2", "tok3"} {
|
|
rec, err := db.GetTokenRecord(jti)
|
|
if err != nil {
|
|
t.Fatalf("GetTokenRecord %q: %v", jti, err)
|
|
}
|
|
if !rec.IsRevoked() {
|
|
t.Errorf("token %q should be revoked", jti)
|
|
}
|
|
}
|
|
}
|