diff --git a/AUDIT.md b/AUDIT.md new file mode 100644 index 0000000..3077c04 --- /dev/null +++ b/AUDIT.md @@ -0,0 +1,179 @@ +# AUDIT.md — eng-pad-server Security Audit + +## Audit Date: 2026-03-24 + +### Summary + +Comprehensive security and engineering review of the eng-pad-server +codebase. 17 issues identified across critical, high, and medium +severity levels. All critical and high issues resolved. + +--- + +## Findings + +### A-001: SQL Injection in Database Backup +- **Severity**: Critical +- **Status**: ~~Resolved~~ +- **Location**: `cmd/eng-pad-server/snapshot.go` +- **Description**: Backup path interpolated into SQL via `fmt.Sprintf` + without escaping. Allows arbitrary SQL execution if path contains + single quotes. +- **Resolution**: Escape single quotes in the backup path before + interpolation. + +### A-002: Authentication Timing Attack +- **Severity**: High +- **Status**: ~~Resolved~~ +- **Location**: `internal/auth/users.go` +- **Description**: Early return when user not found skips Argon2id + computation, allowing username enumeration via response timing. +- **Resolution**: Always perform Argon2id verification against a dummy + hash when user is not found, consuming constant time. + +### A-003: Missing Notebook Ownership Checks +- **Severity**: High +- **Status**: ~~Resolved~~ +- **Location**: `internal/server/notebooks_handler.go` +- **Description**: Authenticated endpoints load notebooks by ID without + verifying the notebook belongs to the requesting user. An + authenticated user could access any notebook by guessing IDs. +- **Resolution**: Added `user_id` check to all notebook queries in + authenticated endpoints. Share link endpoints use the validated + notebook ID from the share link. + +### A-004: Unbounded Point Data Decoding +- **Severity**: High +- **Status**: ~~Resolved~~ +- **Location**: `internal/render/svg.go` +- **Description**: `decodePoints()` panics on data not aligned to 4 + bytes, has no size limit (OOM risk), and doesn't filter NaN/Inf. +- **Resolution**: Added alignment check, 4MB size limit, NaN/Inf + filtering. Function now returns error. + +### A-005: Error Messages Leak Information +- **Severity**: High +- **Status**: ~~Resolved~~ +- **Location**: Multiple REST handlers +- **Description**: `err.Error()` concatenated directly into HTTP JSON + responses, leaking database structure and internal state. +- **Resolution**: All error responses use generic messages. Detailed + errors logged server-side only. + +### A-006: No Authorization on Share Link Revocation +- **Severity**: High +- **Status**: ~~Resolved~~ +- **Location**: `internal/grpcserver/share.go` +- **Description**: Any authenticated user could revoke any share link + by token, regardless of notebook ownership. +- **Resolution**: Added JOIN with notebooks table to verify the calling + user owns the notebook before allowing revocation. + +### A-007: Silent Row Scan Errors +- **Severity**: High +- **Status**: ~~Resolved~~ +- **Location**: Multiple handlers +- **Description**: `continue` on row scan errors silently returns + incomplete data without indication. +- **Resolution**: Scan errors now return 500 Internal Server Error in + REST handlers. + +### A-008: Web Server Missing TLS +- **Severity**: Medium +- **Status**: ~~Resolved~~ +- **Location**: `internal/webserver/server.go` +- **Description**: Web UI served over plain HTTP. Session cookies marked + `Secure: true` are ineffective without TLS. +- **Resolution**: Added TLS cert/key fields to web server config. Uses + HTTPS when configured, falls back to HTTP for development. + +### A-009: Missing Input Validation in Sync RPC +- **Severity**: Medium +- **Status**: ~~Resolved~~ +- **Location**: `internal/grpcserver/sync.go` +- **Description**: No validation of page_size, stroke count, or point + data alignment in SyncNotebook RPC. +- **Resolution**: Added validation: page_size must be REGULAR or LARGE, + total strokes limited to 100,000, point_data must be 4-byte aligned. + +### A-010: No Graceful Shutdown +- **Severity**: Medium +- **Status**: ~~Resolved~~ +- **Location**: `cmd/eng-pad-server/server.go` +- **Description**: Signal handler terminates immediately without + draining in-flight requests. +- **Resolution**: Graceful shutdown with 30-second timeout: gRPC + GracefulStop, HTTP Shutdown, database Close. + +### A-011: Missing CSRF Protection on Web Forms +- **Severity**: Medium +- **Status**: Accepted +- **Rationale**: Web UI is currently read-only (viewing synced + notebooks). The only mutating form is login, which is not a CSRF + target (attacker gains nothing by logging victim into their own + account). Will add CSRF tokens when/if web UI gains write features. + +### A-012: WebAuthn Sign Count Not Verified +- **Severity**: Medium +- **Status**: Accepted +- **Rationale**: Sign count regression detection is defense-in-depth + against cloned authenticators. Risk is low for a personal service. + Will add verification when WebAuthn is fully wired into the web UI. + +### A-013: gRPC Per-Request Password Auth +- **Severity**: Medium (audit assessment) / Accepted (our assessment) +- **Status**: Accepted +- **Rationale**: By design. Password travels over TLS 1.3 (encrypted), + stored in Android Keystore (hardware-backed). Sync is manual and + infrequent. Token-based auth adds complexity without meaningful + security gain for this use case. + +### A-014: No Structured Logging +- **Severity**: Medium +- **Status**: Open +- **Description**: Only `fmt.Printf` to stdout. No log levels, no + structured output, no request tracking. +- **Plan**: Add `log/slog` based logging in a future phase. + +### A-015: Incomplete Config Validation +- **Severity**: Medium +- **Status**: Open +- **Description**: TLS files not checked for existence at startup. + Token TTL, WebAuthn config not validated. +- **Plan**: Add file existence checks and config field validation. + +### A-016: Inconsistent Error Types +- **Severity**: Low +- **Status**: Open +- **Description**: String errors instead of sentinel errors make + error handling difficult for callers. + +### A-017: No Race Condition Testing +- **Severity**: Low +- **Status**: Open +- **Description**: Test suite does not use `-race` flag. +- **Plan**: Add `make test-race` target. + +--- + +## Priority Summary + +| ID | Severity | Status | +|----|----------|--------| +| A-001 | Critical | ~~Resolved~~ | +| A-002 | High | ~~Resolved~~ | +| A-003 | High | ~~Resolved~~ | +| A-004 | High | ~~Resolved~~ | +| A-005 | High | ~~Resolved~~ | +| A-006 | High | ~~Resolved~~ | +| A-007 | High | ~~Resolved~~ | +| A-008 | Medium | ~~Resolved~~ | +| A-009 | Medium | ~~Resolved~~ | +| A-010 | Medium | ~~Resolved~~ | +| A-011 | Medium | Accepted | +| A-012 | Medium | Accepted | +| A-013 | Medium | Accepted | +| A-014 | Medium | Open | +| A-015 | Medium | Open | +| A-016 | Low | Open | +| A-017 | Low | Open | diff --git a/internal/auth/users.go b/internal/auth/users.go index 8e22763..b9f1828 100644 --- a/internal/auth/users.go +++ b/internal/auth/users.go @@ -2,10 +2,16 @@ package auth import ( "database/sql" + "errors" "fmt" "time" ) +// dummyHash is a pre-computed Argon2id hash used for constant-time comparison +// when a user is not found. This prevents timing attacks that reveal whether +// a username exists. +var dummyHash = "$argon2id$v=19$m=65536,t=3,p=4$AAAAAAAAAAAAAAAAAAAAAA$AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" + // CreateUser creates a new user with a hashed password. func CreateUser(database *sql.DB, username, password string, params Argon2Params) (int64, error) { hash, err := HashPassword(password, params) @@ -32,16 +38,22 @@ func AuthenticateUser(database *sql.DB, username, password string) (int64, error err := database.QueryRow( "SELECT id, password_hash FROM users WHERE username = ?", username, ).Scan(&userID, &hash) + if errors.Is(err, sql.ErrNoRows) { + // User not found: verify against dummy hash to consume constant time, + // preventing timing attacks that reveal username existence. + _, _ = VerifyPassword(password, dummyHash) + return 0, fmt.Errorf("invalid credentials") + } if err != nil { - return 0, fmt.Errorf("user not found") + return 0, fmt.Errorf("invalid credentials") } ok, err := VerifyPassword(password, hash) if err != nil { - return 0, err + return 0, fmt.Errorf("invalid credentials") } if !ok { - return 0, fmt.Errorf("invalid password") + return 0, fmt.Errorf("invalid credentials") } return userID, nil diff --git a/internal/grpcserver/server.go b/internal/grpcserver/server.go index 2d9a7f5..6f91ad6 100644 --- a/internal/grpcserver/server.go +++ b/internal/grpcserver/server.go @@ -19,10 +19,13 @@ type Config struct { BaseURL string } -func Start(cfg Config) error { +// Start creates and starts the gRPC server. It returns the server so the +// caller can manage graceful shutdown. The server runs in a background +// goroutine; errors are sent to errCh. +func Start(cfg Config) (*grpc.Server, error) { cert, err := tls.LoadX509KeyPair(cfg.TLSCert, cfg.TLSKey) if err != nil { - return fmt.Errorf("load TLS cert: %w", err) + return nil, fmt.Errorf("load TLS cert: %w", err) } tlsConfig := &tls.Config{ @@ -32,7 +35,7 @@ func Start(cfg Config) error { lis, err := net.Listen("tcp", cfg.Addr) if err != nil { - return fmt.Errorf("listen %s: %w", cfg.Addr, err) + return nil, fmt.Errorf("listen %s: %w", cfg.Addr, err) } srv := grpc.NewServer( @@ -44,5 +47,7 @@ func Start(cfg Config) error { pb.RegisterEngPadSyncServer(srv, syncSvc) fmt.Printf("gRPC listening on %s\n", cfg.Addr) - return srv.Serve(lis) + go func() { _ = srv.Serve(lis) }() + + return srv, nil } diff --git a/internal/grpcserver/share.go b/internal/grpcserver/share.go index b7a7fee..e767787 100644 --- a/internal/grpcserver/share.go +++ b/internal/grpcserver/share.go @@ -51,6 +51,23 @@ func (s *SyncService) CreateShareLink(ctx context.Context, req *pb.CreateShareLi } func (s *SyncService) RevokeShareLink(ctx context.Context, req *pb.RevokeShareLinkRequest) (*pb.RevokeShareLinkResponse, error) { + userID, ok := UserIDFromContext(ctx) + if !ok { + return nil, status.Error(codes.Internal, "missing user context") + } + + // Verify the calling user owns the notebook associated with this share link. + var count int + err := s.DB.QueryRowContext(ctx, + `SELECT COUNT(*) FROM share_links sl + JOIN notebooks n ON sl.notebook_id = n.id + WHERE sl.token = ? AND n.user_id = ?`, + req.Token, userID, + ).Scan(&count) + if err != nil || count == 0 { + return nil, status.Error(codes.NotFound, "share link not found") + } + if err := share.RevokeLink(s.DB, req.Token); err != nil { return nil, status.Errorf(codes.Internal, "revoke: %v", err) } diff --git a/internal/grpcserver/sync.go b/internal/grpcserver/sync.go index 5a8290f..095380b 100644 --- a/internal/grpcserver/sync.go +++ b/internal/grpcserver/sync.go @@ -17,12 +17,33 @@ type SyncService struct { BaseURL string } +const maxTotalStrokes = 100000 + func (s *SyncService) SyncNotebook(ctx context.Context, req *pb.SyncNotebookRequest) (*pb.SyncNotebookResponse, error) { userID, ok := UserIDFromContext(ctx) if !ok { return nil, status.Error(codes.Internal, "missing user context") } + // Validate page_size + if req.PageSize != "REGULAR" && req.PageSize != "LARGE" { + return nil, status.Errorf(codes.InvalidArgument, "invalid page_size: must be REGULAR or LARGE") + } + + // Validate total stroke count and point_data alignment + totalStrokes := 0 + for _, page := range req.Pages { + totalStrokes += len(page.Strokes) + if totalStrokes > maxTotalStrokes { + return nil, status.Errorf(codes.InvalidArgument, "total stroke count exceeds maximum of %d", maxTotalStrokes) + } + for _, stroke := range page.Strokes { + if len(stroke.PointData)%4 != 0 { + return nil, status.Errorf(codes.InvalidArgument, "point_data length must be a multiple of 4") + } + } + } + tx, err := s.DB.BeginTx(ctx, nil) if err != nil { return nil, status.Errorf(codes.Internal, "begin tx: %v", err) diff --git a/internal/render/jpg.go b/internal/render/jpg.go index 4b059eb..a6bdcf1 100644 --- a/internal/render/jpg.go +++ b/internal/render/jpg.go @@ -22,7 +22,10 @@ func RenderJPG(pageSize string, strokes []Stroke, quality int) ([]byte, error) { // Draw strokes for _, s := range strokes { - points := decodePoints(s.PointData) + points, err := decodePoints(s.PointData) + if err != nil { + continue + } if len(points) < 4 { continue } diff --git a/internal/render/pdf.go b/internal/render/pdf.go index 3e8ce12..0ee79aa 100644 --- a/internal/render/pdf.go +++ b/internal/render/pdf.go @@ -12,8 +12,8 @@ type Page struct { } // RenderPDF generates a minimal PDF document from pages. -// Uses raw PDF operators — no external library needed. -func RenderPDF(pageSize string, pages []Page) []byte { +// Uses raw PDF operators -- no external library needed. +func RenderPDF(pageSize string, pages []Page) ([]byte, error) { w, h := PageSizePt(pageSize) var objects []string @@ -30,7 +30,7 @@ func RenderPDF(pageSize string, pages []Page) []byte { objects = append(objects, fmt.Sprintf("1 0 obj\n<< /Type /Catalog /Pages 2 0 R >>\nendobj\n")) pdf.WriteString(objects[0]) - // Pages object (object 2) — we'll write this after we know the page objects + // Pages object (object 2) -- we'll write this after we know the page objects pagesObjOffset := pdf.Len() pagesObjPlaceholder := strings.Repeat(" ", 200) + "\n" pdf.WriteString(pagesObjPlaceholder) @@ -47,7 +47,10 @@ func RenderPDF(pageSize string, pages []Page) []byte { // Content stream var stream strings.Builder for _, s := range page.Strokes { - points := decodePoints(s.PointData) + points, err := decodePoints(s.PointData) + if err != nil { + continue + } if len(points) < 4 { continue } @@ -89,12 +92,12 @@ func RenderPDF(pageSize string, pages []Page) []byte { // Now write the Pages object at its placeholder position pagesObj := fmt.Sprintf("2 0 obj\n<< /Type /Pages /Kids [%s] /Count %d >>\nendobj\n", strings.Join(pageRefs, " "), len(pages)) - // Overwrite placeholder — we need to rebuild the PDF string + // Overwrite placeholder -- we need to rebuild the PDF string pdfStr := pdf.String() pdfStr = pdfStr[:pagesObjOffset] + pagesObj + strings.Repeat(" ", len(pagesObjPlaceholder)-len(pagesObj)) + pdfStr[pagesObjOffset+len(pagesObjPlaceholder):] // Rebuild with correct offsets for xref - // For simplicity, just return the PDF bytes — most viewers handle minor xref issues + // For simplicity, just return the PDF bytes -- most viewers handle minor xref issues var final strings.Builder final.WriteString(pdfStr) @@ -102,7 +105,7 @@ func RenderPDF(pageSize string, pages []Page) []byte { xrefOffset := final.Len() fmt.Fprintf(&final, "xref\n0 %d\n", nextObj) fmt.Fprintf(&final, "0000000000 65535 f \n") - // For a proper PDF we'd need exact offsets — skip for now + // For a proper PDF we'd need exact offsets -- skip for now for i := 0; i < nextObj-1; i++ { fmt.Fprintf(&final, "%010d 00000 n \n", 0) } @@ -110,5 +113,5 @@ func RenderPDF(pageSize string, pages []Page) []byte { fmt.Fprintf(&final, "trailer\n<< /Size %d /Root 1 0 R >>\n", nextObj) fmt.Fprintf(&final, "startxref\n%d\n%%%%EOF\n", xrefOffset) - return []byte(final.String()) + return []byte(final.String()), nil } diff --git a/internal/render/render_test.go b/internal/render/render_test.go index d470d87..6b3883f 100644 --- a/internal/render/render_test.go +++ b/internal/render/render_test.go @@ -25,7 +25,10 @@ func TestRenderSVG(t *testing.T) { }, } - svg := RenderSVG("REGULAR", strokes) + svg, err := RenderSVG("REGULAR", strokes) + if err != nil { + t.Fatalf("render: %v", err) + } if !strings.Contains(svg, "") - return b.String() + return b.String(), nil } func renderArrowHeads(x1, y1, x2, y2 float64, style string, penW float64, color string) string { @@ -109,14 +115,28 @@ func renderArrowHeads(x1, y1, x2, y2 float64, style string, penW float64, color return b.String() } -func decodePoints(data []byte) []float64 { +// decodePoints decodes a byte slice of little-endian float32 values into float64. +// Returns an error if the data length is not a multiple of 4, exceeds maxPointDataSize, +// or contains NaN/Inf values. +func decodePoints(data []byte) ([]float64, error) { + if len(data)%4 != 0 { + return nil, fmt.Errorf("point data length %d is not a multiple of 4", len(data)) + } + if len(data) > maxPointDataSize { + return nil, fmt.Errorf("point data size %d exceeds maximum %d", len(data), maxPointDataSize) + } + count := len(data) / 4 - points := make([]float64, count) + points := make([]float64, 0, count) for i := 0; i < count; i++ { bits := binary.LittleEndian.Uint32(data[i*4 : (i+1)*4]) - points[i] = float64(math.Float32frombits(bits)) + v := float64(math.Float32frombits(bits)) + if math.IsNaN(v) || math.IsInf(v, 0) { + return nil, fmt.Errorf("point data contains NaN or Inf at index %d", i) + } + points = append(points, v) } - return points + return points, nil } func colorToCSS(argb int32) string { diff --git a/internal/server/notebooks_handler.go b/internal/server/notebooks_handler.go index 282eb18..c28f6e5 100644 --- a/internal/server/notebooks_handler.go +++ b/internal/server/notebooks_handler.go @@ -3,6 +3,7 @@ package server import ( "database/sql" "encoding/json" + "fmt" "net/http" "strconv" @@ -41,7 +42,8 @@ func handleListNotebooks(database *sql.DB) http.HandlerFunc { for rows.Next() { var nb notebookJSON if err := rows.Scan(&nb.ID, &nb.RemoteID, &nb.Title, &nb.PageSize, &nb.Pages); err != nil { - continue + http.Error(w, `{"error":"internal error"}`, http.StatusInternalServerError) + return } notebooks = append(notebooks, nb) } @@ -53,13 +55,40 @@ func handleListNotebooks(database *sql.DB) http.HandlerFunc { func handlePageSVG(database *sql.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - strokes, pageSize, err := loadPageStrokes(r, database) - if err != nil { - http.Error(w, `{"error":"`+err.Error()+`"}`, http.StatusNotFound) + userID, ok := UserIDFromContext(r.Context()) + if !ok { + http.Error(w, `{"error":"unauthenticated"}`, http.StatusUnauthorized) return } - svg := render.RenderSVG(pageSize, strokes) + notebookID, err := strconv.ParseInt(chi.URLParam(r, "id"), 10, 64) + if err != nil { + http.Error(w, `{"error":"invalid id"}`, http.StatusBadRequest) + return + } + pageNum, err := strconv.Atoi(chi.URLParam(r, "num")) + if err != nil { + http.Error(w, `{"error":"invalid page number"}`, http.StatusBadRequest) + return + } + + internalID, err := verifyNotebookOwnership(database, notebookID, userID) + if err != nil { + http.Error(w, `{"error":"not found"}`, http.StatusNotFound) + return + } + + strokes, pageSize, err := loadPageStrokesByNotebookID(database, internalID, pageNum) + if err != nil { + http.Error(w, `{"error":"not found"}`, http.StatusNotFound) + return + } + + svg, err := render.RenderSVG(pageSize, strokes) + if err != nil { + http.Error(w, `{"error":"internal error"}`, http.StatusInternalServerError) + return + } w.Header().Set("Content-Type", "image/svg+xml") _, _ = w.Write([]byte(svg)) } @@ -67,15 +96,38 @@ func handlePageSVG(database *sql.DB) http.HandlerFunc { func handlePageJPG(database *sql.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - strokes, pageSize, err := loadPageStrokes(r, database) + userID, ok := UserIDFromContext(r.Context()) + if !ok { + http.Error(w, `{"error":"unauthenticated"}`, http.StatusUnauthorized) + return + } + + notebookID, err := strconv.ParseInt(chi.URLParam(r, "id"), 10, 64) if err != nil { - http.Error(w, `{"error":"`+err.Error()+`"}`, http.StatusNotFound) + http.Error(w, `{"error":"invalid id"}`, http.StatusBadRequest) + return + } + pageNum, err := strconv.Atoi(chi.URLParam(r, "num")) + if err != nil { + http.Error(w, `{"error":"invalid page number"}`, http.StatusBadRequest) + return + } + + internalID, err := verifyNotebookOwnership(database, notebookID, userID) + if err != nil { + http.Error(w, `{"error":"not found"}`, http.StatusNotFound) + return + } + + strokes, pageSize, err := loadPageStrokesByNotebookID(database, internalID, pageNum) + if err != nil { + http.Error(w, `{"error":"not found"}`, http.StatusNotFound) return } data, err := render.RenderJPG(pageSize, strokes, 95) if err != nil { - http.Error(w, `{"error":"render error"}`, http.StatusInternalServerError) + http.Error(w, `{"error":"internal error"}`, http.StatusInternalServerError) return } @@ -86,19 +138,35 @@ func handlePageJPG(database *sql.DB) http.HandlerFunc { func handleNotebookPDF(database *sql.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { + userID, ok := UserIDFromContext(r.Context()) + if !ok { + http.Error(w, `{"error":"unauthenticated"}`, http.StatusUnauthorized) + return + } + notebookID, err := strconv.ParseInt(chi.URLParam(r, "id"), 10, 64) if err != nil { http.Error(w, `{"error":"invalid id"}`, http.StatusBadRequest) return } - pages, pageSize, err := loadNotebookPages(database, notebookID) + internalID, err := verifyNotebookOwnership(database, notebookID, userID) if err != nil { - http.Error(w, `{"error":"`+err.Error()+`"}`, http.StatusNotFound) + http.Error(w, `{"error":"not found"}`, http.StatusNotFound) return } - data := render.RenderPDF(pageSize, pages) + pages, pageSize, err := loadNotebookPages(database, internalID) + if err != nil { + http.Error(w, `{"error":"not found"}`, http.StatusNotFound) + return + } + + data, err := render.RenderPDF(pageSize, pages) + if err != nil { + http.Error(w, `{"error":"internal error"}`, http.StatusInternalServerError) + return + } w.Header().Set("Content-Type", "application/pdf") w.Header().Set("Content-Disposition", "attachment; filename=notebook.pdf") _, _ = w.Write(data) @@ -134,13 +202,17 @@ func handleSharePageSVG(database *sql.DB) http.HandlerFunc { return } - strokes, pageSize, err := loadPageStrokesByNotebook(database, notebookID, pageNum) + strokes, pageSize, err := loadPageStrokesByNotebookID(database, notebookID, pageNum) if err != nil { - http.Error(w, `{"error":"`+err.Error()+`"}`, http.StatusNotFound) + http.Error(w, `{"error":"not found"}`, http.StatusNotFound) return } - svg := render.RenderSVG(pageSize, strokes) + svg, err := render.RenderSVG(pageSize, strokes) + if err != nil { + http.Error(w, `{"error":"internal error"}`, http.StatusInternalServerError) + return + } w.Header().Set("Content-Type", "image/svg+xml") _, _ = w.Write([]byte(svg)) } @@ -161,15 +233,15 @@ func handleSharePageJPG(database *sql.DB) http.HandlerFunc { return } - strokes, pageSize, err := loadPageStrokesByNotebook(database, notebookID, pageNum) + strokes, pageSize, err := loadPageStrokesByNotebookID(database, notebookID, pageNum) if err != nil { - http.Error(w, `{"error":"`+err.Error()+`"}`, http.StatusNotFound) + http.Error(w, `{"error":"not found"}`, http.StatusNotFound) return } data, err := render.RenderJPG(pageSize, strokes, 95) if err != nil { - http.Error(w, `{"error":"render error"}`, http.StatusInternalServerError) + http.Error(w, `{"error":"internal error"}`, http.StatusInternalServerError) return } @@ -189,11 +261,15 @@ func handleSharePDF(database *sql.DB) http.HandlerFunc { pages, pageSize, err := loadNotebookPages(database, notebookID) if err != nil { - http.Error(w, `{"error":"`+err.Error()+`"}`, http.StatusNotFound) + http.Error(w, `{"error":"not found"}`, http.StatusNotFound) return } - data := render.RenderPDF(pageSize, pages) + data, err := render.RenderPDF(pageSize, pages) + if err != nil { + http.Error(w, `{"error":"internal error"}`, http.StatusInternalServerError) + return + } w.Header().Set("Content-Type", "application/pdf") w.Header().Set("Content-Disposition", "attachment; filename=notebook.pdf") _, _ = w.Write(data) @@ -202,19 +278,22 @@ func handleSharePDF(database *sql.DB) http.HandlerFunc { // --- helpers --- -func loadPageStrokes(r *http.Request, database *sql.DB) ([]render.Stroke, string, error) { - notebookID, err := strconv.ParseInt(chi.URLParam(r, "id"), 10, 64) +// verifyNotebookOwnership checks that notebookID belongs to userID and returns +// the internal (server-side) notebook ID. +func verifyNotebookOwnership(database *sql.DB, notebookID int64, userID int64) (int64, error) { + var internalID int64 + err := database.QueryRow( + "SELECT id FROM notebooks WHERE id = ? AND user_id = ?", notebookID, userID, + ).Scan(&internalID) if err != nil { - return nil, "", err + return 0, fmt.Errorf("not found") } - pageNum, err := strconv.Atoi(chi.URLParam(r, "num")) - if err != nil { - return nil, "", err - } - return loadPageStrokesByNotebook(database, notebookID, pageNum) + return internalID, nil } -func loadPageStrokesByNotebook(database *sql.DB, notebookID int64, pageNum int) ([]render.Stroke, string, error) { +// loadPageStrokesByNotebookID loads strokes for a page by internal notebook ID. +// This is used by both authenticated and share-link endpoints. +func loadPageStrokesByNotebookID(database *sql.DB, notebookID int64, pageNum int) ([]render.Stroke, string, error) { var pageSize string err := database.QueryRow("SELECT page_size FROM notebooks WHERE id = ?", notebookID).Scan(&pageSize) if err != nil { @@ -243,7 +322,7 @@ func loadPageStrokesByNotebook(database *sql.DB, notebookID int64, pageNum int) for rows.Next() { var s render.Stroke if err := rows.Scan(&s.PenSize, &s.Color, &s.Style, &s.PointData, &s.StrokeOrder); err != nil { - continue + return nil, "", fmt.Errorf("scan stroke: %w", err) } strokes = append(strokes, s) } @@ -271,7 +350,7 @@ func loadNotebookPages(database *sql.DB, notebookID int64) ([]render.Page, strin var pageID int64 var pageNum int if err := rows.Scan(&pageID, &pageNum); err != nil { - continue + return nil, "", fmt.Errorf("scan page: %w", err) } strokeRows, err := database.Query( @@ -279,14 +358,15 @@ func loadNotebookPages(database *sql.DB, notebookID int64) ([]render.Page, strin pageID, ) if err != nil { - continue + return nil, "", fmt.Errorf("query strokes: %w", err) } var strokes []render.Stroke for strokeRows.Next() { var s render.Stroke if err := strokeRows.Scan(&s.PenSize, &s.Color, &s.Style, &s.PointData, &s.StrokeOrder); err != nil { - continue + _ = strokeRows.Close() + return nil, "", fmt.Errorf("scan stroke: %w", err) } strokes = append(strokes, s) } diff --git a/internal/server/server.go b/internal/server/server.go index 9d08b66..40170e4 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -18,13 +18,16 @@ type Config struct { BaseURL string } -func Start(cfg Config) error { +// Start creates and starts the REST API server. It returns the *http.Server +// so the caller can manage graceful shutdown. The server runs in a background +// goroutine. +func Start(cfg Config) (*http.Server, error) { r := chi.NewRouter() RegisterRoutes(r, cfg.DB, cfg.BaseURL) tlsCert, err := tls.LoadX509KeyPair(cfg.TLSCert, cfg.TLSKey) if err != nil { - return fmt.Errorf("load TLS cert: %w", err) + return nil, fmt.Errorf("load TLS cert: %w", err) } srv := &http.Server{ @@ -40,5 +43,7 @@ func Start(cfg Config) error { } fmt.Printf("REST API listening on %s\n", cfg.Addr) - return srv.ListenAndServeTLS("", "") + go func() { _ = srv.ListenAndServeTLS("", "") }() + + return srv, nil } diff --git a/internal/webserver/handlers.go b/internal/webserver/handlers.go index e7a8e42..d55d128 100644 --- a/internal/webserver/handlers.go +++ b/internal/webserver/handlers.go @@ -83,7 +83,8 @@ func (ws *WebServer) handleNotebooks(w http.ResponseWriter, r *http.Request) { var nb notebook var syncedAt int64 if err := rows.Scan(&nb.ID, &nb.Title, &nb.PageSize, &syncedAt, &nb.PageCount); err != nil { - continue + http.Error(w, "Internal error", http.StatusInternalServerError) + return } nb.SyncedAt = time.UnixMilli(syncedAt).Format("2006-01-02 15:04") notebooks = append(notebooks, nb) @@ -118,7 +119,8 @@ func (ws *WebServer) handleNotebook(w http.ResponseWriter, r *http.Request) { for rows.Next() { var num int if err := rows.Scan(&num); err != nil { - continue + http.Error(w, "Internal error", http.StatusInternalServerError) + return } pages = append(pages, pageInfo{ Number: num, @@ -182,7 +184,8 @@ func (ws *WebServer) handleShareNotebook(w http.ResponseWriter, r *http.Request) for rows.Next() { var num int if err := rows.Scan(&num); err != nil { - continue + http.Error(w, "Internal error", http.StatusInternalServerError) + return } pages = append(pages, pageInfo{ Number: num, diff --git a/internal/webserver/server.go b/internal/webserver/server.go index 709b781..f875ada 100644 --- a/internal/webserver/server.go +++ b/internal/webserver/server.go @@ -1,6 +1,7 @@ package webserver import ( + "crypto/tls" "database/sql" "fmt" "html/template" @@ -16,6 +17,8 @@ type Config struct { Addr string DB *sql.DB BaseURL string + TLSCert string + TLSKey string } type WebServer struct { @@ -24,15 +27,15 @@ type WebServer struct { tmpl *template.Template } -func Start(cfg Config) error { +func Start(cfg Config) (*http.Server, error) { templateFS, err := fs.Sub(web.Content, "templates") if err != nil { - return fmt.Errorf("template fs: %w", err) + return nil, fmt.Errorf("template fs: %w", err) } tmpl, err := template.ParseFS(templateFS, "*.html") if err != nil { - return fmt.Errorf("parse templates: %w", err) + return nil, fmt.Errorf("parse templates: %w", err) } ws := &WebServer{ @@ -73,6 +76,21 @@ func Start(cfg Config) error { IdleTimeout: 120 * time.Second, } - fmt.Printf("Web UI listening on %s\n", cfg.Addr) - return srv.ListenAndServe() + if cfg.TLSCert != "" && cfg.TLSKey != "" { + tlsCert, err := tls.LoadX509KeyPair(cfg.TLSCert, cfg.TLSKey) + if err != nil { + return nil, fmt.Errorf("load TLS cert: %w", err) + } + srv.TLSConfig = &tls.Config{ + Certificates: []tls.Certificate{tlsCert}, + MinVersion: tls.VersionTLS13, + } + fmt.Printf("Web UI listening on %s (TLS)\n", cfg.Addr) + go func() { _ = srv.ListenAndServeTLS("", "") }() + } else { + fmt.Printf("Web UI listening on %s\n", cfg.Addr) + go func() { _ = srv.ListenAndServe() }() + } + + return srv, nil }