package agent import ( "context" "fmt" "io" "os" "path/filepath" "strings" mcpv1 "git.wntrmute.dev/mc/mcp/gen/mcp/v1" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) // validatePath validates and resolves a relative path within a service's // /srv// directory. It rejects path traversal, absolute paths, // and symlink escapes. func validatePath(service, relPath string) (string, error) { if service == "" { return "", fmt.Errorf("empty service name") } if relPath == "" { return "", fmt.Errorf("empty path") } if filepath.IsAbs(relPath) { return "", fmt.Errorf("absolute path not allowed: %s", relPath) } cleaned := filepath.Clean(relPath) if strings.Contains(cleaned, "..") { return "", fmt.Errorf("path traversal not allowed: %s", relPath) } serviceDir := filepath.Join("/srv", service) fullPath := filepath.Join(serviceDir, cleaned) if !strings.HasPrefix(fullPath, serviceDir+"/") { return "", fmt.Errorf("path escapes service directory: %s", relPath) } parentDir := filepath.Dir(fullPath) if _, err := os.Stat(parentDir); err == nil { resolved, err := filepath.EvalSymlinks(parentDir) if err != nil { return "", fmt.Errorf("resolve symlinks: %w", err) } if !strings.HasPrefix(resolved, serviceDir) { return "", fmt.Errorf("symlink escapes service directory: %s", relPath) } } return fullPath, nil } // PushFile writes a file to the node's filesystem under /srv//. func (a *Agent) PushFile(ctx context.Context, req *mcpv1.PushFileRequest) (*mcpv1.PushFileResponse, error) { if req.Service == "" { return nil, status.Errorf(codes.InvalidArgument, "service name required") } if req.Path == "" { return nil, status.Errorf(codes.InvalidArgument, "path required") } fullPath, err := validatePath(req.Service, req.Path) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "invalid path: %v", err) } a.Logger.Info("push file", "service", req.Service, "path", req.Path) dir := filepath.Dir(fullPath) if err := os.MkdirAll(dir, 0750); err != nil { return nil, status.Errorf(codes.Internal, "create directories: %v", err) } // Atomic write: temp file in the same directory, then rename. tmp, err := os.CreateTemp(dir, ".mcp-push-*") if err != nil { return nil, status.Errorf(codes.Internal, "create temp file: %v", err) } tmpName := tmp.Name() cleanup := func() { _ = os.Remove(tmpName) } if _, err := tmp.Write(req.Content); err != nil { _ = tmp.Close() cleanup() return nil, status.Errorf(codes.Internal, "write temp file: %v", err) } if err := tmp.Close(); err != nil { cleanup() return nil, status.Errorf(codes.Internal, "close temp file: %v", err) } mode := os.FileMode(req.Mode) if mode == 0 { mode = 0600 } if err := os.Chmod(tmpName, mode); err != nil { cleanup() return nil, status.Errorf(codes.Internal, "set permissions: %v", err) } if err := os.Rename(tmpName, fullPath); err != nil { cleanup() return nil, status.Errorf(codes.Internal, "rename to target: %v", err) } return &mcpv1.PushFileResponse{Success: true}, nil } // PullFile reads a file from the node's filesystem under /srv//. func (a *Agent) PullFile(ctx context.Context, req *mcpv1.PullFileRequest) (*mcpv1.PullFileResponse, error) { if req.Service == "" { return nil, status.Errorf(codes.InvalidArgument, "service name required") } if req.Path == "" { return nil, status.Errorf(codes.InvalidArgument, "path required") } fullPath, err := validatePath(req.Service, req.Path) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "invalid path: %v", err) } a.Logger.Info("pull file", "service", req.Service, "path", req.Path) f, err := os.Open(fullPath) //nolint:gosec // path validated by validatePath if err != nil { if os.IsNotExist(err) { return nil, status.Errorf(codes.NotFound, "file not found: %s", req.Path) } return nil, status.Errorf(codes.Internal, "open file: %v", err) } defer f.Close() //nolint:errcheck info, err := f.Stat() if err != nil { return nil, status.Errorf(codes.Internal, "stat file: %v", err) } content, err := io.ReadAll(f) if err != nil { return nil, status.Errorf(codes.Internal, "read file: %v", err) } return &mcpv1.PullFileResponse{ Content: content, Mode: uint32(info.Mode().Perm()), }, nil }