package db import ( "context" "database/sql" "errors" "os" "path/filepath" "testing" "time" ) func tempDB(t *testing.T) string { t.Helper() dir := t.TempDir() return filepath.Join(dir, "test.db") } func mustOpen(t *testing.T) *sql.DB { t.Helper() path := tempDB(t) database, err := Open(path) if err != nil { t.Fatalf("Open failed: %v", err) } t.Cleanup(func() { _ = database.Close() }) return database } func mustOpenAndMigrate(t *testing.T) *sql.DB { t.Helper() database := mustOpen(t) if err := Migrate(database); err != nil { t.Fatalf("Migrate failed: %v", err) } return database } func TestOpenAndPing(t *testing.T) { database := mustOpen(t) if err := database.Ping(); err != nil { t.Fatalf("Ping failed: %v", err) } } func TestOpenCreatesFile(t *testing.T) { path := tempDB(t) database, err := Open(path) if err != nil { t.Fatalf("Open failed: %v", err) } t.Cleanup(func() { _ = database.Close() }) if _, err := os.Stat(path); os.IsNotExist(err) { t.Fatal("database file was not created") } } func TestMigrate(t *testing.T) { database := mustOpenAndMigrate(t) tables := []string{ "metadata", "tags", "categories", "publishers", "citations", "authors", "artifacts", "artifact_tags", "artifact_categories", "artifacts_history", "artifact_snapshots", "blobs", "schema_version", } for _, table := range tables { var name string row := database.QueryRow(`SELECT name FROM sqlite_master WHERE type='table' AND name=?`, table) if err := row.Scan(&name); err != nil { t.Errorf("table %q not found after migration: %v", table, err) } } } func TestMigrateIdempotent(t *testing.T) { database := mustOpenAndMigrate(t) if err := Migrate(database); err != nil { t.Fatalf("second Migrate failed: %v", err) } } func TestStartTXAndEndTX(t *testing.T) { database := mustOpenAndMigrate(t) ctx := context.Background() tx, err := StartTX(ctx, database) if err != nil { t.Fatalf("StartTX failed: %v", err) } _, err = tx.ExecContext(ctx, `INSERT INTO tags (id, tag) VALUES ('test-id', 'test-tag')`) if err != nil { t.Fatalf("INSERT failed: %v", err) } if err := EndTX(tx, nil); err != nil { t.Fatalf("EndTX (commit) failed: %v", err) } var tag string row := database.QueryRow(`SELECT tag FROM tags WHERE id='test-id'`) if err := row.Scan(&tag); err != nil { t.Fatalf("committed row not found: %v", err) } if tag != "test-tag" { t.Fatalf("expected 'test-tag', got %q", tag) } } func TestEndTXRollback(t *testing.T) { database := mustOpenAndMigrate(t) ctx := context.Background() tx, err := StartTX(ctx, database) if err != nil { t.Fatalf("StartTX failed: %v", err) } _, err = tx.ExecContext(ctx, `INSERT INTO tags (id, tag) VALUES ('rollback-id', 'rollback-tag')`) if err != nil { t.Fatalf("INSERT failed: %v", err) } simErr := context.DeadlineExceeded if err := EndTX(tx, simErr); !errors.Is(err, simErr) { t.Fatalf("EndTX should return the original error, got: %v", err) } var tag string row := database.QueryRow(`SELECT tag FROM tags WHERE id='rollback-id'`) if err := row.Scan(&tag); err == nil { t.Fatal("rolled-back row should not be found") } } func TestToDBTimeAndFromDBTime(t *testing.T) { original := time.Date(2024, 6, 15, 14, 30, 0, 0, time.UTC) s := ToDBTime(original) if s != "2024-06-15 14:30:00" { t.Fatalf("unexpected time string: %q", s) } parsed, err := FromDBTime(s, nil) if err != nil { t.Fatalf("FromDBTime failed: %v", err) } if !parsed.Equal(original) { t.Fatalf("round-trip failed: got %v, want %v", parsed, original) } } func TestFromDBTimeWithLocation(t *testing.T) { s := "2024-06-15 14:30:00" loc, err := time.LoadLocation("America/New_York") if err != nil { t.Skipf("timezone not available: %v", err) } parsed, err := FromDBTime(s, loc) if err != nil { t.Fatalf("FromDBTime failed: %v", err) } if parsed.Location() != loc { t.Fatalf("expected location %v, got %v", loc, parsed.Location()) } } func TestFromDBTimeInvalid(t *testing.T) { _, err := FromDBTime("not-a-date", nil) if err == nil { t.Fatal("expected error for invalid time string") } } func TestForeignKeysEnabled(t *testing.T) { database := mustOpen(t) var fk int row := database.QueryRow(`PRAGMA foreign_keys`) if err := row.Scan(&fk); err != nil { t.Fatalf("PRAGMA foreign_keys failed: %v", err) } if fk != 1 { t.Fatalf("foreign keys should be enabled, got %d", fk) } } func TestSchemaVersion(t *testing.T) { database := mustOpenAndMigrate(t) version, err := getCurrentVersion(database) if err != nil { t.Fatalf("getCurrentVersion failed: %v", err) } if version < 1 { t.Fatalf("expected schema version >= 1, got %d", version) } }