package server import ( "context" "database/sql" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "git.wntrmute.dev/kyle/exo/artifacts" "git.wntrmute.dev/kyle/exo/core" "git.wntrmute.dev/kyle/exo/db" "git.wntrmute.dev/kyle/exo/kg" pb "git.wntrmute.dev/kyle/exo/proto/exo/v1" ) // KGServer implements the KnowledgeGraphService gRPC service. type KGServer struct { pb.UnimplementedKnowledgeGraphServiceServer database *sql.DB } // NewKGServer creates a new KGServer. func NewKGServer(database *sql.DB) *KGServer { return &KGServer{database: database} } func (s *KGServer) CreateNode(ctx context.Context, req *pb.CreateNodeRequest) (*pb.CreateNodeResponse, error) { if req.Name == "" { return nil, status.Error(codes.InvalidArgument, "name is required") } nodeType := kg.NodeType(req.Type) if nodeType == "" { nodeType = kg.NodeTypeNote } node := kg.NewNode(req.Name, nodeType) node.ParentID = req.ParentId node.Tags = core.MapFromList(req.Tags) node.Categories = core.MapFromList(req.Categories) tx, err := db.StartTX(ctx, s.database) if err != nil { return nil, status.Errorf(codes.Internal, "failed to start transaction: %v", err) } // Create tags and categories idempotently. for tag := range node.Tags { if err := artifacts.CreateTag(ctx, tx, tag); err != nil { _ = tx.Rollback() return nil, status.Errorf(codes.Internal, "failed to create tag: %v", err) } } for cat := range node.Categories { if err := artifacts.CreateCategory(ctx, tx, cat); err != nil { _ = tx.Rollback() return nil, status.Errorf(codes.Internal, "failed to create category: %v", err) } } if err := node.Store(ctx, tx); err != nil { _ = tx.Rollback() return nil, status.Errorf(codes.Internal, "failed to store node: %v", err) } if err := tx.Commit(); err != nil { return nil, status.Errorf(codes.Internal, "failed to commit: %v", err) } return &pb.CreateNodeResponse{Id: node.ID}, nil } func (s *KGServer) GetNode(ctx context.Context, req *pb.GetNodeRequest) (*pb.GetNodeResponse, error) { if req.Id == "" { return nil, status.Error(codes.InvalidArgument, "id is required") } tx, err := db.StartTX(ctx, s.database) if err != nil { return nil, status.Errorf(codes.Internal, "failed to start transaction: %v", err) } node := &kg.Node{ID: req.Id} if err := node.Get(ctx, tx); err != nil { _ = tx.Rollback() return nil, status.Errorf(codes.NotFound, "node not found: %v", err) } cells, err := kg.GetCellsForNode(ctx, tx, req.Id) if err != nil { _ = tx.Rollback() return nil, status.Errorf(codes.Internal, "failed to get cells: %v", err) } if err := tx.Commit(); err != nil { return nil, status.Errorf(codes.Internal, "failed to commit: %v", err) } resp := &pb.GetNodeResponse{ Node: &pb.Node{ Id: node.ID, ParentId: node.ParentID, Name: node.Name, Type: string(node.Type), Created: db.ToDBTime(node.Created), Modified: db.ToDBTime(node.Modified), Children: node.Children, Tags: core.ListFromMap(node.Tags), Categories: core.ListFromMap(node.Categories), }, } for _, c := range cells { resp.Cells = append(resp.Cells, &pb.Cell{ Id: c.ID, NodeId: c.NodeID, Type: string(c.Type), Contents: c.Contents, Ordinal: int32(c.Ordinal), //nolint:gosec // ordinal values are small Created: db.ToDBTime(c.Created), Modified: db.ToDBTime(c.Modified), }) } return resp, nil } func (s *KGServer) AddCell(ctx context.Context, req *pb.AddCellRequest) (*pb.AddCellResponse, error) { if req.NodeId == "" { return nil, status.Error(codes.InvalidArgument, "node_id is required") } cellType := kg.CellType(req.Type) if cellType == "" { cellType = kg.CellTypeMarkdown } cell := kg.NewCell(req.NodeId, cellType, req.Contents) cell.Ordinal = int(req.Ordinal) tx, err := db.StartTX(ctx, s.database) if err != nil { return nil, status.Errorf(codes.Internal, "failed to start transaction: %v", err) } if err := cell.Store(ctx, tx); err != nil { _ = tx.Rollback() return nil, status.Errorf(codes.Internal, "failed to store cell: %v", err) } if err := tx.Commit(); err != nil { return nil, status.Errorf(codes.Internal, "failed to commit: %v", err) } return &pb.AddCellResponse{Id: cell.ID}, nil } func (s *KGServer) RecordFact(ctx context.Context, req *pb.RecordFactRequest) (*pb.RecordFactResponse, error) { if req.EntityId == "" { return nil, status.Error(codes.InvalidArgument, "entity_id is required") } value := core.Value{} if req.Value != nil { value = core.Value{Contents: req.Value.Contents, Type: req.Value.Type} } fact := kg.NewFact(req.EntityId, req.EntityName, req.AttributeId, req.AttributeName, value) fact.Retraction = req.Retraction tx, err := db.StartTX(ctx, s.database) if err != nil { return nil, status.Errorf(codes.Internal, "failed to start transaction: %v", err) } if err := fact.Store(ctx, tx); err != nil { _ = tx.Rollback() return nil, status.Errorf(codes.Internal, "failed to store fact: %v", err) } if err := tx.Commit(); err != nil { return nil, status.Errorf(codes.Internal, "failed to commit: %v", err) } return &pb.RecordFactResponse{Id: fact.ID}, nil } func (s *KGServer) GetFacts(ctx context.Context, req *pb.GetFactsRequest) (*pb.GetFactsResponse, error) { if req.EntityId == "" { return nil, status.Error(codes.InvalidArgument, "entity_id is required") } tx, err := db.StartTX(ctx, s.database) if err != nil { return nil, status.Errorf(codes.Internal, "failed to start transaction: %v", err) } var facts []*kg.Fact if req.CurrentOnly { facts, err = kg.GetCurrentFactsForEntity(ctx, tx, req.EntityId) } else { facts, err = kg.GetFactsForEntity(ctx, tx, req.EntityId) } if err != nil { _ = tx.Rollback() return nil, status.Errorf(codes.Internal, "failed to get facts: %v", err) } if err := tx.Commit(); err != nil { return nil, status.Errorf(codes.Internal, "failed to commit: %v", err) } resp := &pb.GetFactsResponse{} for _, f := range facts { resp.Facts = append(resp.Facts, &pb.Fact{ Id: f.ID, EntityId: f.EntityID, EntityName: f.EntityName, AttributeId: f.AttributeID, AttributeName: f.AttributeName, Value: &pb.Value{Contents: f.Value.Contents, Type: f.Value.Type}, TxTimestamp: f.TxTimestamp.Unix(), Retraction: f.Retraction, }) } return resp, nil } func (s *KGServer) AddEdge(ctx context.Context, req *pb.AddEdgeRequest) (*pb.AddEdgeResponse, error) { if req.SourceId == "" || req.TargetId == "" { return nil, status.Error(codes.InvalidArgument, "source_id and target_id are required") } edge := kg.NewEdge(req.SourceId, req.TargetId, kg.EdgeRelation(req.Relation)) tx, err := db.StartTX(ctx, s.database) if err != nil { return nil, status.Errorf(codes.Internal, "failed to start transaction: %v", err) } if err := edge.Store(ctx, tx); err != nil { _ = tx.Rollback() return nil, status.Errorf(codes.Internal, "failed to store edge: %v", err) } if err := tx.Commit(); err != nil { return nil, status.Errorf(codes.Internal, "failed to commit: %v", err) } return &pb.AddEdgeResponse{Id: edge.ID}, nil } func (s *KGServer) GetEdges(ctx context.Context, req *pb.GetEdgesRequest) (*pb.GetEdgesResponse, error) { if req.NodeId == "" { return nil, status.Error(codes.InvalidArgument, "node_id is required") } tx, err := db.StartTX(ctx, s.database) if err != nil { return nil, status.Errorf(codes.Internal, "failed to start transaction: %v", err) } var edges []*kg.Edge switch req.Direction { case "to": edges, err = kg.GetEdgesTo(ctx, tx, req.NodeId) case "from", "": edges, err = kg.GetEdgesFrom(ctx, tx, req.NodeId) case "both": from, err2 := kg.GetEdgesFrom(ctx, tx, req.NodeId) if err2 != nil { _ = tx.Rollback() return nil, status.Errorf(codes.Internal, "failed to get edges: %v", err2) } to, err2 := kg.GetEdgesTo(ctx, tx, req.NodeId) if err2 != nil { _ = tx.Rollback() return nil, status.Errorf(codes.Internal, "failed to get edges: %v", err2) } edges = append(from, to...) default: _ = tx.Rollback() return nil, status.Errorf(codes.InvalidArgument, "direction must be 'from', 'to', or 'both'") } if err != nil { _ = tx.Rollback() return nil, status.Errorf(codes.Internal, "failed to get edges: %v", err) } if err := tx.Commit(); err != nil { return nil, status.Errorf(codes.Internal, "failed to commit: %v", err) } resp := &pb.GetEdgesResponse{} for _, e := range edges { resp.Edges = append(resp.Edges, &pb.Edge{ Id: e.ID, SourceId: e.SourceID, TargetId: e.TargetID, Relation: string(e.Relation), Created: db.ToDBTime(e.Created), }) } return resp, nil }