Files
mcp/internal/agent/files.go
Kyle Isom 08b3e2a472 Migrate module path from kyle/ to mc/ org
All import paths updated to git.wntrmute.dev/mc/. Bumps mcdsl to v1.2.0,
mc-proxy to v1.1.0.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-27 02:07:42 -07:00

153 lines
4.2 KiB
Go

package agent
import (
"context"
"fmt"
"io"
"os"
"path/filepath"
"strings"
mcpv1 "git.wntrmute.dev/mc/mcp/gen/mcp/v1"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
// validatePath validates and resolves a relative path within a service's
// /srv/<service>/ directory. It rejects path traversal, absolute paths,
// and symlink escapes.
func validatePath(service, relPath string) (string, error) {
if service == "" {
return "", fmt.Errorf("empty service name")
}
if relPath == "" {
return "", fmt.Errorf("empty path")
}
if filepath.IsAbs(relPath) {
return "", fmt.Errorf("absolute path not allowed: %s", relPath)
}
cleaned := filepath.Clean(relPath)
if strings.Contains(cleaned, "..") {
return "", fmt.Errorf("path traversal not allowed: %s", relPath)
}
serviceDir := filepath.Join("/srv", service)
fullPath := filepath.Join(serviceDir, cleaned)
if !strings.HasPrefix(fullPath, serviceDir+"/") {
return "", fmt.Errorf("path escapes service directory: %s", relPath)
}
parentDir := filepath.Dir(fullPath)
if _, err := os.Stat(parentDir); err == nil {
resolved, err := filepath.EvalSymlinks(parentDir)
if err != nil {
return "", fmt.Errorf("resolve symlinks: %w", err)
}
if !strings.HasPrefix(resolved, serviceDir) {
return "", fmt.Errorf("symlink escapes service directory: %s", relPath)
}
}
return fullPath, nil
}
// PushFile writes a file to the node's filesystem under /srv/<service>/.
func (a *Agent) PushFile(ctx context.Context, req *mcpv1.PushFileRequest) (*mcpv1.PushFileResponse, error) {
if req.Service == "" {
return nil, status.Errorf(codes.InvalidArgument, "service name required")
}
if req.Path == "" {
return nil, status.Errorf(codes.InvalidArgument, "path required")
}
fullPath, err := validatePath(req.Service, req.Path)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid path: %v", err)
}
a.Logger.Info("push file", "service", req.Service, "path", req.Path)
dir := filepath.Dir(fullPath)
if err := os.MkdirAll(dir, 0750); err != nil {
return nil, status.Errorf(codes.Internal, "create directories: %v", err)
}
// Atomic write: temp file in the same directory, then rename.
tmp, err := os.CreateTemp(dir, ".mcp-push-*")
if err != nil {
return nil, status.Errorf(codes.Internal, "create temp file: %v", err)
}
tmpName := tmp.Name()
cleanup := func() { _ = os.Remove(tmpName) }
if _, err := tmp.Write(req.Content); err != nil {
_ = tmp.Close()
cleanup()
return nil, status.Errorf(codes.Internal, "write temp file: %v", err)
}
if err := tmp.Close(); err != nil {
cleanup()
return nil, status.Errorf(codes.Internal, "close temp file: %v", err)
}
mode := os.FileMode(req.Mode)
if mode == 0 {
mode = 0600
}
if err := os.Chmod(tmpName, mode); err != nil {
cleanup()
return nil, status.Errorf(codes.Internal, "set permissions: %v", err)
}
if err := os.Rename(tmpName, fullPath); err != nil {
cleanup()
return nil, status.Errorf(codes.Internal, "rename to target: %v", err)
}
return &mcpv1.PushFileResponse{Success: true}, nil
}
// PullFile reads a file from the node's filesystem under /srv/<service>/.
func (a *Agent) PullFile(ctx context.Context, req *mcpv1.PullFileRequest) (*mcpv1.PullFileResponse, error) {
if req.Service == "" {
return nil, status.Errorf(codes.InvalidArgument, "service name required")
}
if req.Path == "" {
return nil, status.Errorf(codes.InvalidArgument, "path required")
}
fullPath, err := validatePath(req.Service, req.Path)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid path: %v", err)
}
a.Logger.Info("pull file", "service", req.Service, "path", req.Path)
f, err := os.Open(fullPath) //nolint:gosec // path validated by validatePath
if err != nil {
if os.IsNotExist(err) {
return nil, status.Errorf(codes.NotFound, "file not found: %s", req.Path)
}
return nil, status.Errorf(codes.Internal, "open file: %v", err)
}
defer f.Close() //nolint:errcheck
info, err := f.Stat()
if err != nil {
return nil, status.Errorf(codes.Internal, "stat file: %v", err)
}
content, err := io.ReadAll(f)
if err != nil {
return nil, status.Errorf(codes.Internal, "read file: %v", err)
}
return &mcpv1.PullFileResponse{
Content: content,
Mode: uint32(info.Mode().Perm()),
}, nil
}