diff --git a/internal/share/share.go b/internal/share/share.go index 7cd3e7e..f93ca20 100644 --- a/internal/share/share.go +++ b/internal/share/share.go @@ -29,7 +29,7 @@ func CreateLink(database *sql.DB, notebookID int64, expiry time.Duration, baseUR var expiresAt *int64 var expiresTime *time.Time - if expiry > 0 { + if expiry != 0 { ea := time.Now().Add(expiry).UnixMilli() expiresAt = &ea t := time.UnixMilli(ea) diff --git a/internal/share/share_test.go b/internal/share/share_test.go new file mode 100644 index 0000000..99f335a --- /dev/null +++ b/internal/share/share_test.go @@ -0,0 +1,105 @@ +package share + +import ( + "path/filepath" + "testing" + "time" + + "git.wntrmute.dev/kyle/eng-pad-server/internal/auth" + "git.wntrmute.dev/kyle/eng-pad-server/internal/db" +) + +func setupTestDB(t *testing.T) (*db.TestDB, int64) { + t.Helper() + dir := t.TempDir() + database, err := db.Open(filepath.Join(dir, "test.db")) + if err != nil { + t.Fatalf("open: %v", err) + } + if err := db.Migrate(database); err != nil { + t.Fatalf("migrate: %v", err) + } + t.Cleanup(func() { _ = database.Close() }) + + userID, err := auth.CreateUser(database, "test", "pass", auth.DefaultParams) + if err != nil { + t.Fatalf("create user: %v", err) + } + + _, err = database.Exec( + "INSERT INTO notebooks (user_id, remote_id, title, page_size, synced_at) VALUES (?, 1, 'Test', 'REGULAR', ?)", + userID, time.Now().UnixMilli(), + ) + if err != nil { + t.Fatalf("insert notebook: %v", err) + } + + return &db.TestDB{DB: database}, 1 // notebook ID +} + +func TestCreateAndValidateLink(t *testing.T) { + tdb, notebookID := setupTestDB(t) + + token, _, err := CreateLink(tdb.DB, notebookID, 0, "https://example.com") + if err != nil { + t.Fatalf("create: %v", err) + } + if token == "" { + t.Fatal("expected non-empty token") + } + + nbID, err := ValidateLink(tdb.DB, token) + if err != nil { + t.Fatalf("validate: %v", err) + } + if nbID != notebookID { + t.Fatalf("expected notebook %d, got %d", notebookID, nbID) + } +} + +func TestExpiredLink(t *testing.T) { + tdb, notebookID := setupTestDB(t) + + token, _, err := CreateLink(tdb.DB, notebookID, -time.Hour, "https://example.com") + if err != nil { + t.Fatalf("create: %v", err) + } + + _, err = ValidateLink(tdb.DB, token) + if err == nil { + t.Fatal("expected error for expired link") + } +} + +func TestRevokeLink(t *testing.T) { + tdb, notebookID := setupTestDB(t) + + token, _, err := CreateLink(tdb.DB, notebookID, 0, "https://example.com") + if err != nil { + t.Fatalf("create: %v", err) + } + + if err := RevokeLink(tdb.DB, token); err != nil { + t.Fatalf("revoke: %v", err) + } + + _, err = ValidateLink(tdb.DB, token) + if err == nil { + t.Fatal("expected error for revoked link") + } +} + +func TestListLinks(t *testing.T) { + tdb, notebookID := setupTestDB(t) + + _, _, _ = CreateLink(tdb.DB, notebookID, 0, "https://example.com") + _, _, _ = CreateLink(tdb.DB, notebookID, time.Hour, "https://example.com") + + links, err := ListLinks(tdb.DB, notebookID, "https://example.com") + if err != nil { + t.Fatalf("list: %v", err) + } + if len(links) != 2 { + t.Fatalf("expected 2 links, got %d", len(links)) + } +}