package db import ( "path/filepath" "testing" ) func TestOpenAndMigrate(t *testing.T) { dir := t.TempDir() database, err := Open(filepath.Join(dir, "test.db")) if err != nil { t.Fatalf("open: %v", err) } defer func() { _ = database.Close() }() if err := Migrate(database); err != nil { t.Fatalf("migrate: %v", err) } // Verify tables exist tables := []string{"users", "notebooks", "pages", "strokes", "share_links", "webauthn_credentials", "schema_migrations"} for _, table := range tables { var name string err := database.QueryRow("SELECT name FROM sqlite_master WHERE type='table' AND name=?", table).Scan(&name) if err != nil { t.Errorf("table %s not found: %v", table, err) } } } func TestMigrateIdempotent(t *testing.T) { dir := t.TempDir() database, err := Open(filepath.Join(dir, "test.db")) if err != nil { t.Fatalf("open: %v", err) } defer func() { _ = database.Close() }() if err := Migrate(database); err != nil { t.Fatalf("first migrate: %v", err) } if err := Migrate(database); err != nil { t.Fatalf("second migrate: %v", err) } } func TestForeignKeys(t *testing.T) { dir := t.TempDir() database, err := Open(filepath.Join(dir, "test.db")) if err != nil { t.Fatalf("open: %v", err) } defer func() { _ = database.Close() }() if err := Migrate(database); err != nil { t.Fatalf("migrate: %v", err) } // Inserting a notebook with non-existent user_id should fail _, err = database.Exec("INSERT INTO notebooks (user_id, remote_id, title, page_size, synced_at) VALUES (999, 1, 'test', 'REGULAR', 0)") if err == nil { t.Fatal("expected foreign key error, got nil") } } func TestCascadeDelete(t *testing.T) { dir := t.TempDir() database, err := Open(filepath.Join(dir, "test.db")) if err != nil { t.Fatalf("open: %v", err) } defer func() { _ = database.Close() }() if err := Migrate(database); err != nil { t.Fatalf("migrate: %v", err) } // Create user, notebook, page, stroke res, err := database.Exec("INSERT INTO users (username, password_hash, created_at, updated_at) VALUES ('test', 'hash', 0, 0)") if err != nil { t.Fatalf("insert user: %v", err) } userID, _ := res.LastInsertId() res, err = database.Exec("INSERT INTO notebooks (user_id, remote_id, title, page_size, synced_at) VALUES (?, 1, 'nb', 'REGULAR', 0)", userID) if err != nil { t.Fatalf("insert notebook: %v", err) } nbID, _ := res.LastInsertId() res, err = database.Exec("INSERT INTO pages (notebook_id, remote_id, page_number) VALUES (?, 1, 1)", nbID) if err != nil { t.Fatalf("insert page: %v", err) } pageID, _ := res.LastInsertId() _, err = database.Exec("INSERT INTO strokes (page_id, pen_size, color, point_data, stroke_order) VALUES (?, 1.0, 0, X'00', 1)", pageID) if err != nil { t.Fatalf("insert stroke: %v", err) } // Delete the user — everything should cascade if _, err := database.Exec("DELETE FROM users WHERE id = ?", userID); err != nil { t.Fatalf("delete user: %v", err) } var count int _ = database.QueryRow("SELECT COUNT(*) FROM notebooks").Scan(&count) if count != 0 { t.Errorf("expected 0 notebooks, got %d", count) } _ = database.QueryRow("SELECT COUNT(*) FROM pages").Scan(&count) if count != 0 { t.Errorf("expected 0 pages, got %d", count) } _ = database.QueryRow("SELECT COUNT(*) FROM strokes").Scan(&count) if count != 0 { t.Errorf("expected 0 strokes, got %d", count) } }