Implement Phase 4: gRPC sync service
- Proto definitions (engpad.v1.EngPadSync) with 6 RPCs - Generated Go gRPC code - Auth interceptor: username/password from metadata - SyncNotebook: upsert with full page/stroke replacement in a tx - DeleteNotebook, ListNotebooks handlers - Share link RPCs: CreateShareLink, RevokeShareLink, ListShareLinks - Share link token management (32-byte random, optional expiry) - gRPC server with TLS 1.3 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
46
internal/grpcserver/interceptors.go
Normal file
46
internal/grpcserver/interceptors.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package grpcserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"git.wntrmute.dev/kyle/eng-pad-server/internal/auth"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
type contextKey string
|
||||
|
||||
const userIDKey contextKey = "user_id"
|
||||
|
||||
// UserIDFromContext extracts the authenticated user ID from the context.
|
||||
func UserIDFromContext(ctx context.Context) (int64, bool) {
|
||||
id, ok := ctx.Value(userIDKey).(int64)
|
||||
return id, ok
|
||||
}
|
||||
|
||||
// AuthInterceptor verifies username/password from gRPC metadata.
|
||||
func AuthInterceptor(database *sql.DB) grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
return nil, status.Error(codes.Unauthenticated, "missing metadata")
|
||||
}
|
||||
|
||||
usernames := md.Get("username")
|
||||
passwords := md.Get("password")
|
||||
if len(usernames) == 0 || len(passwords) == 0 {
|
||||
return nil, status.Error(codes.Unauthenticated, "missing credentials")
|
||||
}
|
||||
|
||||
userID, err := auth.AuthenticateUser(database, usernames[0], passwords[0])
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Unauthenticated, "invalid credentials")
|
||||
}
|
||||
|
||||
ctx = context.WithValue(ctx, userIDKey, userID)
|
||||
return handler(ctx, req)
|
||||
}
|
||||
}
|
||||
48
internal/grpcserver/server.go
Normal file
48
internal/grpcserver/server.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package grpcserver
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
pb "git.wntrmute.dev/kyle/eng-pad-server/gen/engpad/v1"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Addr string
|
||||
TLSCert string
|
||||
TLSKey string
|
||||
DB *sql.DB
|
||||
BaseURL string
|
||||
}
|
||||
|
||||
func Start(cfg Config) error {
|
||||
cert, err := tls.LoadX509KeyPair(cfg.TLSCert, cfg.TLSKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load TLS cert: %w", err)
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
MinVersion: tls.VersionTLS13,
|
||||
}
|
||||
|
||||
lis, err := net.Listen("tcp", cfg.Addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listen %s: %w", cfg.Addr, err)
|
||||
}
|
||||
|
||||
srv := grpc.NewServer(
|
||||
grpc.Creds(credentials.NewTLS(tlsConfig)),
|
||||
grpc.UnaryInterceptor(AuthInterceptor(cfg.DB)),
|
||||
)
|
||||
|
||||
syncSvc := &SyncService{DB: cfg.DB, BaseURL: cfg.BaseURL}
|
||||
pb.RegisterEngPadSyncServer(srv, syncSvc)
|
||||
|
||||
fmt.Printf("gRPC listening on %s\n", cfg.Addr)
|
||||
return srv.Serve(lis)
|
||||
}
|
||||
94
internal/grpcserver/share.go
Normal file
94
internal/grpcserver/share.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package grpcserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
pb "git.wntrmute.dev/kyle/eng-pad-server/gen/engpad/v1"
|
||||
"git.wntrmute.dev/kyle/eng-pad-server/internal/share"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
func (s *SyncService) CreateShareLink(ctx context.Context, req *pb.CreateShareLinkRequest) (*pb.CreateShareLinkResponse, error) {
|
||||
userID, ok := UserIDFromContext(ctx)
|
||||
if !ok {
|
||||
return nil, status.Error(codes.Internal, "missing user context")
|
||||
}
|
||||
|
||||
// Verify notebook belongs to user
|
||||
var notebookID int64
|
||||
err := s.DB.QueryRowContext(ctx,
|
||||
"SELECT id FROM notebooks WHERE user_id = ? AND remote_id = ?",
|
||||
userID, req.NotebookId,
|
||||
).Scan(¬ebookID)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, status.Error(codes.NotFound, "notebook not found")
|
||||
} else if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "query: %v", err)
|
||||
}
|
||||
|
||||
var expiry time.Duration
|
||||
if req.ExpiresInSeconds > 0 {
|
||||
expiry = time.Duration(req.ExpiresInSeconds) * time.Second
|
||||
}
|
||||
|
||||
token, expiresAt, err := share.CreateLink(s.DB, notebookID, expiry, s.BaseURL)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "create link: %v", err)
|
||||
}
|
||||
|
||||
resp := &pb.CreateShareLinkResponse{
|
||||
Token: token,
|
||||
Url: s.BaseURL + "/s/" + token,
|
||||
}
|
||||
if expiresAt != nil {
|
||||
resp.ExpiresAt = timestamppb.New(*expiresAt)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (s *SyncService) RevokeShareLink(ctx context.Context, req *pb.RevokeShareLinkRequest) (*pb.RevokeShareLinkResponse, error) {
|
||||
if err := share.RevokeLink(s.DB, req.Token); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "revoke: %v", err)
|
||||
}
|
||||
return &pb.RevokeShareLinkResponse{}, nil
|
||||
}
|
||||
|
||||
func (s *SyncService) ListShareLinks(ctx context.Context, req *pb.ListShareLinksRequest) (*pb.ListShareLinksResponse, error) {
|
||||
userID, ok := UserIDFromContext(ctx)
|
||||
if !ok {
|
||||
return nil, status.Error(codes.Internal, "missing user context")
|
||||
}
|
||||
|
||||
var notebookID int64
|
||||
err := s.DB.QueryRowContext(ctx,
|
||||
"SELECT id FROM notebooks WHERE user_id = ? AND remote_id = ?",
|
||||
userID, req.NotebookId,
|
||||
).Scan(¬ebookID)
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.NotFound, "notebook not found")
|
||||
}
|
||||
|
||||
links, err := share.ListLinks(s.DB, notebookID, s.BaseURL)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "list: %v", err)
|
||||
}
|
||||
|
||||
var pbLinks []*pb.ShareLinkInfo
|
||||
for _, l := range links {
|
||||
pbl := &pb.ShareLinkInfo{
|
||||
Token: l.Token,
|
||||
Url: l.URL,
|
||||
CreatedAt: timestamppb.New(l.CreatedAt),
|
||||
}
|
||||
if l.ExpiresAt != nil {
|
||||
pbl.ExpiresAt = timestamppb.New(*l.ExpiresAt)
|
||||
}
|
||||
pbLinks = append(pbLinks, pbl)
|
||||
}
|
||||
|
||||
return &pb.ListShareLinksResponse{Links: pbLinks}, nil
|
||||
}
|
||||
142
internal/grpcserver/sync.go
Normal file
142
internal/grpcserver/sync.go
Normal file
@@ -0,0 +1,142 @@
|
||||
package grpcserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
pb "git.wntrmute.dev/kyle/eng-pad-server/gen/engpad/v1"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
type SyncService struct {
|
||||
pb.UnimplementedEngPadSyncServer
|
||||
DB *sql.DB
|
||||
BaseURL string
|
||||
}
|
||||
|
||||
func (s *SyncService) SyncNotebook(ctx context.Context, req *pb.SyncNotebookRequest) (*pb.SyncNotebookResponse, error) {
|
||||
userID, ok := UserIDFromContext(ctx)
|
||||
if !ok {
|
||||
return nil, status.Error(codes.Internal, "missing user context")
|
||||
}
|
||||
|
||||
tx, err := s.DB.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "begin tx: %v", err)
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
now := time.Now().UnixMilli()
|
||||
|
||||
// Upsert notebook
|
||||
var notebookID int64
|
||||
err = tx.QueryRowContext(ctx,
|
||||
"SELECT id FROM notebooks WHERE user_id = ? AND remote_id = ?",
|
||||
userID, req.NotebookId,
|
||||
).Scan(¬ebookID)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
res, err := tx.ExecContext(ctx,
|
||||
"INSERT INTO notebooks (user_id, remote_id, title, page_size, synced_at) VALUES (?, ?, ?, ?, ?)",
|
||||
userID, req.NotebookId, req.Title, req.PageSize, now,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "insert notebook: %v", err)
|
||||
}
|
||||
notebookID, _ = res.LastInsertId()
|
||||
} else if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "query notebook: %v", err)
|
||||
} else {
|
||||
// Update existing — delete all pages (cascade deletes strokes)
|
||||
if _, err := tx.ExecContext(ctx, "DELETE FROM pages WHERE notebook_id = ?", notebookID); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "delete pages: %v", err)
|
||||
}
|
||||
if _, err := tx.ExecContext(ctx,
|
||||
"UPDATE notebooks SET title = ?, page_size = ?, synced_at = ? WHERE id = ?",
|
||||
req.Title, req.PageSize, now, notebookID,
|
||||
); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "update notebook: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Insert pages and strokes
|
||||
for _, page := range req.Pages {
|
||||
res, err := tx.ExecContext(ctx,
|
||||
"INSERT INTO pages (notebook_id, remote_id, page_number) VALUES (?, ?, ?)",
|
||||
notebookID, page.PageId, page.PageNumber,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "insert page: %v", err)
|
||||
}
|
||||
pageID, _ := res.LastInsertId()
|
||||
|
||||
for _, stroke := range page.Strokes {
|
||||
if _, err := tx.ExecContext(ctx,
|
||||
"INSERT INTO strokes (page_id, pen_size, color, style, point_data, stroke_order) VALUES (?, ?, ?, ?, ?, ?)",
|
||||
pageID, stroke.PenSize, stroke.Color, stroke.Style, stroke.PointData, stroke.StrokeOrder,
|
||||
); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "insert stroke: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "commit: %v", err)
|
||||
}
|
||||
|
||||
return &pb.SyncNotebookResponse{
|
||||
ServerNotebookId: notebookID,
|
||||
SyncedAt: timestamppb.Now(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *SyncService) DeleteNotebook(ctx context.Context, req *pb.DeleteNotebookRequest) (*pb.DeleteNotebookResponse, error) {
|
||||
userID, ok := UserIDFromContext(ctx)
|
||||
if !ok {
|
||||
return nil, status.Error(codes.Internal, "missing user context")
|
||||
}
|
||||
|
||||
_, err := s.DB.ExecContext(ctx,
|
||||
"DELETE FROM notebooks WHERE user_id = ? AND remote_id = ?",
|
||||
userID, req.NotebookId,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "delete: %v", err)
|
||||
}
|
||||
|
||||
return &pb.DeleteNotebookResponse{}, nil
|
||||
}
|
||||
|
||||
func (s *SyncService) ListNotebooks(ctx context.Context, req *pb.ListNotebooksRequest) (*pb.ListNotebooksResponse, error) {
|
||||
userID, ok := UserIDFromContext(ctx)
|
||||
if !ok {
|
||||
return nil, status.Error(codes.Internal, "missing user context")
|
||||
}
|
||||
|
||||
rows, err := s.DB.QueryContext(ctx,
|
||||
`SELECT n.id, n.remote_id, n.title, n.page_size, n.synced_at,
|
||||
(SELECT COUNT(*) FROM pages WHERE notebook_id = n.id) as page_count
|
||||
FROM notebooks n WHERE n.user_id = ? ORDER BY n.synced_at DESC`,
|
||||
userID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "query: %v", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var notebooks []*pb.NotebookSummary
|
||||
for rows.Next() {
|
||||
var nb pb.NotebookSummary
|
||||
var syncedAt int64
|
||||
if err := rows.Scan(&nb.ServerId, &nb.RemoteId, &nb.Title, &nb.PageSize, &syncedAt, &nb.PageCount); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "scan: %v", err)
|
||||
}
|
||||
nb.SyncedAt = timestamppb.New(time.UnixMilli(syncedAt))
|
||||
notebooks = append(notebooks, &nb)
|
||||
}
|
||||
|
||||
return &pb.ListNotebooksResponse{Notebooks: notebooks}, nil
|
||||
}
|
||||
105
internal/share/share.go
Normal file
105
internal/share/share.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package share
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
const tokenBytes = 32
|
||||
|
||||
type LinkInfo struct {
|
||||
Token string
|
||||
URL string
|
||||
CreatedAt time.Time
|
||||
ExpiresAt *time.Time
|
||||
}
|
||||
|
||||
// CreateLink generates a shareable link for a notebook.
|
||||
func CreateLink(database *sql.DB, notebookID int64, expiry time.Duration, baseURL string) (string, *time.Time, error) {
|
||||
raw := make([]byte, tokenBytes)
|
||||
if _, err := rand.Read(raw); err != nil {
|
||||
return "", nil, fmt.Errorf("generate token: %w", err)
|
||||
}
|
||||
|
||||
token := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(raw)
|
||||
now := time.Now().UnixMilli()
|
||||
|
||||
var expiresAt *int64
|
||||
var expiresTime *time.Time
|
||||
if expiry > 0 {
|
||||
ea := time.Now().Add(expiry).UnixMilli()
|
||||
expiresAt = &ea
|
||||
t := time.UnixMilli(ea)
|
||||
expiresTime = &t
|
||||
}
|
||||
|
||||
_, err := database.Exec(
|
||||
"INSERT INTO share_links (notebook_id, token, expires_at, created_at) VALUES (?, ?, ?, ?)",
|
||||
notebookID, token, expiresAt, now,
|
||||
)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("insert share link: %w", err)
|
||||
}
|
||||
|
||||
return token, expiresTime, nil
|
||||
}
|
||||
|
||||
// ValidateLink checks if a token is valid and returns the notebook ID.
|
||||
func ValidateLink(database *sql.DB, token string) (int64, error) {
|
||||
var notebookID int64
|
||||
var expiresAt *int64
|
||||
err := database.QueryRow(
|
||||
"SELECT notebook_id, expires_at FROM share_links WHERE token = ?", token,
|
||||
).Scan(¬ebookID, &expiresAt)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("link not found")
|
||||
}
|
||||
|
||||
if expiresAt != nil && time.Now().UnixMilli() > *expiresAt {
|
||||
return 0, fmt.Errorf("link expired")
|
||||
}
|
||||
|
||||
return notebookID, nil
|
||||
}
|
||||
|
||||
// RevokeLink deletes a share link.
|
||||
func RevokeLink(database *sql.DB, token string) error {
|
||||
_, err := database.Exec("DELETE FROM share_links WHERE token = ?", token)
|
||||
return err
|
||||
}
|
||||
|
||||
// ListLinks returns all active share links for a notebook.
|
||||
func ListLinks(database *sql.DB, notebookID int64, baseURL string) ([]LinkInfo, error) {
|
||||
rows, err := database.Query(
|
||||
"SELECT token, created_at, expires_at FROM share_links WHERE notebook_id = ? ORDER BY created_at DESC",
|
||||
notebookID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var links []LinkInfo
|
||||
for rows.Next() {
|
||||
var token string
|
||||
var createdAt int64
|
||||
var expiresAt *int64
|
||||
if err := rows.Scan(&token, &createdAt, &expiresAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
li := LinkInfo{
|
||||
Token: token,
|
||||
URL: baseURL + "/s/" + token,
|
||||
CreatedAt: time.UnixMilli(createdAt),
|
||||
}
|
||||
if expiresAt != nil {
|
||||
t := time.UnixMilli(*expiresAt)
|
||||
li.ExpiresAt = &t
|
||||
}
|
||||
links = append(links, li)
|
||||
}
|
||||
return links, nil
|
||||
}
|
||||
Reference in New Issue
Block a user