From 7f328f7f56843f96b7e740c9f51704d6514ffb37 Mon Sep 17 00:00:00 2001 From: Lan Luo Date: Wed, 18 Mar 2026 11:21:11 -0700 Subject: [PATCH 1/2] initial checkin --- .../model_static_shape.py | 61 +++++++++++++ py/torch_tensorrt/_compile.py | 50 ++++++++++- .../dynamo/runtime/_TorchTensorRTModule.py | 2 +- py/torch_tensorrt/executorch/__init__.py | 9 ++ py/torch_tensorrt/executorch/backend.py | 85 +++++++++++++++++++ .../executorch/operator_support.py | 26 ++++++ py/torch_tensorrt/executorch/partitioner.py | 63 ++++++++++++++ py/torch_tensorrt/executorch/serialization.py | 32 +++++++ 8 files changed, 323 insertions(+), 5 deletions(-) create mode 100644 examples/torchtrt_executorch_example/model_static_shape.py create mode 100644 py/torch_tensorrt/executorch/__init__.py create mode 100644 py/torch_tensorrt/executorch/backend.py create mode 100644 py/torch_tensorrt/executorch/operator_support.py create mode 100644 py/torch_tensorrt/executorch/partitioner.py create mode 100644 py/torch_tensorrt/executorch/serialization.py diff --git a/examples/torchtrt_executorch_example/model_static_shape.py b/examples/torchtrt_executorch_example/model_static_shape.py new file mode 100644 index 0000000000..f94a7d3c0a --- /dev/null +++ b/examples/torchtrt_executorch_example/model_static_shape.py @@ -0,0 +1,61 @@ +""" +.. _executorch_export: + +Saving a Torch-TensorRT Model in ExecuTorch Format (.pte) +========================================================= + +This example demonstrates how to compile a model with Torch-TensorRT and save it +as an ExecuTorch ``.pte`` file, which can be loaded by the ExecuTorch runtime +(e.g., on embedded or mobile devices with a TensorRT-capable backend). + +Prerequisites +------------- +Install ExecuTorch before running this example:: + + pip install executorch + +See https://pytorch.org/executorch/stable/getting-started-setup.html for details. +""" + +# %% +# Imports and Model Definition +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +import torch +import torch_tensorrt + + +class MyModel(torch.nn.Module): + def forward(self, x): + return x + 1 + + +# %% +# Compile with Torch-TensorRT +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# Export the model, compile it with TensorRT, then save as .pte + +with torch.no_grad(): + model = MyModel().eval().cuda() + example_input = (torch.randn((2, 3, 4, 4)).cuda(),) + + exported_program = torch.export.export(model, example_input) + compile_settings = { + "arg_inputs": [ + torch_tensorrt.Input(shape=(2, 3, 4, 4), dtype=torch.float32), + ], + "min_block_size": 1, + } + trt_gm = torch_tensorrt.dynamo.compile(exported_program, **compile_settings) + + # %% + # Save as ExecuTorch .pte format (loadable by the ExecuTorch runtime) + # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + # The TensorRT engine is serialized inside the .pte using the same blob format + # as the Torch-TensorRT runtime (vector of strings), so one engine format for + # both ExecuTorch and non-ExecuTorch deployment. + torch_tensorrt.save( + trt_gm, "model.pte", output_format="executorch", arg_inputs=example_input + ) + + print("Saved model.pte successfully.") diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index c4dbb1c148..3e25cce239 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -653,7 +653,7 @@ def save( inputs (Union[torch.Tensor, torch_tensorrt.Input]): Torch input tensors or Input specifications arg_inputs (Tuple[Union[torch.Tensor, torch_tensorrt.Input], ...]): Same as inputs. Alias for better understanding with kwarg_inputs. kwarg_inputs (dict[str, Union[torch.Tensor, torch_tensorrt.Input]]): Optional, kwarg inputs to the module forward function. - output_format (str): Format to save the model. Options include exported_program | torchscript | aot_inductor. + output_format (str): Format to save the model. Options include exported_program | torchscript | aot_inductor | executorch. retrace (bool): When the module type is a fx.GraphModule, this option re-exports the graph using torch.export.export(strict=False) to save it. For TRT-compiled modules with dynamic shapes, both retrace=True and retrace=False are supported: @@ -726,7 +726,7 @@ def save( if isinstance(module, CudaGraphsTorchTensorRTModule): module = module.compiled_module module_type = _parse_module_type(module) - accepted_formats = {"exported_program", "torchscript", "aot_inductor"} + accepted_formats = {"exported_program", "torchscript", "aot_inductor", "executorch"} if arg_inputs is not None and not all( isinstance(input, (torch.Tensor, Input)) for input in arg_inputs ): @@ -847,12 +847,16 @@ def _extract_tensor(obj: Any) -> Any: if output_format not in accepted_formats: raise ValueError( - f"Provided output_format {output_format} is not supported. Supported options are exported_program | torchscript" + f"Provided output_format {output_format} is not supported. Supported options are exported_program | torchscript | aot_inductor | executorch" ) if output_format == "aot_inductor" and platform.system() != "Linux": raise ValueError( f"The AOT Inductor format is only supported on Linux, {platform.system()} is not a supported platform for this format" ) + if output_format == "executorch" and platform.system() != "Linux": + raise ValueError( + f"The executorch format is only supported on Linux, {platform.system()} is not a supported platform for this format" + ) if not file_path: raise ValueError("File path cannot be empty. Please provide a valid file path") @@ -906,6 +910,8 @@ def _extract_tensor(obj: Any) -> Any: inductor_configs=inductor_configs, package_path=file_path, ) + elif output_format == "executorch": + _save_as_executorch(module, file_path) else: raise RuntimeError( "Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor" @@ -963,6 +969,8 @@ def _extract_tensor(obj: Any) -> Any: inductor_configs=inductor_configs, package_path=file_path, ) + elif output_format == "executorch": + _save_as_executorch(exp_program, file_path) else: raise RuntimeError( "Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor" @@ -1014,7 +1022,7 @@ def _extract_tensor(obj: Any) -> Any: "Provided model is a torch.fx.GraphModule without existing shape metadata and retrace is True, however no inputs specs were provided. " "Please provide valid torch.Tensors or torch_tensorrt.Input objects as inputs to retrace and save the model" ) - + breakpoint() exp_program = torch.export.export( module, args=tuple(arg_tensors), @@ -1042,12 +1050,46 @@ def _extract_tensor(obj: Any) -> Any: inductor_configs=inductor_configs, package_path=file_path, ) + elif output_format == "executorch": + _save_as_executorch(exp_program, file_path) else: raise RuntimeError( "Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor" ) +def _save_as_executorch(exp_program: Any, file_path: str) -> None: + """Save an ExportedProgram (with TensorRT execute_engine nodes) as an ExecuTorch .pte file. + + Partitions the graph by torch.ops.tensorrt.execute_engine, serializes each engine + to the same blob format as the TRT runtime (vector of strings), and embeds it + in the .pte. Requires the ``executorch`` package and torch_tensorrt_runtime. See + https://pytorch.org/executorch/stable/getting-started-setup.html + """ + if not ENABLED_FEATURES.torch_tensorrt_runtime: + raise RuntimeError( + "output_format='executorch' requires the Torch-TensorRT runtime " + "(torch_tensorrt_runtime). Reinstall torch_tensorrt with the runtime extension." + ) + try: + from executorch.exir import to_edge_transform_and_lower + except ImportError: + raise ImportError( + "ExecuTorch is not installed. Please install it to use output_format='executorch'. " + "See https://pytorch.org/executorch/stable/getting-started-setup.html" + ) + from torch_tensorrt.executorch import TensorRTPartitioner + + breakpoint() + edge_program = to_edge_transform_and_lower( + exp_program, + partitioner=[TensorRTPartitioner()], + ) + executorch_program = edge_program.to_executorch() + with open(file_path, "wb") as f: + executorch_program.write_to_file(f) + + def function_overload_with_kwargs( fn: Callable[..., Any], *args: Any, **kwargs: Any ) -> Any: diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index d77c0bf39f..91994ca49b 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -338,6 +338,7 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]: Returns: torch.Tensor or Tuple(torch.Tensor): Result of the engine computation """ + breakpoint() if self.engine is None: raise RuntimeError("Engine has not been setup yet.") @@ -354,7 +355,6 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]: (i if isinstance(i, torch.Tensor) else torch.tensor(i).cuda()) for i in inputs ] - outputs: List[torch.Tensor] = torch.ops.tensorrt.execute_engine( list(input_tensors), self.engine ) diff --git a/py/torch_tensorrt/executorch/__init__.py b/py/torch_tensorrt/executorch/__init__.py new file mode 100644 index 0000000000..81aa088610 --- /dev/null +++ b/py/torch_tensorrt/executorch/__init__.py @@ -0,0 +1,9 @@ +# ExecuTorch backend for Torch-TensorRT: save/load .pte with TensorRT delegate. + +from torch_tensorrt.executorch.backend import TensorRTBackend +from torch_tensorrt.executorch.partitioner import TensorRTPartitioner + +__all__ = [ + "TensorRTBackend", + "TensorRTPartitioner", +] diff --git a/py/torch_tensorrt/executorch/backend.py b/py/torch_tensorrt/executorch/backend.py new file mode 100644 index 0000000000..1c9ba4c615 --- /dev/null +++ b/py/torch_tensorrt/executorch/backend.py @@ -0,0 +1,85 @@ +# ExecuTorch TensorRT backend: serialize engine to same blob format as TRT runtime. + +import base64 +from typing import Any, List, final + +import torch +from executorch.exir.backend.backend_details import ( + BackendDetails, + CompileSpec, + PreprocessResult, +) +from torch.export.exported_program import ExportedProgram +from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ( + ENGINE_IDX, + SERIALIZATION_LEN, +) +from torch_tensorrt.executorch.serialization import serialize_engine_info + + +def _get_engine_info_from_edge_program(edge_program: ExportedProgram) -> List[Any]: + """Extract engine info (list of strings/bytes) from the partition's execute_engine node.""" + gm = edge_program.graph_module + execute_engine_op = torch.ops.tensorrt.execute_engine.default + + for node in gm.graph.nodes: + if node.op != "call_function" or node.target is not execute_engine_op: + continue + if len(node.args) < 2: + continue + engine_arg = node.args[1] + if engine_arg.op == "get_attr": + val = getattr(gm, engine_arg.target, None) + if val is None: + raise RuntimeError( + f"Engine get_attr({engine_arg.target}) not found on partition module." + ) + if hasattr(val, "__getstate__"): + engine_info = val.__getstate__() + else: + engine_info = getattr(val, "engine_info", val) + if ( + isinstance(engine_info, (list, tuple)) + and len(engine_info) >= SERIALIZATION_LEN + ): + return list(engine_info) + raise RuntimeError( + f"Engine argument get_attr({engine_arg.target}) did not yield engine info list (len >= {SERIALIZATION_LEN})." + ) + raise RuntimeError( + "TensorRT ExecuTorch backend expects execute_engine(inputs, engine) " + "where engine is a get_attr; cannot find engine." + ) + raise RuntimeError( + "TensorRT ExecuTorch backend: no execute_engine node found in partition." + ) + + +@final +class TensorRTBackend(BackendDetails): # type: ignore[misc] + """Backend that serializes TensorRT engine to the same blob format as the TRT runtime. + + The partition contains a single execute_engine node; we extract the engine + and metadata and encode them as a vector of strings (same layout as + core/runtime/runtime.h SerializedInfoIndex) so the same blob works for + both ExecuTorch and non-ExecuTorch TRT runtime. + """ + + @staticmethod + def preprocess( + edge_program: ExportedProgram, + compile_specs: List[CompileSpec], + ) -> PreprocessResult: + engine_info = _get_engine_info_from_edge_program(edge_program) + engine_info = list(engine_info) + serialized_engine = engine_info[ENGINE_IDX] + if isinstance(serialized_engine, str): + engine_info[ENGINE_IDX] = base64.b64decode( + serialized_engine.encode("utf-8") + ) + elif not isinstance(serialized_engine, (bytes, bytearray)): + engine_info[ENGINE_IDX] = bytes(serialized_engine) + if len(engine_info) > 7 and isinstance(engine_info[7], bytes): + engine_info[7] = engine_info[7].decode("utf-8", errors="replace") + blob = serialize_engine_info(engine_info) + return PreprocessResult(processed_bytes=blob) diff --git a/py/torch_tensorrt/executorch/operator_support.py b/py/torch_tensorrt/executorch/operator_support.py new file mode 100644 index 0000000000..32763665c2 --- /dev/null +++ b/py/torch_tensorrt/executorch/operator_support.py @@ -0,0 +1,26 @@ +# Operator support for ExecuTorch TensorRT partitioner: only execute_engine is supported. + +from typing import Dict + +import torch +from torch.fx.passes.operator_support import OperatorSupportBase + + +class TensorRTOperatorSupport(OperatorSupportBase): # type: ignore[misc] + """Supports only torch.ops.tensorrt.execute_engine for partitioning. + + Used so that TRT-compiled graphs (which already contain execute_engine nodes) + are partitioned per engine; each partition is then lowered to TensorRTBackend + which serializes the engine to the same blob format as the TRT runtime. + """ + + def __init__(self) -> None: + super().__init__() + self._execute_engine_op = torch.ops.tensorrt.execute_engine.default + + def is_node_supported( + self, submodules: Dict[str, torch.nn.Module], node: torch.fx.Node + ) -> bool: + if node.op != "call_function": + return False + return node.target is self._execute_engine_op diff --git a/py/torch_tensorrt/executorch/partitioner.py b/py/torch_tensorrt/executorch/partitioner.py new file mode 100644 index 0000000000..9fcab9f709 --- /dev/null +++ b/py/torch_tensorrt/executorch/partitioner.py @@ -0,0 +1,63 @@ +# ExecuTorch partitioner: partition by execute_engine nodes. + +from typing import Callable, Dict, List, Optional, Tuple + +import torch +from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.exir.backend.partitioner import ( + DelegationSpec, + Partitioner, + PartitionResult, +) +from executorch.exir.backend.utils import tag_constant_data +from torch.export import ExportedProgram +from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner +from torch_tensorrt.executorch.backend import TensorRTBackend +from torch_tensorrt.executorch.operator_support import TensorRTOperatorSupport + + +class TensorRTPartitioner(Partitioner): # type: ignore[misc] + """Partitions the graph for TensorRT delegation. + + Only nodes that are torch.ops.tensorrt.execute_engine are supported; + each such node becomes its own partition so the backend can serialize + the engine to the same format as the TRT runtime. + """ + + def __init__( + self, + compile_specs: Optional[List[CompileSpec]] = None, + ) -> None: + super().__init__() + self.compile_specs = compile_specs or [] + self.delegation_spec = DelegationSpec( + backend_id=TensorRTBackend.__name__, + compile_specs=self.compile_specs, + ) + + def partition(self, exported_program: ExportedProgram) -> PartitionResult: + capability_partitioner = CapabilityBasedPartitioner( + exported_program.graph_module, + TensorRTOperatorSupport(), + allows_single_node_partition=True, + ) + partition_list = capability_partitioner.propose_partitions() + + partition_tags: Dict[str, DelegationSpec] = {} + for partition in partition_list: + tag = f"tensorrt_{partition.id}" + for node in partition.nodes: + node.meta["delegation_tag"] = tag + partition_tags[tag] = self.delegation_spec + + tag_constant_data(exported_program) + + return PartitionResult( + tagged_exported_program=exported_program, + partition_tags=partition_tags, + ) + + def ops_to_not_decompose( + self, ep: ExportedProgram + ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: + return ([], None) diff --git a/py/torch_tensorrt/executorch/serialization.py b/py/torch_tensorrt/executorch/serialization.py new file mode 100644 index 0000000000..742269973d --- /dev/null +++ b/py/torch_tensorrt/executorch/serialization.py @@ -0,0 +1,32 @@ +# Serialization for ExecuTorch TensorRT blob: same format as TRT runtime (vector of strings). +# Uses the same list format as TorchTensorRTModule._pack_engine_info, then encodes to bytes. +# Only valid when ENABLED_FEATURES.torch_tensorrt_runtime is True. + +import struct +from typing import List, Union + +from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import SERIALIZATION_LEN + + +def serialize_engine_info(engine_info: List[Union[str, bytes]]) -> bytes: + """Encode engine info list (same format as TorchTensorRTModule._pack_engine_info) to bytes. + + Takes the list produced by _pack_engine_info (or equivalent) and writes it in the + TRT runtime vector format: 4-byte count (SERIALIZATION_LEN), then for each + entry 4-byte length (LE) + raw bytes. C++ can deserialize to std::vector + and pass to TRTEngine(std::vector serialized_info). + """ + if len(engine_info) < SERIALIZATION_LEN: + engine_info = list(engine_info) + [""] * (SERIALIZATION_LEN - len(engine_info)) + parts: List[bytes] = [] + for i in range(SERIALIZATION_LEN): + raw = engine_info[i] + if isinstance(raw, str): + raw = raw.encode("utf-8") + elif raw is None: + raw = b"" + else: + raw = bytes(raw) + parts.append(struct.pack(" Date: Wed, 18 Mar 2026 11:54:13 -0700 Subject: [PATCH 2/2] test1 --- .../model_static_shape.py | 8 ++++- py/torch_tensorrt/_compile.py | 5 +-- .../dynamo/runtime/_TorchTensorRTModule.py | 2 +- .../runtime/meta_ops/register_meta_ops.py | 34 ++++++++++++++++--- 4 files changed, 40 insertions(+), 9 deletions(-) diff --git a/examples/torchtrt_executorch_example/model_static_shape.py b/examples/torchtrt_executorch_example/model_static_shape.py index f94a7d3c0a..f6f9e9006d 100644 --- a/examples/torchtrt_executorch_example/model_static_shape.py +++ b/examples/torchtrt_executorch_example/model_static_shape.py @@ -54,8 +54,14 @@ def forward(self, x): # The TensorRT engine is serialized inside the .pte using the same blob format # as the Torch-TensorRT runtime (vector of strings), so one engine format for # both ExecuTorch and non-ExecuTorch deployment. + # Use retrace=False so the legacy exporter is used; the engine is then available + # when ExecuTorch's partitioner runs the graph. torch_tensorrt.save( - trt_gm, "model.pte", output_format="executorch", arg_inputs=example_input + trt_gm, + "model.pte", + output_format="executorch", + arg_inputs=example_input, + retrace=False, ) print("Saved model.pte successfully.") diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index 3e25cce239..6491d91c1f 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -1022,7 +1022,6 @@ def _extract_tensor(obj: Any) -> Any: "Provided model is a torch.fx.GraphModule without existing shape metadata and retrace is True, however no inputs specs were provided. " "Please provide valid torch.Tensors or torch_tensorrt.Input objects as inputs to retrace and save the model" ) - breakpoint() exp_program = torch.export.export( module, args=tuple(arg_tensors), @@ -1078,9 +1077,11 @@ def _save_as_executorch(exp_program: Any, file_path: str) -> None: "ExecuTorch is not installed. Please install it to use output_format='executorch'. " "See https://pytorch.org/executorch/stable/getting-started-setup.html" ) + # Ensure execute_engine fake kernel is registered so partitioner can run + # when the engine is a CustomObjArgument (export placeholder). + import torch_tensorrt.dynamo.runtime.meta_ops.register_meta_ops # noqa: F401 from torch_tensorrt.executorch import TensorRTPartitioner - breakpoint() edge_program = to_edge_transform_and_lower( exp_program, partitioner=[TensorRTPartitioner()], diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 91994ca49b..2de6b10810 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -338,7 +338,7 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]: Returns: torch.Tensor or Tuple(torch.Tensor): Result of the engine computation """ - breakpoint() + if self.engine is None: raise RuntimeError("Engine has not been setup yet.") diff --git a/py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py b/py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py index e03c88153c..83ac0644b0 100644 --- a/py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py +++ b/py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py @@ -189,6 +189,22 @@ def fake_aten_cudnn_grid_sampler( return torch.empty(out_shape, dtype=input.dtype, device=input.device) +def _is_placeholder_engine(engine: Any) -> bool: + """True if engine is a placeholder (CustomObjArgument/FakeScriptObject) from export.""" + if engine is None: + return True + type_name = type(engine).__name__ + if type_name == "CustomObjArgument": + return True + if type_name == "FakeScriptObject": + return True + if hasattr(engine, "fake_val") and engine.fake_val is not None: + return True + if not hasattr(engine, "get_serialized_metadata"): + return True + return False + + @torch.library.register_fake("tensorrt::execute_engine") # type: ignore def fake_tensorrt_execute_engine( inputs: List[torch.Tensor], fake_trt_engine: Any @@ -196,13 +212,23 @@ def fake_tensorrt_execute_engine( """ Meta kernel for TensorRT engine execution. - Uses symbolic shape expressions captured at compile time to correctly infer - output shapes while preserving symbolic SymInt relationships. + When the engine is a placeholder (CustomObjArgument/FakeScriptObject from + torch.export/ExecuTorch), returns one fake output per input (same shape/dtype) + so partitioners can run without a real engine. Otherwise uses symbolic shape + expressions from metadata to infer output shapes. """ + if _is_placeholder_engine(fake_trt_engine): + from torch._guards import detect_fake_mode + + fake_mode = detect_fake_mode(inputs) if inputs else None + if not inputs: + return [torch.empty(())] + if fake_mode is not None: + return [fake_mode.from_tensor(inputs[0])] + return [torch.empty_like(inputs[0])] metadata = None if hasattr(fake_trt_engine, "real_obj"): - # Wrapped C++ engine with real_obj trt_engine = fake_trt_engine.real_obj metadata = TorchTensorRTModule.decode_metadata( trt_engine.get_serialized_metadata() @@ -215,8 +241,6 @@ def fake_tensorrt_execute_engine( shape_info = metadata.get("inout_symexprs") if metadata else None if shape_info: - # Apply the symbolic shape expressions to create output fake tensors - # shape_info now contains both 'inputs' and 'outputs' keys return _apply_symbolic_shape_expressions(inputs, shape_info) else: raise RuntimeError(