package db import ( "path/filepath" "testing" ) func openTestDB(t *testing.T) *DB { t.Helper() dir := t.TempDir() database, err := Open(filepath.Join(dir, "test.db")) if err != nil { t.Fatalf("open db: %v", err) } if err := database.Migrate(); err != nil { t.Fatalf("migrate: %v", err) } t.Cleanup(func() { _ = database.Close() }) return database } func TestCreateZone(t *testing.T) { db := openTestDB(t) zone, err := db.CreateZone("example.com", "ns1.example.com.", "admin.example.com.", 3600, 600, 86400, 300) if err != nil { t.Fatalf("create zone: %v", err) } if zone.Name != "example.com" { t.Fatalf("got name %q, want %q", zone.Name, "example.com") } if zone.Serial == 0 { t.Fatal("serial should not be zero") } if zone.PrimaryNS != "ns1.example.com." { t.Fatalf("got primary_ns %q, want %q", zone.PrimaryNS, "ns1.example.com.") } } func TestCreateZoneDuplicate(t *testing.T) { db := openTestDB(t) _, err := db.CreateZone("example.com", "ns1.example.com.", "admin.example.com.", 3600, 600, 86400, 300) if err != nil { t.Fatalf("create zone: %v", err) } _, err = db.CreateZone("example.com", "ns1.example.com.", "admin.example.com.", 3600, 600, 86400, 300) if err == nil { t.Fatal("expected error for duplicate zone") } } func TestCreateZoneNormalization(t *testing.T) { db := openTestDB(t) zone, err := db.CreateZone("Example.COM.", "ns1.example.com.", "admin.example.com.", 3600, 600, 86400, 300) if err != nil { t.Fatalf("create zone: %v", err) } if zone.Name != "example.com" { t.Fatalf("got name %q, want %q", zone.Name, "example.com") } } func TestListZones(t *testing.T) { db := openTestDB(t) _, err := db.CreateZone("b.example.com", "ns1.example.com.", "admin.example.com.", 3600, 600, 86400, 300) if err != nil { t.Fatalf("create zone b: %v", err) } _, err = db.CreateZone("a.example.com", "ns1.example.com.", "admin.example.com.", 3600, 600, 86400, 300) if err != nil { t.Fatalf("create zone a: %v", err) } zones, err := db.ListZones() if err != nil { t.Fatalf("list zones: %v", err) } if len(zones) != 2 { t.Fatalf("got %d zones, want 2", len(zones)) } if zones[0].Name != "a.example.com" { t.Fatalf("zones should be ordered by name, got %q first", zones[0].Name) } } func TestGetZone(t *testing.T) { db := openTestDB(t) _, err := db.CreateZone("example.com", "ns1.example.com.", "admin.example.com.", 3600, 600, 86400, 300) if err != nil { t.Fatalf("create zone: %v", err) } zone, err := db.GetZone("example.com") if err != nil { t.Fatalf("get zone: %v", err) } if zone.Name != "example.com" { t.Fatalf("got name %q, want %q", zone.Name, "example.com") } _, err = db.GetZone("nonexistent.com") if err != ErrNotFound { t.Fatalf("expected ErrNotFound, got %v", err) } } func TestUpdateZone(t *testing.T) { db := openTestDB(t) original, err := db.CreateZone("example.com", "ns1.example.com.", "admin.example.com.", 3600, 600, 86400, 300) if err != nil { t.Fatalf("create zone: %v", err) } updated, err := db.UpdateZone("example.com", "ns2.example.com.", "newadmin.example.com.", 7200, 1200, 172800, 600) if err != nil { t.Fatalf("update zone: %v", err) } if updated.PrimaryNS != "ns2.example.com." { t.Fatalf("got primary_ns %q, want %q", updated.PrimaryNS, "ns2.example.com.") } if updated.Serial <= original.Serial { t.Fatalf("serial should have incremented: %d <= %d", updated.Serial, original.Serial) } } func TestDeleteZone(t *testing.T) { db := openTestDB(t) _, err := db.CreateZone("example.com", "ns1.example.com.", "admin.example.com.", 3600, 600, 86400, 300) if err != nil { t.Fatalf("create zone: %v", err) } if err := db.DeleteZone("example.com"); err != nil { t.Fatalf("delete zone: %v", err) } _, err = db.GetZone("example.com") if err != ErrNotFound { t.Fatalf("expected ErrNotFound after delete, got %v", err) } if err := db.DeleteZone("nonexistent.com"); err != ErrNotFound { t.Fatalf("expected ErrNotFound for nonexistent zone, got %v", err) } } func TestNextSerial(t *testing.T) { // A zero serial should produce a date-based serial. s1 := nextSerial(0) if s1 < 2026032600 { t.Fatalf("serial %d seems too low", s1) } // Incrementing should increase. s2 := nextSerial(s1) if s2 != s1+1 { t.Fatalf("expected %d, got %d", s1+1, s2) } }