Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions examples/torchtrt_executorch_example/model_static_shape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""
.. _executorch_export:

Saving a Torch-TensorRT Model in ExecuTorch Format (.pte)
=========================================================

This example demonstrates how to compile a model with Torch-TensorRT and save it
as an ExecuTorch ``.pte`` file, which can be loaded by the ExecuTorch runtime
(e.g., on embedded or mobile devices with a TensorRT-capable backend).

Prerequisites
-------------
Install ExecuTorch before running this example::

pip install executorch

See https://pytorch.org/executorch/stable/getting-started-setup.html for details.
"""

# %%
# Imports and Model Definition
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

import torch
import torch_tensorrt


class MyModel(torch.nn.Module):
def forward(self, x):
return x + 1


# %%
# Compile with Torch-TensorRT
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# Export the model, compile it with TensorRT, then save as .pte

with torch.no_grad():
model = MyModel().eval().cuda()
example_input = (torch.randn((2, 3, 4, 4)).cuda(),)

exported_program = torch.export.export(model, example_input)
compile_settings = {
"arg_inputs": [
torch_tensorrt.Input(shape=(2, 3, 4, 4), dtype=torch.float32),
],
"min_block_size": 1,
}
trt_gm = torch_tensorrt.dynamo.compile(exported_program, **compile_settings)

# %%
# Save as ExecuTorch .pte format (loadable by the ExecuTorch runtime)
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# The TensorRT engine is serialized inside the .pte using the same blob format
# as the Torch-TensorRT runtime (vector of strings), so one engine format for
# both ExecuTorch and non-ExecuTorch deployment.
# Use retrace=False so the legacy exporter is used; the engine is then available
# when ExecuTorch's partitioner runs the graph.
torch_tensorrt.save(
trt_gm,
"model.pte",
output_format="executorch",
arg_inputs=example_input,
retrace=False,
)

print("Saved model.pte successfully.")
51 changes: 47 additions & 4 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ def save(
inputs (Union[torch.Tensor, torch_tensorrt.Input]): Torch input tensors or Input specifications
arg_inputs (Tuple[Union[torch.Tensor, torch_tensorrt.Input], ...]): Same as inputs. Alias for better understanding with kwarg_inputs.
kwarg_inputs (dict[str, Union[torch.Tensor, torch_tensorrt.Input]]): Optional, kwarg inputs to the module forward function.
output_format (str): Format to save the model. Options include exported_program | torchscript | aot_inductor.
output_format (str): Format to save the model. Options include exported_program | torchscript | aot_inductor | executorch.
retrace (bool): When the module type is a fx.GraphModule, this option re-exports the graph using torch.export.export(strict=False) to save it.

For TRT-compiled modules with dynamic shapes, both retrace=True and retrace=False are supported:
Expand Down Expand Up @@ -726,7 +726,7 @@ def save(
if isinstance(module, CudaGraphsTorchTensorRTModule):
module = module.compiled_module
module_type = _parse_module_type(module)
accepted_formats = {"exported_program", "torchscript", "aot_inductor"}
accepted_formats = {"exported_program", "torchscript", "aot_inductor", "executorch"}
if arg_inputs is not None and not all(
isinstance(input, (torch.Tensor, Input)) for input in arg_inputs
):
Expand Down Expand Up @@ -847,12 +847,16 @@ def _extract_tensor(obj: Any) -> Any:

if output_format not in accepted_formats:
raise ValueError(
f"Provided output_format {output_format} is not supported. Supported options are exported_program | torchscript"
f"Provided output_format {output_format} is not supported. Supported options are exported_program | torchscript | aot_inductor | executorch"
)
if output_format == "aot_inductor" and platform.system() != "Linux":
raise ValueError(
f"The AOT Inductor format is only supported on Linux, {platform.system()} is not a supported platform for this format"
)
if output_format == "executorch" and platform.system() != "Linux":
raise ValueError(
f"The executorch format is only supported on Linux, {platform.system()} is not a supported platform for this format"
)
if not file_path:
raise ValueError("File path cannot be empty. Please provide a valid file path")

Expand Down Expand Up @@ -906,6 +910,8 @@ def _extract_tensor(obj: Any) -> Any:
inductor_configs=inductor_configs,
package_path=file_path,
)
elif output_format == "executorch":
_save_as_executorch(module, file_path)
else:
raise RuntimeError(
"Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor"
Expand Down Expand Up @@ -963,6 +969,8 @@ def _extract_tensor(obj: Any) -> Any:
inductor_configs=inductor_configs,
package_path=file_path,
)
elif output_format == "executorch":
_save_as_executorch(exp_program, file_path)
else:
raise RuntimeError(
"Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor"
Expand Down Expand Up @@ -1014,7 +1022,6 @@ def _extract_tensor(obj: Any) -> Any:
"Provided model is a torch.fx.GraphModule without existing shape metadata and retrace is True, however no inputs specs were provided. "
"Please provide valid torch.Tensors or torch_tensorrt.Input objects as inputs to retrace and save the model"
)

exp_program = torch.export.export(
module,
args=tuple(arg_tensors),
Expand Down Expand Up @@ -1042,12 +1049,48 @@ def _extract_tensor(obj: Any) -> Any:
inductor_configs=inductor_configs,
package_path=file_path,
)
elif output_format == "executorch":
_save_as_executorch(exp_program, file_path)
else:
raise RuntimeError(
"Attempted to serialize an exported program with an unsupported format. Exported programs support exported_program and aot_inductor"
)


def _save_as_executorch(exp_program: Any, file_path: str) -> None:
"""Save an ExportedProgram (with TensorRT execute_engine nodes) as an ExecuTorch .pte file.

Partitions the graph by torch.ops.tensorrt.execute_engine, serializes each engine
to the same blob format as the TRT runtime (vector of strings), and embeds it
in the .pte. Requires the ``executorch`` package and torch_tensorrt_runtime. See
https://pytorch.org/executorch/stable/getting-started-setup.html
"""
if not ENABLED_FEATURES.torch_tensorrt_runtime:
raise RuntimeError(
"output_format='executorch' requires the Torch-TensorRT runtime "
"(torch_tensorrt_runtime). Reinstall torch_tensorrt with the runtime extension."
)
try:
from executorch.exir import to_edge_transform_and_lower
except ImportError:
raise ImportError(
"ExecuTorch is not installed. Please install it to use output_format='executorch'. "
"See https://pytorch.org/executorch/stable/getting-started-setup.html"
)
# Ensure execute_engine fake kernel is registered so partitioner can run
# when the engine is a CustomObjArgument (export placeholder).
import torch_tensorrt.dynamo.runtime.meta_ops.register_meta_ops # noqa: F401
from torch_tensorrt.executorch import TensorRTPartitioner

edge_program = to_edge_transform_and_lower(
exp_program,
partitioner=[TensorRTPartitioner()],
)
executorch_program = edge_program.to_executorch()
with open(file_path, "wb") as f:
executorch_program.write_to_file(f)


def function_overload_with_kwargs(
fn: Callable[..., Any], *args: Any, **kwargs: Any
) -> Any:
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
Returns:
torch.Tensor or Tuple(torch.Tensor): Result of the engine computation
"""

if self.engine is None:
raise RuntimeError("Engine has not been setup yet.")

Expand All @@ -354,7 +355,6 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
(i if isinstance(i, torch.Tensor) else torch.tensor(i).cuda())
for i in inputs
]

outputs: List[torch.Tensor] = torch.ops.tensorrt.execute_engine(
list(input_tensors), self.engine
)
Expand Down
34 changes: 29 additions & 5 deletions py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,20 +189,46 @@ def fake_aten_cudnn_grid_sampler(
return torch.empty(out_shape, dtype=input.dtype, device=input.device)


def _is_placeholder_engine(engine: Any) -> bool:
"""True if engine is a placeholder (CustomObjArgument/FakeScriptObject) from export."""
if engine is None:
return True
type_name = type(engine).__name__
if type_name == "CustomObjArgument":
return True
if type_name == "FakeScriptObject":
return True
if hasattr(engine, "fake_val") and engine.fake_val is not None:
return True
if not hasattr(engine, "get_serialized_metadata"):
return True
return False


@torch.library.register_fake("tensorrt::execute_engine") # type: ignore
def fake_tensorrt_execute_engine(
inputs: List[torch.Tensor], fake_trt_engine: Any
) -> Any:
"""
Meta kernel for TensorRT engine execution.

Uses symbolic shape expressions captured at compile time to correctly infer
output shapes while preserving symbolic SymInt relationships.
When the engine is a placeholder (CustomObjArgument/FakeScriptObject from
torch.export/ExecuTorch), returns one fake output per input (same shape/dtype)
so partitioners can run without a real engine. Otherwise uses symbolic shape
expressions from metadata to infer output shapes.
"""
if _is_placeholder_engine(fake_trt_engine):
from torch._guards import detect_fake_mode

fake_mode = detect_fake_mode(inputs) if inputs else None
if not inputs:
return [torch.empty(())]
if fake_mode is not None:
return [fake_mode.from_tensor(inputs[0])]
return [torch.empty_like(inputs[0])]

metadata = None
if hasattr(fake_trt_engine, "real_obj"):
# Wrapped C++ engine with real_obj
trt_engine = fake_trt_engine.real_obj
metadata = TorchTensorRTModule.decode_metadata(
trt_engine.get_serialized_metadata()
Expand All @@ -215,8 +241,6 @@ def fake_tensorrt_execute_engine(
shape_info = metadata.get("inout_symexprs") if metadata else None

if shape_info:
# Apply the symbolic shape expressions to create output fake tensors
# shape_info now contains both 'inputs' and 'outputs' keys
return _apply_symbolic_shape_expressions(inputs, shape_info)
else:
raise RuntimeError(
Expand Down
9 changes: 9 additions & 0 deletions py/torch_tensorrt/executorch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# ExecuTorch backend for Torch-TensorRT: save/load .pte with TensorRT delegate.

from torch_tensorrt.executorch.backend import TensorRTBackend
from torch_tensorrt.executorch.partitioner import TensorRTPartitioner

__all__ = [
"TensorRTBackend",
"TensorRTPartitioner",
]
85 changes: 85 additions & 0 deletions py/torch_tensorrt/executorch/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# ExecuTorch TensorRT backend: serialize engine to same blob format as TRT runtime.

import base64
from typing import Any, List, final

import torch
from executorch.exir.backend.backend_details import (
BackendDetails,
CompileSpec,
PreprocessResult,
)
from torch.export.exported_program import ExportedProgram
from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import (
ENGINE_IDX,
SERIALIZATION_LEN,
)
from torch_tensorrt.executorch.serialization import serialize_engine_info


def _get_engine_info_from_edge_program(edge_program: ExportedProgram) -> List[Any]:
"""Extract engine info (list of strings/bytes) from the partition's execute_engine node."""
gm = edge_program.graph_module
execute_engine_op = torch.ops.tensorrt.execute_engine.default

for node in gm.graph.nodes:
if node.op != "call_function" or node.target is not execute_engine_op:
continue
if len(node.args) < 2:
continue
engine_arg = node.args[1]
if engine_arg.op == "get_attr":
val = getattr(gm, engine_arg.target, None)
if val is None:
raise RuntimeError(
f"Engine get_attr({engine_arg.target}) not found on partition module."
)
if hasattr(val, "__getstate__"):
engine_info = val.__getstate__()
else:
engine_info = getattr(val, "engine_info", val)
if (
isinstance(engine_info, (list, tuple))
and len(engine_info) >= SERIALIZATION_LEN
):
return list(engine_info)
raise RuntimeError(
f"Engine argument get_attr({engine_arg.target}) did not yield engine info list (len >= {SERIALIZATION_LEN})."
)
raise RuntimeError(
"TensorRT ExecuTorch backend expects execute_engine(inputs, engine) "
"where engine is a get_attr; cannot find engine."
)
raise RuntimeError(
"TensorRT ExecuTorch backend: no execute_engine node found in partition."
)


@final
class TensorRTBackend(BackendDetails): # type: ignore[misc]
"""Backend that serializes TensorRT engine to the same blob format as the TRT runtime.

The partition contains a single execute_engine node; we extract the engine
and metadata and encode them as a vector of strings (same layout as
core/runtime/runtime.h SerializedInfoIndex) so the same blob works for
both ExecuTorch and non-ExecuTorch TRT runtime.
"""

@staticmethod
def preprocess(
edge_program: ExportedProgram,
compile_specs: List[CompileSpec],
) -> PreprocessResult:
engine_info = _get_engine_info_from_edge_program(edge_program)
engine_info = list(engine_info)
serialized_engine = engine_info[ENGINE_IDX]
if isinstance(serialized_engine, str):
engine_info[ENGINE_IDX] = base64.b64decode(
serialized_engine.encode("utf-8")
)
elif not isinstance(serialized_engine, (bytes, bytearray)):
engine_info[ENGINE_IDX] = bytes(serialized_engine)
if len(engine_info) > 7 and isinstance(engine_info[7], bytes):
engine_info[7] = engine_info[7].decode("utf-8", errors="replace")
blob = serialize_engine_info(engine_info)
return PreprocessResult(processed_bytes=blob)
26 changes: 26 additions & 0 deletions py/torch_tensorrt/executorch/operator_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Operator support for ExecuTorch TensorRT partitioner: only execute_engine is supported.

from typing import Dict

import torch
from torch.fx.passes.operator_support import OperatorSupportBase


class TensorRTOperatorSupport(OperatorSupportBase): # type: ignore[misc]
"""Supports only torch.ops.tensorrt.execute_engine for partitioning.

Used so that TRT-compiled graphs (which already contain execute_engine nodes)
are partitioned per engine; each partition is then lowered to TensorRTBackend
which serializes the engine to the same blob format as the TRT runtime.
"""

def __init__(self) -> None:
super().__init__()
self._execute_engine_op = torch.ops.tensorrt.execute_engine.default

def is_node_supported(
self, submodules: Dict[str, torch.nn.Module], node: torch.fx.Node
) -> bool:
if node.op != "call_function":
return False
return node.target is self._execute_engine_op
Loading
Loading