diff --git a/.github/workflows/build-test-linux-x86_64.yml b/.github/workflows/build-test-linux-x86_64.yml index a437c284c0..8bff2f8805 100644 --- a/.github/workflows/build-test-linux-x86_64.yml +++ b/.github/workflows/build-test-linux-x86_64.yml @@ -107,7 +107,7 @@ jobs: set -euo pipefail pushd . cd tests/py/dynamo - python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_converter_tests_results.xml --dist=loadscope --maxfail=20 conversion/ + python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_converter_tests_results.xml --maxfail=20 conversion/ popd L0-dynamo-core-tests: @@ -236,6 +236,7 @@ jobs: cd tests/py/dynamo python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_core_tests_results.xml runtime/test_001_* python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_core_partitioning_tests_results.xml partitioning/test_001_* + python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_hlo_tests_results.xml hlo/ popd diff --git a/.github/workflows/build-test-linux-x86_64_rtx.yml b/.github/workflows/build-test-linux-x86_64_rtx.yml index 5315cdd762..de407803e8 100644 --- a/.github/workflows/build-test-linux-x86_64_rtx.yml +++ b/.github/workflows/build-test-linux-x86_64_rtx.yml @@ -107,7 +107,7 @@ jobs: set -euo pipefail pushd . cd tests/py/dynamo - python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_converter_tests_results.xml --dist=loadscope --maxfail=20 conversion/ + python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_converter_tests_results.xml --maxfail=20 conversion/ popd L0-dynamo-core-tests: @@ -204,6 +204,7 @@ jobs: pushd . cd tests/py/dynamo python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_core_tests_results.xml runtime/test_001_* + python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_hlo_tests_results.xml hlo/ popd L1-dynamo-compile-tests: diff --git a/.github/workflows/build-test-windows.yml b/.github/workflows/build-test-windows.yml index 4106b65046..f25d2a2a3b 100644 --- a/.github/workflows/build-test-windows.yml +++ b/.github/workflows/build-test-windows.yml @@ -226,6 +226,7 @@ jobs: cd tests/py/dynamo ../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_core_tests_results.xml runtime/test_001_* ../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_core_partitioning_tests_results.xml partitioning/test_001_* + ../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_hlo_tests_results.xml hlo/ popd L1-dynamo-compile-tests: diff --git a/.github/workflows/build-test-windows_rtx.yml b/.github/workflows/build-test-windows_rtx.yml index 6fdbc1eab3..d25ed8b770 100644 --- a/.github/workflows/build-test-windows_rtx.yml +++ b/.github/workflows/build-test-windows_rtx.yml @@ -200,6 +200,7 @@ jobs: pushd . cd tests/py/dynamo ../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_core_tests_results.xml runtime/test_001_* + ../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_hlo_tests_results.xml hlo/ popd L1-dynamo-compile-tests: diff --git a/docsrc/contributors/complex_number_support.rst b/docsrc/contributors/complex_number_support.rst index c224c76d6d..f4fbe96f70 100644 --- a/docsrc/contributors/complex_number_support.rst +++ b/docsrc/contributors/complex_number_support.rst @@ -135,10 +135,15 @@ runtime modules handle the conversion: Key Implementation Invariants ------------------------------- -* **``originally_complex`` set** — the set of nodes that were complex-dtype - *before* any rewrites. After ``replace_input_node``, complex placeholders become - ``float32`` so ``is_complex_dtype()`` returns ``False``. The ``originally_complex`` - set is used to decide which ``mul.Tensor`` nodes need the complex mul rewrite. +* **``node.meta["is_complex_layout"]``** — every node that represents a complex + quantity (either originally complex-dtype, or a real ``(..., 2)`` tensor produced + by the rewriter) is annotated with ``node.meta["is_complex_layout"] = True``. + This annotation is set during the detection phase (before any rewrites begin) and + propagated by every rewrite handler as it emits new nodes. It survives dtype + changes: after ``replace_input_node`` converts a ``placeholder`` from complex to + ``float32``, the dtype-based check ``is_complex_dtype()`` would return ``False``, + but the metadata flag remains. ``_is_complex_layout_node(n)`` is simply + ``n.meta.get("is_complex_layout", False)`` — no shape heuristics or recursion. * **FakeTensorMode reuse** — ``propagate_metadata`` must use the ``FakeTensorMode`` from existing placeholder fake tensors (not a fresh mode) to avoid mode-mismatch errors under ``torch.compile`` and to preserve SymInt for dynamic shapes. @@ -146,8 +151,228 @@ Key Implementation Invariants Nested submodule parameter names (e.g. ``layers.0.weight``) must have ``.`` replaced with ``__`` before registration. +The Decomposition System — How It Is Built +------------------------------------------- + +The rewriter is split across two classes and wired together by a lightweight +dispatch mechanism. This section walks through each piece and explains the +design decisions. + +ComplexOpDetector — Subgraph Discovery +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +``ComplexOpDetector`` walks the graph to find the set of nodes that participate +in complex arithmetic. + +``node_include_in_subgraph`` +"""""""""""""""""""""""""""" + +A node is included in a complex subgraph if: + +1. Its output dtype is ``complex64`` or ``complex128`` (``is_complex_dtype``), **or** +2. Any of its inputs are complex (``has_complex_input``). + +The second condition is necessary to catch real-output ops — ``abs``, ``angle``, +``real``, ``imag`` — whose inputs are complex. These must be rewritten alongside +the rest of the subgraph even though their outputs are real. + +``subgraph_from_anchor`` +"""""""""""""""""""""""" + +For ``view_as_real``-bounded subgraphs, detection starts at a ``view_as_real`` +*anchor* node and performs a backward BFS: + +.. code-block:: text + + view_as_real ← mul (complex) ← reshape ← placeholder (complex) + ↑ anchor ↑ subgraph ↑ subgraph ↑ input + +At each step, if an upstream node satisfies ``node_include_in_subgraph`` it is +added to the subgraph; otherwise it becomes an *input node* (the boundary). The +result is a ``ComplexSubGraphInfo`` containing anchor nodes, subgraph nodes, and +input nodes. + +After collection the subgraph is **sorted in topological order** (by position in +the graph's node list). This is critical: without it a ``mul`` node could be +processed before its ``sin`` or ``cos`` operands, causing the rewriter to see the +original complex node instead of the already-rewritten real node. + +``find_complex_op_subgraphs`` and subgraph merging +""""""""""""""""""""""""""""""""""""""""""""""""""" + +When a model has multiple ``view_as_real`` anchors that share upstream nodes +(e.g. ``xq_out`` and ``xk_out`` in a RoPE layer both descend from the same +``freqs_cis`` placeholder), their subgraphs would otherwise be detected +separately. ``find_complex_op_subgraphs`` merges overlapping subgraphs by +set intersection so each node is rewritten exactly once. + +``find_all_complex_subgraphs`` — unbounded complex ops +""""""""""""""""""""""""""""""""""""""""""""""""""""""" + +Some models produce a complex tensor as a graph *output* without passing it +through ``view_as_real``. ``find_all_complex_subgraphs`` is a forward scan that +collects every ``call_function`` node with a complex output, regardless of +anchoring. The resulting subgraph is processed the same way as an +anchor-bounded one. + +ComplexGraphRewriter — Dispatch-Based Rewriting +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +``ComplexGraphRewriter`` is decorated with ``@_register_unpackers``, which at +class-definition time scans every method for the ``@_complex_unpacker(op, ...)`` +decorator and builds a ``cls._DISPATCH`` dictionary mapping aten ops to rewrite +methods. + +.. code-block:: python + + @_complex_unpacker(torch.ops.aten.mul.Tensor) + def _rewrite_mul(self, node: Node, b: SubgraphBuilder, ...): + ... + +The entry point ``rewrite_subgraph_nodes`` iterates over the (topologically +ordered) subgraph nodes and for each node: + +1. Looks up ``node.target`` in ``_DISPATCH``. +2. If found, calls the corresponding rewrite method. +3. If not found but the op is in ``_ELEMENTWISE_SAFE``, skips it (the op applies + independently to every scalar, so the ``(..., 2)`` real layout is already + correct). +4. Otherwise logs a warning and leaves the node unchanged. + +``_ELEMENTWISE_SAFE`` +""""""""""""""""""""" + +The ``_ELEMENTWISE_SAFE`` set contains ops that apply to every element of the +tensor independently — ``add.Tensor``, ``sub.Tensor``, ``neg``, ``mul.Scalar``, +``clone``, ``where``, etc. On the ``(..., 2)`` real layout these are already +correct: adding two complex tensors element-wise is the same as adding their +real and imaginary parts independently. + +Notably **excluded** from this set: + +* ``permute.default`` — must append the trailing real/imag dim index. +* ``add.Scalar`` / ``sub.Scalar`` — a scalar added to a complex number only + shifts the real part; on the ``(..., 2)`` layout both parts would be shifted. +* ``reshape`` / ``view`` — shape arguments need updating for the extra ``2`` dim. + +Complex Multiply Decomposition +""""""""""""""""""""""""""""""" + +The most important rewrite is ``mul.Tensor`` between two complex operands. +The rewriter calls ``complex_mul_replacement``: + +.. code-block:: python + + # inputs a, b have shape (..., 2) — last dim is [real, imag] + re_a = select(a, -1, 0); im_a = select(a, -1, 1) + re_b = select(b, -1, 0); im_b = select(b, -1, 1) + real_out = re_a * re_b - im_a * im_b # ac - bd + imag_out = re_a * im_b + im_a * re_b # ad + bc + result = stack([real_out, imag_out], dim=-1) + +Each step is inserted via a ``SubgraphBuilder`` anchored at the ``mul`` node, +so all six new nodes appear immediately after it in topological order. The +original ``mul`` node is then replaced and erased. + +See :ref:`subgraph_builder` for more on how ``SubgraphBuilder`` manages +cursor-based insertion. + +The ``is_complex_layout`` Metadata Invariant +""""""""""""""""""""""""""""""""""""""""""""" + +Input replacement (Stage 2) converts complex ``placeholder`` nodes to +``float32``. After that, ``is_complex_dtype(node)`` returns ``False`` for those +nodes even though they logically represent complex quantities. + +To avoid missed rewrites, every node that represents a complex quantity is +annotated with ``node.meta["is_complex_layout"] = True`` during the detection +phase (lines in ``rewrite_subgraph_nodes`` before any rewrites begin). The +annotation is then propagated forward by every rewrite handler: + +* ``replace_input_node`` stamps it on the new placeholder and ``get_attr`` nodes. +* ``_inline_cat_re_im`` stamps it on every ``[re_u, im_u]`` concatenation node, + covering all math handlers (``exp``, ``log``, ``sin``, ``mul``, etc.) at once. +* Each shape-manipulation handler (``reshape``, ``permute``, ``unsqueeze``, + ``cat``, ``stack``, etc.) stamps it on its output node explicitly. + +``_is_complex_layout_node(n)`` is therefore a direct metadata lookup — no shape +heuristics (``val.shape[-1] == 2``), no recursive ``_SHAPE_TRANSPARENT_OPS`` +propagation. This also eliminates false-positives on real parameters that +coincidentally have a trailing dimension of size 2. + +FakeTensorMode Reuse for Dynamic Shapes +""""""""""""""""""""""""""""""""""""""""" + +When inserting a new ``placeholder`` for a complex input, the pass must populate +``meta["val"]`` with a ``FakeTensor`` of the new real shape. Using a fresh +``FakeTensorMode()`` would create a *new* ``ShapeEnv``, which is incompatible +with the one that ``torch.export`` used to encode dynamic shape constraints +(SymInt ranges). + +The fix is to extract the ``FakeTensorMode`` from the *original* placeholder's +``meta["val"].fake_mode`` and reuse it. The new fake tensor is then constructed +by appending a concrete ``2`` to the symbolic shape list: + +.. code-block:: python + + orig_fake = input_node.meta["val"] + sym_shape = list(orig_fake.shape) + [2] + with orig_fake.fake_mode: + fake_tensor = torch.empty(sym_shape, dtype=new_dtype, device=device) + +This preserves all SymInt identity across the graph and keeps +dynamic-shape exports working correctly. + +Entry Point: ``complex_graph_detection`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The public entry point called by the lowering pipeline is +``complex_graph_detection(gm, settings)``. It: + +1. Instantiates ``ComplexOpDetector`` and ``ComplexGraphRewriter``. +2. Calls ``find_complex_op_subgraphs`` anchored on ``view_as_real`` to find + bounded complex subgraphs. +3. Calls ``find_all_complex_subgraphs`` for any remaining complex nodes that + are not ``view_as_real``-bounded. +4. For each subgraph: + + a. Calls ``replace_input_node`` on every boundary input node (Stage 2). + b. Calls ``rewrite_subgraph_nodes`` on the ordered subgraph (Stage 3). + c. Calls ``clean_up_graph_after_modifications`` to remove dead nodes. + +5. Returns the modified ``GraphModule``. + +Adding New Op Rewrites +^^^^^^^^^^^^^^^^^^^^^^^ + +To teach the rewriter about a new complex op, add a method to +``ComplexGraphRewriter`` tagged with ``@_complex_unpacker``: + +.. code-block:: python + + @_complex_unpacker(torch.ops.aten.my_new_op.default) + def _rewrite_my_new_op(self, node: Node) -> bool: + inp = node.args[0] + with SubgraphBuilder(self.gm.graph, node) as b: + re = b(torch.ops.aten.select.int, inp, -1, 0) + im = b(torch.ops.aten.select.int, inp, -1, 1) + out = b(my_real_impl, re, im) + # If the output is still a complex-layout [..., 2] tensor, annotate it. + # (Not needed if using _inline_cat_re_im, which sets the flag automatically.) + out.meta["is_complex_layout"] = True + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + +``@_register_unpackers`` (applied to the class) picks up the new entry +automatically at import time — no other registration is required. + +If the new op is elementwise-safe on the ``(..., 2)`` layout (i.e. it acts +independently on every scalar), add it to ``_ELEMENTWISE_SAFE`` instead. + Related ------- * :ref:`lowering` — the complex rewrite is a lowering pass. +* :ref:`subgraph_builder` — the ``SubgraphBuilder`` helper used in every rewrite method. * :ref:`lowering_passes_catalog` — pass ordering and management. diff --git a/docsrc/tutorials/advanced_usage.rst b/docsrc/tutorials/advanced_usage.rst index 11090dc5a0..28480a4a44 100644 --- a/docsrc/tutorials/advanced_usage.rst +++ b/docsrc/tutorials/advanced_usage.rst @@ -2,7 +2,7 @@ Advanced Usage ============== Step-by-step tutorials covering engine caching, quantization, custom kernels, -dynamic shapes, weight streaming, debugging, and more. +dynamic shapes, weight streaming, debugging, complex numerics, and more. .. toctree:: :maxdepth: 2 @@ -14,5 +14,6 @@ dynamic shapes, weight streaming, debugging, and more. weight_refit/index runtime_opt/index deployment/index + complex_numerics/index Example: Distributed Inference <_rendered_examples/distributed_inference/index> ../indices/supported_ops diff --git a/docsrc/tutorials/deployment/complex_tensors.rst b/docsrc/tutorials/complex_numerics/complex_tensors.rst similarity index 96% rename from docsrc/tutorials/deployment/complex_tensors.rst rename to docsrc/tutorials/complex_numerics/complex_tensors.rst index 57716f181d..c7507685a9 100644 --- a/docsrc/tutorials/deployment/complex_tensors.rst +++ b/docsrc/tutorials/complex_numerics/complex_tensors.rst @@ -11,6 +11,12 @@ compilation. This page explains what the rewriter does, which patterns are supported, and what limitations to be aware of when compiling models with complex inputs. +.. seealso:: + + :doc:`../_rendered_examples/dynamo/torch_export_3d_rope` — a runnable + end-to-end example compiling a video-transformer 3D RoPE attention block + (CogVideoX / Wan / HunyuanVideo style) with dynamic T×H×W shapes. + ---- How the Rewriter Works diff --git a/docsrc/tutorials/complex_numerics/index.rst b/docsrc/tutorials/complex_numerics/index.rst new file mode 100644 index 0000000000..0494d84dad --- /dev/null +++ b/docsrc/tutorials/complex_numerics/index.rst @@ -0,0 +1,10 @@ +Complex Numerics +=================== + +Compatiblity support for numerical datatypes like complex numerics which are not natively supported by TensorRT + +.. toctree:: + :maxdepth: 1 + + complex_tensors + Example: 3D RoPE with Complex Numerics <../_rendered_examples/dynamo/torch_export_3d_rope> diff --git a/docsrc/tutorials/deployment/index.rst b/docsrc/tutorials/deployment/index.rst index 7df88922e5..40383bfd65 100644 --- a/docsrc/tutorials/deployment/index.rst +++ b/docsrc/tutorials/deployment/index.rst @@ -12,4 +12,3 @@ complex-valued model support. cross_compile_windows Example: Cross-runtime Compilation for Windows <../_rendered_examples/dynamo/cross_runtime_compilation_for_windows> distributed_inference - complex_tensors diff --git a/docsrc/tutorials/extensibility/lowering/index.rst b/docsrc/tutorials/extensibility/lowering/index.rst index 487fe5b4ec..e44dcb66c4 100644 --- a/docsrc/tutorials/extensibility/lowering/index.rst +++ b/docsrc/tutorials/extensibility/lowering/index.rst @@ -8,3 +8,4 @@ rewrite ATen ops before TensorRT compilation. :maxdepth: 1 writing_dynamo_aten_lowering_passes + subgraph_builder diff --git a/docsrc/tutorials/extensibility/lowering/subgraph_builder.rst b/docsrc/tutorials/extensibility/lowering/subgraph_builder.rst new file mode 100644 index 0000000000..b7f6131e0d --- /dev/null +++ b/docsrc/tutorials/extensibility/lowering/subgraph_builder.rst @@ -0,0 +1,105 @@ +.. _subgraph_builder: + +SubgraphBuilder — Cursor-Based FX Node Insertion +================================================= + +Writing lowering passes that replace one node with several new nodes requires +careful management of insertion order: each new node must be inserted +*after the previous one* so that the topological ordering of the graph is +preserved. Doing this by hand with repeated ``graph.inserting_after(cursor)`` +context managers is verbose and error-prone. + +``SubgraphBuilder`` is a lightweight context-manager helper in +``torch_tensorrt.dynamo.lowering._SubgraphBuilder`` that automates this +cursor-tracking pattern. + +Basic Usage +----------- + +Construct a ``SubgraphBuilder`` with the target graph and the *anchor* node — +the node immediately before where you want to start inserting. Then use it +as a callable inside a ``with`` block to add nodes one at a time: + +.. code-block:: python + + from torch_tensorrt.dynamo.lowering._SubgraphBuilder import SubgraphBuilder + import torch.ops.aten as aten + + # Inside a lowering pass, given a node `mul_node` to replace: + with SubgraphBuilder(gm.graph, mul_node) as b: + # Each call inserts a node after the current cursor and advances it. + re_a = b(aten.select.int, a, -1, 0) # a[..., 0] (real part of a) + im_a = b(aten.select.int, a, -1, 1) # a[..., 1] (imag part of a) + re_b = b(aten.select.int, b_node, -1, 0) + im_b = b(aten.select.int, b_node, -1, 1) + real = b(aten.sub.Tensor, b(aten.mul.Tensor, re_a, re_b), + b(aten.mul.Tensor, im_a, im_b)) # ac - bd + imag = b(aten.add.Tensor, b(aten.mul.Tensor, re_a, im_b), + b(aten.mul.Tensor, im_a, re_b)) # ad + bc + result = b(aten.stack, [real, imag], -1) + + mul_node.replace_all_uses_with(result) + gm.graph.erase_node(mul_node) + +On ``__exit__``, the builder automatically calls ``graph.lint()`` to validate +the modified graph. If your code raises an exception inside the block, the +lint is skipped so you see the original error rather than a secondary graph +validation failure. + +How It Works +------------ + +The builder maintains a *cursor* — initially the anchor node passed to +``__init__``. Every time you call it: + +1. A new ``call_function`` node is inserted via ``graph.inserting_after(cursor)``. +2. The cursor advances to the newly inserted node. +3. The new node is appended to an internal ``_inserted`` list for debug logging. + +This ensures that successive calls produce a correctly ordered chain: + +.. code-block:: text + + anchor → node_0 → node_1 → node_2 → ... + +without any manual bookkeeping. + +Debug Logging +------------- + +When the ``torch_tensorrt`` logger is set to ``DEBUG``, the builder emits a +compact summary of all inserted nodes after a successful block, for example:: + + rewrite %mul_17[(4, 32, 2),torch.float32] -> + %select_72[(4, 32),torch.float32] = select_int(%inp_0, -1, 0) + %select_73[(4, 32),torch.float32] = select_int(%inp_0, -1, 1) + %mul_18[(4, 32),torch.float32] = mul_Tensor(%select_72, %select_73) + ... + +This makes it easy to trace exactly which nodes were produced by a particular +rewrite rule. + +API Reference +------------- + +.. autoclass:: torch_tensorrt.dynamo.lowering._SubgraphBuilder.SubgraphBuilder + :members: + :undoc-members: + +When to Use SubgraphBuilder +--------------------------- + +Use ``SubgraphBuilder`` whenever a lowering pass needs to **expand one node into +a sequence of several nodes** in a single linear chain. Typical use cases: + +* Replacing a complex-arithmetic op with real-arithmetic equivalents + (e.g. the ``complex_mul_replacement`` in :ref:`complex_number_support_design`). +* Decomposing a high-level op (e.g. ``layer_norm``) into its ATen primitives + when a custom replacement strategy is needed beyond the standard decomposition + table. +* Inserting diagnostic nodes (shape probes, debug prints) around a target op. + +If you only need to insert a *single* node, a plain +``graph.inserting_after(node)`` is simpler. If you need to insert into multiple +disconnected locations in the same pass, create a separate ``SubgraphBuilder`` +for each anchor. diff --git a/examples/dynamo/README.rst b/examples/dynamo/README.rst index fade7a3ee5..219d825af3 100644 --- a/examples/dynamo/README.rst +++ b/examples/dynamo/README.rst @@ -25,4 +25,5 @@ Model Zoo * :ref:`_torch_export_llama2`: Compiling a Llama2 model using AOT workflow (`ir=dynamo`) * :ref:`_torch_export_sam2`: Compiling SAM2 model using AOT workflow (`ir=dynamo`) * :ref:`_torch_export_flux_dev`: Compiling FLUX.1-dev model using AOT workflow (`ir=dynamo`) -* :ref:`debugger_example`: Debugging Torch-TensorRT Compilation \ No newline at end of file +* :ref:`debugger_example`: Debugging Torch-TensorRT Compilation +* :ref:`torch_export_3d_rope`: Compiling a 3D RoPE video-transformer block with complex numerics support \ No newline at end of file diff --git a/examples/dynamo/torch_export_3d_rope.py b/examples/dynamo/torch_export_3d_rope.py new file mode 100644 index 0000000000..8851beb59e --- /dev/null +++ b/examples/dynamo/torch_export_3d_rope.py @@ -0,0 +1,369 @@ +""" +3D Rotary Position Embedding (RoPE) + Attention compiled with Torch-TensorRT +============================================================================= + +3D RoPE is the positional encoding used in video generation transformers such +as CogVideoX, Wan, and HunyuanVideo. Unlike 1D RoPE (used in language models) +which encodes a single sequence index, 3D RoPE independently encodes three +axes — temporal (T), height (H), and width (W) — and assigns each axis a +dedicated slice of the per-head frequency vector: + + head-dim slots 0 .. d//3-1 → temporal frequencies + head-dim slots d//3.. 2d//3-1 → height frequencies + head-dim slots 2d//3.. d//2-1 → width frequencies + +The rotation is expressed with complex arithmetic: + + xq_rotated = view_as_real(view_as_complex(xq) * freqs_cis) + +PyTorch complex ops (view_as_complex, complex mul) are not natively supported +by TensorRT. Torch-TensorRT's ``complex_graph_detection`` lowering pass +intercepts them before partitioning and rewrites the subgraph to equivalent +real arithmetic — splitting the last dimension into (..., 2) real/imag pairs +and computing (ac-bd, ad+bc) manually — so the TRT engine only sees standard +float32 ops and the caller never needs to change anything. + +This example: + 1. Defines a 3D-RoPE frequency precomputation helper (complex64 output). + 2. Defines a VideoAttentionBlock: linear QKV projection → 3D RoPE → SDPA. + 3. Runs a PyTorch baseline forward pass. + 4. Exports with torch.export.export() and dynamic T/H/W dimensions. + 5. Compiles to TensorRT via torch_tensorrt.dynamo.compile(). + 6. Verifies numerical accuracy (cosine similarity on the output tensor). + 7. (Optional) benchmarks latency of both backends. + +Usage +----- +# Quick correctness check (static shapes) +python examples/dynamo/torch_export_3d_rope.py + +# Dynamic T/H/W shapes +python examples/dynamo/torch_export_3d_rope.py --dynamic + +# Larger config + benchmark +python examples/dynamo/torch_export_3d_rope.py --heads 16 --head-dim 96 --t 8 --h 16 --w 16 --benchmark +""" + +import argparse +import timeit + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch_tensorrt +from torch.export import Dim + +DEVICE = torch.device("cuda:0") + + +# --------------------------------------------------------------------------- +# Frequency precomputation +# --------------------------------------------------------------------------- + + +def precompute_freqs_3d( + head_dim: int, + t: int, + h: int, + w: int, + theta: float = 10000.0, +) -> torch.Tensor: + """Pre-compute 3D RoPE unit-complex frequency tensor. + + Returns a complex64 tensor of shape (t, h, w, head_dim // 2) where the + last dimension is split evenly across the three spatial axes. + + Args: + head_dim: Channels per attention head (must be even, head_dim//2 + must be divisible by 3). + t: Number of temporal frames. + h: Spatial height in patches. + w: Spatial width in patches. + theta: Base for the geometric frequency progression. + """ + half = head_dim // 2 + d_t = half // 3 + d_h = half // 3 + d_w = half - d_t - d_h # absorbs any remainder from integer division + + def _axis_freqs(d: int, n: int) -> torch.Tensor: + """1-D complex exponentials, shape (n, d).""" + inv_freq = 1.0 / (theta ** (torch.arange(0, d * 2, 2).float() / (d * 2))) + positions = torch.arange(n, dtype=torch.float32) + angles = torch.outer(positions, inv_freq) + return torch.polar(torch.ones_like(angles), angles) # complex64 + + freqs_t = _axis_freqs(d_t, t)[:, None, None, :].expand(t, h, w, d_t) + freqs_h = _axis_freqs(d_h, h)[None, :, None, :].expand(t, h, w, d_h) + freqs_w = _axis_freqs(d_w, w)[None, None, :, :].expand(t, h, w, d_w) + + # Concatenate along last dim → (t, h, w, half), complex64 + return torch.cat([freqs_t, freqs_h, freqs_w], dim=-1).contiguous() + + +# --------------------------------------------------------------------------- +# Model +# --------------------------------------------------------------------------- + + +class VideoAttentionBlock(nn.Module): + """Single attention block for video latents with 3D RoPE. + + Inputs + ------ + x : (B, T, H, W, C) float32 video patch features + freqs_cis_real: (T, H, W, C // n_heads) float32 + The RoPE frequency tensor pre-flattened from complex64 via + ``view_as_real(...).flatten(-2)``. The module reconstructs the + complex form internally with ``view_as_complex``. + + Passing frequencies as a plain real-valued input avoids exposing a + complex tensor at the model boundary (TRT inputs must be real). + + Output + ------ + (B, T, H, W, C) float32 + """ + + def __init__(self, channels: int = 512, n_heads: int = 8) -> None: + super().__init__() + assert channels % n_heads == 0 + self.n_heads = n_heads + self.head_dim = channels // n_heads + self.norm = nn.LayerNorm(channels) + self.qkv = nn.Linear(channels, 3 * channels, bias=False) + self.proj = nn.Linear(channels, channels, bias=False) + + def _apply_rope(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + """Apply 3D RoPE to a single Q or K tensor. + + The complex multiply ``xc * freqs_cis`` is what Torch-TensorRT rewrites + to real arithmetic via the complex_graph_detection lowering pass. + + Args: + x : (B, T, H, W, n_heads, head_dim) float32 + freqs_cis: (T, H, W, head_dim // 2) complex64 + Returns: + Rotated tensor, same shape as ``x``, float32. + """ + B, T, H, W, Nh, D = x.shape + # Interpret consecutive pairs of head-dim channels as complex numbers. + xc = torch.view_as_complex(x.reshape(B, T, H, W, Nh, D // 2, 2)) + # freqs_cis broadcast over batch (dim 0) and head (dim 4). + freqs = freqs_cis[None, :, :, :, None, :] # (1, T, H, W, 1, D//2) + return torch.view_as_real(xc * freqs).flatten(-2) # (B,T,H,W,Nh,D) + + def forward( + self, + x: torch.Tensor, + freqs_cis_real: torch.Tensor, + ) -> torch.Tensor: + B, T, H, W, C = x.shape + Nh, D = self.n_heads, self.head_dim + + h = self.norm(x) + qkv = self.qkv(h).reshape(B, T, H, W, 3, Nh, D) + q, k, v = qkv.unbind(dim=4) # each (B, T, H, W, Nh, D) + + # Recover complex frequencies from the real-valued input. + # freqs_cis_real: (T, H, W, D) → reshape to (T, H, W, D//2, 2) → complex + freqs_cis = torch.view_as_complex(freqs_cis_real.reshape(T, H, W, D // 2, 2)) + + q = self._apply_rope(q, freqs_cis) + k = self._apply_rope(k, freqs_cis) + + # Flatten spatial dims for attention: (B, Nh, T*H*W, D) + N = T * H * W + q = q.reshape(B, N, Nh, D).permute(0, 2, 1, 3) + k = k.reshape(B, N, Nh, D).permute(0, 2, 1, 3) + v = v.reshape(B, N, Nh, D).permute(0, 2, 1, 3) + + out = F.scaled_dot_product_attention(q, k, v) # (B, Nh, N, D) + out = out.permute(0, 2, 1, 3).reshape(B, T, H, W, C) + return x + self.proj(out) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_inputs( + B: int, T: int, H: int, W: int, C: int, n_heads: int +) -> tuple[torch.Tensor, torch.Tensor]: + """Return (x, freqs_cis_real) on DEVICE.""" + x = torch.randn(B, T, H, W, C, dtype=torch.float32, device=DEVICE) + freqs_cis = precompute_freqs_3d(C // n_heads, t=T, h=H, w=W).to(DEVICE) + freqs_cis_real = torch.view_as_real(freqs_cis).flatten(-2) # (T,H,W,D) + return x, freqs_cis_real + + +def benchmark(fn, *args, iterations: int = 20, label: str = "") -> float: + fn(*args) # warmup + torch.cuda.synchronize() + total = 0.0 + for _ in range(iterations): + t0 = timeit.default_timer() + fn(*args) + torch.cuda.synchronize() + total += timeit.default_timer() - t0 + avg_ms = total / iterations * 1000 + print(f"[{label}] avg latency over {iterations} iters: {avg_ms:.2f} ms") + return avg_ms + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def parse_args(): + p = argparse.ArgumentParser( + description="3D RoPE attention block compiled with Torch-TensorRT" + ) + p.add_argument("--heads", type=int, default=8, help="Number of attention heads") + p.add_argument( + "--head-dim", + dest="head_dim", + type=int, + default=48, + help="Channels per head. head_dim//2 must be divisible by 3 (default: 48)", + ) + p.add_argument("--t", type=int, default=4, help="Temporal frames (default: 4)") + p.add_argument( + "--h", type=int, default=8, help="Spatial height patches (default: 8)" + ) + p.add_argument( + "--w", type=int, default=8, help="Spatial width patches (default: 8)" + ) + p.add_argument( + "--dynamic", + action="store_true", + help="Export with dynamic T/H/W dims and compile with min/opt/max shapes", + ) + p.add_argument( + "--benchmark", action="store_true", help="Benchmark PyTorch vs TRT latency" + ) + p.add_argument("--iterations", type=int, default=20) + return p.parse_args() + + +def main(): + args = parse_args() + + if (args.head_dim // 2) % 3 != 0: + raise ValueError( + f"head_dim // 2 = {args.head_dim // 2} must be divisible by 3 " + "for the T/H/W frequency split. Try --head-dim 48, 60, 96, or 192." + ) + + B, T, H, W = 1, args.t, args.h, args.w + C = args.heads * args.head_dim + + print(f"VideoAttentionBlock with 3D RoPE") + print(f" heads={args.heads} head_dim={args.head_dim} channels={C}") + print(f" input shape: ({B}, {T}, {H}, {W}, {C})") + + model = VideoAttentionBlock(channels=C, n_heads=args.heads).eval().to(DEVICE) + + # ------------------------------------------------------------------ + # 1. Build inputs + # ------------------------------------------------------------------ + x, freqs_cis_real = make_inputs(B, T, H, W, C, args.heads) + inputs = (x, freqs_cis_real) + print(f"\n x shape : {x.shape}") + print(f" freqs_cis_real shape: {freqs_cis_real.shape}") + + # ------------------------------------------------------------------ + # 2. PyTorch baseline + # ------------------------------------------------------------------ + with torch.inference_mode(): + pyt_out = model(*inputs) + print(f"\n--- PyTorch baseline ---") + print(f" output shape: {pyt_out.shape} dtype: {pyt_out.dtype}") + + # ------------------------------------------------------------------ + # 3. Export + # ------------------------------------------------------------------ + print("\nExporting model ...") + if args.dynamic: + t_dim = Dim("T", min=1, max=32) + h_dim = Dim("H", min=4, max=64) + w_dim = Dim("W", min=4, max=64) + dynamic_shapes = ( + # x: (B, T, H, W, C) + {1: t_dim, 2: h_dim, 3: w_dim}, + # freqs_cis_real: (T, H, W, D) + {0: t_dim, 1: h_dim, 2: w_dim}, + ) + ep = torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes) + print(" Exported with dynamic T / H / W dimensions.") + else: + ep = torch.export.export(model, inputs) + print(" Exported with static shapes.") + + # ------------------------------------------------------------------ + # 4. Compile with Torch-TensorRT + # + # No special flags are required for the complex arithmetic rewrite. + # The complex_graph_detection lowering pass automatically detects + # view_as_complex / complex-mul / view_as_real subgraphs and rewrites + # them to real-arithmetic ops before the TRT engine is built. + # ------------------------------------------------------------------ + print("\nCompiling with Torch-TensorRT ...") + D = C // args.heads # freqs_cis_real last dim + if args.dynamic: + trt_inputs = [ + torch_tensorrt.Input( + min_shape=(B, 1, 4, 4, C), + opt_shape=(B, T, H, W, C), + max_shape=(B, 32, 64, 64, C), + dtype=torch.float32, + ), + torch_tensorrt.Input( + min_shape=(1, 4, 4, D), + opt_shape=(T, H, W, D), + max_shape=(32, 64, 64, D), + dtype=torch.float32, + ), + ] + else: + trt_inputs = list(inputs) + + trt_model = torch_tensorrt.dynamo.compile( + ep, + inputs=trt_inputs, + enabled_precisions={torch.float32}, + min_block_size=1, + ) + + # ------------------------------------------------------------------ + # 5. TRT inference & accuracy check + # ------------------------------------------------------------------ + with torch.inference_mode(): + trt_out = trt_model(*inputs) + + pyt_flat = pyt_out.float().flatten() + trt_flat = trt_out.float().flatten() + cos_sim = (pyt_flat @ trt_flat / (pyt_flat.norm() * trt_flat.norm())).item() + max_diff = (pyt_out.float() - trt_out.float()).abs().max().item() + + print(f"\n--- TensorRT vs PyTorch ---") + print(f" output shape : {trt_out.shape}") + print(f" cosine sim : {cos_sim:.6f}") + print(f" max |Δ| : {max_diff:.2e}") + assert cos_sim > 0.99, f"Cosine similarity {cos_sim:.4f} below threshold 0.99!" + print(" PASSED") + + # ------------------------------------------------------------------ + # 6. (Optional) benchmark + # ------------------------------------------------------------------ + if args.benchmark: + print("\n--- Benchmarking ---") + with torch.inference_mode(): + benchmark(model, *inputs, iterations=args.iterations, label="PyTorch") + benchmark(trt_model, *inputs, iterations=args.iterations, label="TensorRT") + + +if __name__ == "__main__": + main() diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index bc3cdc5721..21d7e802c0 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -44,6 +44,7 @@ resource_partition, ) from torch_tensorrt.dynamo.utils import ( + COMPLEX_TO_REAL_DTYPE, deallocate_module, get_cpu_memory_usage, get_flat_args_with_check, @@ -801,6 +802,104 @@ def compile( return trt_gm +def _insert_complex_io_adapters( + partitioned_module: torch.fx.GraphModule, + gm: torch.fx.GraphModule, + settings: CompilationSettings, +) -> None: + """Insert view_as_real / view_as_complex boundary nodes for complex I/O. + + complex_graph_detection rewrites complex subgraphs to real arithmetic before + partitioning, but when a model has complex inputs or outputs the outer wrapper + graph still needs adapters at the TRT block boundary: + + Inputs: insert view_as_real (+ optional cast for complex128+truncate_double) + after each placeholder that was unpacked by the rewriter. + Outputs: insert view_as_complex before the output node for each originally-complex + output that comes from a TRT block. + + Leverages metadata that was captued when the complex rewriter pass was run + """ + complex_input_names = gm.meta.get("complex_input_names", []) + complex_input_dtypes = gm.meta.get("complex_input_dtypes", {}) + complex_output_indices = gm.meta.get("complex_output_indices", []) + + if not complex_input_names and not complex_output_indices: + return + + graph_modified = False + + # --- Input boundary: view_as_real for complex inputs --- + # complex_graph_detection renames complex placeholder 'foo' to 'foo_unpacked_complex' + # with float dtype. The outer graph still has 'foo_unpacked_complex' as a placeholder, + # but the caller passes the original complex tensor. Insert view_as_real after + # each such placeholder so the graph unpacks it transparently. + reshaped_names = {f"{n}_unpacked_complex" for n in complex_input_names} + for node in list(partitioned_module.graph.nodes): + if node.op != "placeholder" or node.name not in reshaped_names: + continue + with partitioned_module.graph.inserting_after(node): + real_node = partitioned_module.graph.call_function( + torch.ops.aten.view_as_real.default, args=(node,) + ) + # For complex128 with truncate_double, the rewriter produced float32 + # TRT engine inputs but view_as_real gives float64 — add an explicit cast. + orig_name = node.name[: -len("_unpacked_complex")] + orig_dtype = complex_input_dtypes.get(orig_name, None) + + if orig_dtype == torch.complex128 and settings.truncate_double: + logger.info( + f"Input '{orig_name}' is complex128 with truncate_double=True: unpacked " + f"float64 components will be cast to float32." + ) + with partitioned_module.graph.inserting_after(real_node): + cast_node = partitioned_module.graph.call_function( + torch.ops.aten.to.dtype, + args=(real_node, torch.float32), + ) + node.replace_all_uses_with(cast_node) + cast_node.args = (real_node, torch.float32) + real_node.args = (node,) + logger.info( + f"Inserted view_as_real + cast-to-float32 for complex128 input placeholder '{node.name}' (truncate_double=True)" + ) + else: + node.replace_all_uses_with(real_node) + # fix the self-reference created by replace_all_uses_with + real_node.args = (node,) + logger.info( + f"Inserted view_as_real for complex input placeholder '{node.name}'" + ) + graph_modified = True + + # --- Output boundary: view_as_complex for complex outputs from TRT blocks --- + if complex_output_indices: + output_node = list(partitioned_module.graph.nodes)[-1] + outputs = list(output_node.args[0]) + for idx in complex_output_indices: + if idx >= len(outputs): + continue + src = outputs[idx] + if not isinstance(src, torch.fx.Node): + continue + if src.op == "call_module" and "_run_on_acc" in str(src.target): + with partitioned_module.graph.inserting_before(output_node): + complex_node = partitioned_module.graph.call_function( + torch.ops.aten.view_as_complex.default, args=(src,) + ) + logger.info( + f"Inserted view_as_complex for complex output index {idx} " + f"from TRT block '{src.target}'" + ) + outputs[idx] = complex_node + graph_modified = True + output_node.args = (tuple(outputs),) + + if graph_modified: + partitioned_module.graph.lint() + partitioned_module.recompile() + + @fn_supports_debugger # type: ignore[misc] def compile_module( gm: torch.fx.GraphModule, @@ -1097,6 +1196,10 @@ def preserve_module_specs( trt_module = getattr(partitioned_module, name) trt_module.setup_engine() + # Post-partition complex I/O boundary pass — runs in both normal and dryrun mode + # so the wrapper graph reflects the exact graph that will be executed/built. + _insert_complex_io_adapters(partitioned_module, gm, settings) + # Only set output tensors as unowned if not in dryrun mode (TRT modules exist) if not settings.dryrun: output_node = list(partitioned_module.graph.nodes)[-1] diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 0b6af849fa..cf1fe5a191 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -43,6 +43,7 @@ from torch_tensorrt.dynamo.utils import ( CPU_DEVICE, check_module_output, + check_output_equal, get_model_device, get_torch_inputs, to_torch_device, @@ -110,6 +111,17 @@ def construct_refit_mapping_from_weight_name_map( engine_weight_name.split(" ")[-1].lower() ) + elif isinstance(sd_weight_name, tuple): + # Buffer-slice mapping created by Stage 3 of _save_weight_mapping. + # Encodes (state_dict_key, dim, index) for weights that are slices + # of a source buffer (e.g. real/imag parts of an unpacked complex buffer). + sd_key, dim, idx = sd_weight_name + if sd_key not in state_dict: + continue + engine_weight_map[engine_weight_name] = ( + state_dict[sd_key].select(dim, idx).to(to_torch_device(settings.device)) + ) + elif sd_weight_name not in state_dict: # If weights is not in sd, we can leave it unchanged continue @@ -587,14 +599,33 @@ def refit_module_weights( if verify_output and arg_inputs is not None: new_gm.to(to_torch_device(settings.device)) - if check_module_output( - new_module=new_gm, - refitted_module=compiled_module, - arg_inputs=torch_inputs, - kwarg_inputs=torch_kwarg_inputs, - ): + # complex_graph_detection rewrites complex placeholders to real (view_as_real). + # The compiled TRT module handles complex→real internally, but the lowered + # PyTorch reference module (new_gm) expects real-unpacked inputs directly. + has_complex_inputs = any( + isinstance(x, torch.Tensor) and x.is_complex() for x in torch_inputs + ) + if has_complex_inputs: + lowered_inputs = [ + ( + torch.view_as_real(x).contiguous() + if isinstance(x, torch.Tensor) and x.is_complex() + else x + ) + for x in torch_inputs + ] + trt_outputs = compiled_module(*torch_inputs) + ref_outputs = new_gm(*lowered_inputs, **torch_kwarg_inputs) + outputs_match = check_output_equal(trt_outputs, ref_outputs) + else: + outputs_match = check_module_output( + new_module=new_gm, + refitted_module=compiled_module, + arg_inputs=torch_inputs, + kwarg_inputs=torch_kwarg_inputs, + ) + if outputs_match: logger.info("Refitting Succeed!") - new_gm.to(CPU_DEVICE) else: if weight_name_map: logger.warning( @@ -610,7 +641,6 @@ def refit_module_weights( in_place=in_place, ) logger.error("Refitting Failed! The outputs do not match.") - new_gm.to(CPU_DEVICE) else: logger.info("Refitting Completed! Output verification skipped.") diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index d4735baa12..fa982bc6da 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -587,6 +587,41 @@ def _save_weight_mapping(self) -> None: weight_refit_map[engine_weight_name].dtype, ] + # Stage 3: Slice matching for unmatched non-scalar CONSTANT weights. + # complex_graph_detection unpacks complex buffers to real: + # freqs (S,D complex64) → freqs_unpacked_complex (S,D,2 float32) + # The real and imag slices (freqs_unpacked_complex[...,0] and [...,1]) are + # embedded as separate TRT constants, but their shapes differ from the source + # buffer, so Stage 2 value matching fails. Here we try selecting each slice + # along the last dimension of every sd entry to find the match. + for engine_weight_name, val in list(weight_name_map.items()): + if not isinstance(val, list) or len(val) != 2: + continue + sd_weight_name, dtype_val = val + if sd_weight_name != "" or engine_weight_name not in weight_refit_map: + continue + ew_tensor = weight_refit_map[engine_weight_name].to(torch_device) + if ew_tensor.numel() <= 1: + continue # scalars are handled via constant_mapping + matched = False + for sd_key, sd_tensor in sd.items(): + if sd_tensor.dim() < 1 or sd_tensor.shape[-1] < 2: + continue + last_dim = sd_tensor.dim() - 1 + for idx in range(sd_tensor.shape[last_dim]): + sd_slice = sd_tensor.select(last_dim, idx) + if TRTInterpreter.check_weight_equal( + sd_slice, ew_tensor, torch_device + ): + weight_name_map[engine_weight_name] = [ + (sd_key, last_dim, idx), + dtype_val, + ] + matched = True + break + if matched: + break + weight_name_map["constant_mapping"] = constant_mapping self.weight_name_map = weight_name_map diff --git a/py/torch_tensorrt/dynamo/lowering/_SubgraphBuilder.py b/py/torch_tensorrt/dynamo/lowering/_SubgraphBuilder.py new file mode 100644 index 0000000000..e5ef07dd7e --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/_SubgraphBuilder.py @@ -0,0 +1,89 @@ +"""Cursor-based FX graph node builder.""" + +from __future__ import annotations + +import logging +from types import TracebackType +from typing import List, Optional, Type + +import torch +import torch.fx +from torch.fx.node import Node + +logger = logging.getLogger(__name__) + + +def _fmt_node(n: object) -> str: + """Return a compact string summary of an FX node or a plain value.""" + if not isinstance(n, Node): + return repr(n) + val = n.meta.get("val", None) + if val is not None and hasattr(val, "shape") and hasattr(val, "dtype"): + return f"%{n.name}[{tuple(val.shape)},{val.dtype}]" + return f"%{n.name}" + + +def _fmt_args(args: tuple) -> str: + return "(" + ", ".join(_fmt_node(a) for a in args) + ")" + + +# NB: Its pretty tedious to go through and hand write all the graph insert afters +# Could not find a Pytorch utility that simplifies this so we have this class. I want +# remove it if we find a PyTorch alternative +class SubgraphBuilder: + """Cursor-based helper for inserting a sequence of FX ``call_function`` nodes. + + Construct it with the graph and an anchor node, then call it like a + function to append each new node immediately after the current cursor:: + + with SubgraphBuilder(graph, node) as b: + re = b(aten.select.int, inp, -1, 0) + im = b(aten.select.int, inp, -1, 1) + out = b(aten.add.Tensor, re, im) + + Each call inserts one ``call_function`` node right after the cursor and + advances the cursor to that node. Scalar / list arguments are forwarded + as-is. + + On ``__exit__`` the graph is linted to catch any malformed nodes inserted + during the block. Exceptions from user code propagate normally; lint + errors are only raised when the block itself succeeds. + """ + + __slots__ = ("_g", "_anchor_desc", "_cursor", "_inserted") + + def __init__(self, graph: torch.fx.Graph, cursor: Node) -> None: + self._g = graph + # Snapshot the description now — the anchor node is erased inside the block. + self._anchor_desc: str = _fmt_node(cursor) + self._cursor = cursor + self._inserted: List[Node] = [] + + @property + def cursor(self) -> Node: + return self._cursor + + def __call__(self, op: object, *args: object) -> Node: + with self._g.inserting_after(self._cursor): + node = self._g.call_function(op, args=args) + self._cursor = node + self._inserted.append(node) + return node + + def __enter__(self) -> "SubgraphBuilder": + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + if exc_type is None: + if logger.isEnabledFor(logging.DEBUG) and self._inserted: + lines = [f" rewrite {self._anchor_desc} ->"] + for n in self._inserted: + op_name = getattr(n.target, "__name__", str(n.target)) + lines.append(f" {_fmt_node(n)} = {op_name}{_fmt_args(n.args)}") + logger.debug("\n".join(lines)) + self._g.lint() diff --git a/py/torch_tensorrt/dynamo/lowering/__init__.py b/py/torch_tensorrt/dynamo/lowering/__init__.py index bec5e407b5..3c73e42f86 100644 --- a/py/torch_tensorrt/dynamo/lowering/__init__.py +++ b/py/torch_tensorrt/dynamo/lowering/__init__.py @@ -5,3 +5,4 @@ ) from ._decompositions import get_decompositions # noqa: F401 from .passes import * +from torch_tensorrt.dynamo.lowering._SubgraphBuilder import SubgraphBuilder diff --git a/py/torch_tensorrt/dynamo/lowering/passes/complex_graph_rewrite.py b/py/torch_tensorrt/dynamo/lowering/passes/complex_graph_rewrite.py index c3ead218aa..7c03080c8e 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/complex_graph_rewrite.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/complex_graph_rewrite.py @@ -1,17 +1,106 @@ import logging -from typing import Callable, List, Set, Tuple +import math +import operator +from typing import Callable, List, Optional, Tuple import torch from torch._subclasses.fake_tensor import FakeTensorMode from torch.fx import GraphModule, Node from torch.fx.experimental.proxy_tensor import unset_fake_temporarily + from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.lowering._SubgraphBuilder import SubgraphBuilder from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( clean_up_graph_after_modifications, ) logger = logging.getLogger(__name__) +# Ops that are elementwise-safe on the [..., 2] real layout used to represent +# complex tensors. These ops apply independently to every scalar in the tensor +# (including both the real and imaginary components stored in the last dim) so +# no explicit rewrite is needed — the pass-through behaviour is correct. +# +# NOTE: add.Scalar / sub.Scalar are NOT in this set. (a+bi)+s = (a+s)+bi +# adds the scalar only to the real part, but on the [...,2] layout +# add.Scalar would add to both parts. Those need explicit rewrites. +_ELEMENTWISE_SAFE: frozenset = frozenset( + { + # Arithmetic — component-wise operations are correct by construction + torch.ops.aten.add.Tensor, + torch.ops.aten.sub.Tensor, + torch.ops.aten.neg.default, + torch.ops.aten.mul.Scalar, # scalar*(re,im) — both parts scaled equally + torch.ops.aten.div.Scalar, # (re,im)/scalar — both parts divided equally + # Structural / copy — operate on the whole tensor without touching content. + # Note: permute.default is NOT here; it needs an explicit rewrite to append + # the trailing real/imag dimension index to the dims list. + torch.ops.aten.clone.default, + torch.ops.aten.detach.default, + torch.ops.aten.alias.default, + # NOTE: expand.default is NOT here — it takes a shape arg that must + # include the trailing real/imag dim. It has an explicit handler below. + # NOTE: t.default is NOT here — it requires an explicit handler since t() + # raises on tensors with more than 2 dimensions (which the [..., 2] real + # layout always is). + # squeeze.default (no dim arg) squeezes all size-1 dims; the trailing + # real/imag dim is always size 2 so it is never accidentally squeezed. + torch.ops.aten.squeeze.default, + # Construction — zeros_like is layout-neutral (zeros everywhere = 0+0i). + # ones_like is NOT here: ones([a, b]) in real layout = [1, 1] per element + # = 1+1i, but we want 1+0i. It has an explicit handler below. + torch.ops.aten.zeros_like.default, + # Conditional selection — correct on the real layout when mask broadcasts + torch.ops.aten.where.self, + # Rounding — applies to each float independently; complex rounding is + # undefined in PyTorch so these only appear after the rewrite anyway + torch.ops.aten.ceil.default, + torch.ops.aten.floor.default, + torch.ops.aten.round.default, + torch.ops.aten.trunc.default, + # Structural list indexing — extracts one element from a split/chunk output. + # The element is still in real [..., 2] complex layout; the flag is already + # set by the pre-rewrite annotation loop. No view_as_complex wrapping needed. + operator.getitem, + # Shape queries — sym_size.int reads a tensor's dimension value, which is not + # affected by the complex [..., 2] layout. Without this entry the fallback + # wrapper inserts view_as_complex before the sym_size node, causing the shape + # to be computed from a complex tensor in the PyTorch fallback and returning + # a raw SymInt backing value (garbage) to TRT for reshape dims. + torch.ops.aten.sym_size.int, + } +) + + +def _complex_unpacker(*ops: object) -> Callable: + """Decorator that registers a rewrite method for a complex aten op into a real value subgraph. + + Usage:: + + @_complex_unpacker(aten.sin.default, aten.cos.default) + def _rewrite_sin_cos(self, node): ... + + The ops are stored on the function as ``._complex_unpacker_ops`` and picked up by + ``@_register_unpackers`` when the class is fully defined. + """ + + def decorator(fn: Callable) -> Callable: + fn._complex_unpacker_ops = ops + return fn + + return decorator + + +def _register_unpackers(cls: type) -> type: + """Class decorator that builds ``cls._DISPATCH`` from all methods tagged + with ``@_complex_unpacker``. Applied once at class-definition time.""" + dispatch: dict = {} + for attr in vars(cls).values(): + for op in getattr(attr, "_complex_unpacker_ops", ()): + dispatch[op] = attr + cls._DISPATCH = dispatch + return cls + class ComplexSubGraphInfo: def __init__( @@ -44,66 +133,54 @@ def is_complex_dtype(self, node: Node) -> bool: if hasattr(val, "dtype"): dtype = val.dtype - logger.debug(f"dtype of node: {dtype}") return dtype in {torch.complex64, torch.complex128} + def has_complex_input(self, node: Node) -> bool: + """Return True if any input to node has complex dtype.""" + return any(self.is_complex_dtype(inp) for inp in node.all_input_nodes) + def node_include_in_subgraph(self, node: Node) -> bool: - # Include only call_function ops on complex tensors - if node.op == "call_function" and self.is_complex_dtype(node): - logger.debug( - f"node.op is added to subgraph: {node.op}, node name: {node.name} is complex" - ) - return node.op == "call_function" and self.is_complex_dtype(node) + # Include call_function ops that either output complex OR consume complex inputs. + # The second condition catches real-output ops like abs, angle, real, imag whose + # inputs are complex and must be rewritten alongside the rest of the subgraph. + if node.op != "call_function": + return False + return self.is_complex_dtype(node) or self.has_complex_input(node) + + def find_all_complex_subgraphs(self, gm: GraphModule) -> List[ComplexSubGraphInfo]: + """Forward scan: collect all complex-dtype call_function nodes as one subgraph. - def subgraph_from_anchor(self, anchor_node: Node) -> ComplexSubGraphInfo: + Scans forward over every node and collects all call_function nodes whose + output is complex — regardless of whether they are bounded by view_as_real. + This ensures complex ops that feed directly into graph outputs (no view_as_real) + are still rewritten to real arithmetic. + """ subgraph_nodes: Set[Node] = set() input_nodes: Set[Node] = set() - stack = [anchor_node] - while stack: - n = stack.pop() - if n in subgraph_nodes: + for node in gm.graph.nodes: + if not self.node_include_in_subgraph(node): continue - subgraph_nodes.add(n) - logger.debug(f"node {n.name} is added to subgraph") - for inp in n.all_input_nodes: - if self.node_include_in_subgraph(inp): - stack.append(inp) - else: + subgraph_nodes.add(node) + for inp in node.all_input_nodes: + if not self.node_include_in_subgraph(inp): input_nodes.add(inp) - return ComplexSubGraphInfo( - [anchor_node], list(subgraph_nodes), list(input_nodes) - ) - - def find_complex_op_subgraphs( - self, gm: GraphModule, anchor_target: str - ) -> List[ComplexSubGraphInfo]: - complex_op_subgraphs: List[ComplexSubGraphInfo] = [] - for node in gm.graph.nodes: - if node.target == anchor_target: - new_sub = self.subgraph_from_anchor(node) - # if any intersecting nodes between seen and sub.subgraph_nodes they should be merged - merged = False - for existing_sub in complex_op_subgraphs: - if set(existing_sub.subgraph_nodes) & set(new_sub.subgraph_nodes): - logger.debug(f"merging subgraphs {existing_sub} {new_sub}") - # merge the two subgraphs - existing_sub.subgraph_nodes = list( - set(existing_sub.subgraph_nodes) - | set(new_sub.subgraph_nodes) - ) - existing_sub.input_nodes = list( - set(existing_sub.input_nodes) | set(new_sub.input_nodes) - ) - existing_sub.anchor_nodes = list( - set(existing_sub.anchor_nodes) | set(new_sub.anchor_nodes) - ) - merged = True - break - if not merged: - complex_op_subgraphs.append(new_sub) - return complex_op_subgraphs + if not subgraph_nodes: + return [] + # Sort in topological (graph) order so the rewriter processes producers + # before consumers, avoiding the case where e.g. a mul node is rewritten + # before its sin/cos inputs are rewritten (which causes wrong results). + node_order = {n: i for i, n in enumerate(gm.graph.nodes)} + ordered = sorted(subgraph_nodes, key=lambda n: node_order.get(n, 0)) + return [ + ComplexSubGraphInfo( + anchor_nodes=ordered, + subgraph_nodes=ordered, + input_nodes=list(input_nodes), + ) + ] +@_register_unpackers class ComplexGraphRewriter: def __init__(self, gm: GraphModule, truncate_double: bool = False) -> None: self.gm = gm @@ -146,21 +223,61 @@ def get_attr_tensor(self, target): # type: ignore f"Attribute {target} not found in gm parameters or buffers." ) - def replace_input_node(self, input_node: Node) -> None: + def replace_input_node( + self, input_node: Node, fake_mode: Optional[FakeTensorMode] = None + ) -> None: modified = False - logger.debug(f"Replacing input node: {input_node.name}") new_shape, new_dtype, device = self.extract_shape_dtype_device(input_node) - real_tensor = torch.empty(new_shape, dtype=new_dtype, device=device) if input_node.op == "placeholder": - with FakeTensorMode() as fake_mode: + if fake_mode is None: + fake_mode = FakeTensorMode() + # Preserve symbolic dimensions from the original placeholder's fake + # tensor so that dynamic-shape information (SymInt ranges from + # torch.export) survives the rewrite. We build the new fake tensor + # by appending a concrete 2 to the original symbolic shape. + # + # We use the *original* placeholder's FakeTensorMode + # (which owns the ShapeEnv with the export's range constraints) so + # that the new SymInt dimensions belong to the same ShapeEnv as all + # other nodes in the graph. Using shared_fake_mode would create a + # separate ShapeEnv and cause "symbol from different env" errors + # during FakeTensorProp. + orig_fake = input_node.meta.get("val", None) + if orig_fake is not None and hasattr(orig_fake, "shape"): + # orig_fake.shape contains the symbolic sizes; append 2 for real/imag. + sym_shape = list(orig_fake.shape) + [2] + orig_mode = getattr(orig_fake, "fake_mode", None) + create_mode = orig_mode if orig_mode is not None else fake_mode + with create_mode: + fake_tensor = torch.empty(sym_shape, dtype=new_dtype, device=device) + else: + concrete_shape = tuple( + int(s) if not isinstance(s, int) else s for s in new_shape + ) + real_tensor = torch.empty( + concrete_shape, dtype=new_dtype, device=device + ) fake_tensor = fake_mode.from_tensor(real_tensor) with self.gm.graph.inserting_before(input_node): - new_node = self.gm.graph.placeholder(input_node.target + "_reshaped") + new_node = self.gm.graph.placeholder( + input_node.target + "_unpacked_complex" + ) new_node.meta["val"] = fake_tensor + new_node.meta["is_complex_layout"] = True + logger.debug( + " unpack placeholder %s%s -> %s%s", + input_node.name, + tuple(fake_tensor.shape[:-1]), + new_node.name, + tuple(fake_tensor.shape), + ) elif input_node.op == "get_attr": - new_attr_name = input_node.target + "_reshaped" + # Sanitize dots from nested-module targets (e.g. "block1.freq") + # so register_buffer does not raise KeyError on dotted names. + sanitized = input_node.target.replace(".", "__") # type: ignore + new_attr_name = sanitized + "_unpacked_complex" with unset_fake_temporarily(): original_tensor = self.get_attr_tensor(input_node.target) # type: ignore stacked_tensor = torch.stack( @@ -169,93 +286,1530 @@ def replace_input_node(self, input_node: Node) -> None: self.gm.register_buffer(new_attr_name, stacked_tensor) with self.gm.graph.inserting_after(input_node): new_node = self.gm.graph.get_attr(new_attr_name) - else: - logger.debug( - f"Unsupported node type in replacement of input node: {input_node.op}" - ) + # Set fake-tensor metadata on the new node so that _is_complex_layout_node + # can identify it as a complex-layout [..., 2] tensor later when + # processing ops that use this buffer. + if fake_mode is not None: + try: + with unset_fake_temporarily(): + real_tensor = torch.empty( + stacked_tensor.shape, + dtype=stacked_tensor.dtype, + device=stacked_tensor.device, + ) + new_node.meta["val"] = fake_mode.from_tensor(real_tensor) + except Exception: + pass # best-effort + new_node.meta["is_complex_layout"] = True logger.debug( - "This complex subgraph inputnode type does not need to replaced" + " unpack get_attr %s%s -> %s%s", + input_node.target, + tuple(original_tensor.shape), + new_attr_name, + tuple(stacked_tensor.shape), ) + else: + pass # call_function inputs are rewritten in-place by the op handlers input_node.replace_all_uses_with(new_node) self.gm.graph.erase_node(input_node) clean_up_graph_after_modifications(self.gm) + # ------------------------------------------------------------------ + # Private graph-building helpers + # + # Each helper takes a SubgraphBuilder and emits a sub-sequence of nodes, + # advancing the builder's cursor. They return the last node(s) they + # inserted. + # ------------------------------------------------------------------ + + @staticmethod + def _inline_select_re_im(b: SubgraphBuilder, inp: Node) -> Tuple[Node, Node]: + """Select re (index 0) and im (index 1) from a [..., 2] tensor.""" + re = b(torch.ops.aten.select.int, inp, -1, 0) + im = b(torch.ops.aten.select.int, inp, -1, 1) + return re, im + + @staticmethod + def _inline_cat_re_im(b: SubgraphBuilder, out_re: Node, out_im: Node) -> Node: + """Rebuild a [..., 2] complex-layout tensor from re and im nodes.""" + re_u = b(torch.ops.aten.unsqueeze.default, out_re, -1) + im_u = b(torch.ops.aten.unsqueeze.default, out_im, -1) + out = b(torch.ops.aten.cat.default, [re_u, im_u], -1) + out.meta["is_complex_layout"] = True + return out + + @staticmethod + def _inline_complex_log( + b: SubgraphBuilder, re: Node, im: Node + ) -> Tuple[Node, Node]: + """log(a+bi) = 0.5*log(a²+b²) + i*atan2(b, a)""" + re2 = b(torch.ops.aten.mul.Tensor, re, re) + im2 = b(torch.ops.aten.mul.Tensor, im, im) + r2 = b(torch.ops.aten.add.Tensor, re2, im2) + log_r2 = b(torch.ops.aten.log.default, r2) + log_re = b(torch.ops.aten.mul.Tensor, log_r2, 0.5) + log_im = b(torch.ops.aten.atan2.default, im, re) + return log_re, log_im + + @staticmethod + def _inline_complex_exp( + b: SubgraphBuilder, re: Node, im: Node + ) -> Tuple[Node, Node]: + """exp(a+bi) = e^a*cos(b) + i*e^a*sin(b)""" + ea = b(torch.ops.aten.exp.default, re) + cos_b = b(torch.ops.aten.cos.default, im) + sin_b = b(torch.ops.aten.sin.default, im) + exp_re = b(torch.ops.aten.mul.Tensor, ea, cos_b) + exp_im = b(torch.ops.aten.mul.Tensor, ea, sin_b) + return exp_re, exp_im + + @staticmethod + def _inline_complex_mul( + b: SubgraphBuilder, re1: Node, im1: Node, re2: Node, im2: Node + ) -> Tuple[Node, Node]: + """(a+bi)(c+di) = (ac-bd) + (ad+bc)i""" + ac = b(torch.ops.aten.mul.Tensor, re1, re2) + bd = b(torch.ops.aten.mul.Tensor, im1, im2) + ad = b(torch.ops.aten.mul.Tensor, re1, im2) + bc = b(torch.ops.aten.mul.Tensor, im1, re2) + out_re = b(torch.ops.aten.sub.Tensor, ac, bd) + out_im = b(torch.ops.aten.add.Tensor, ad, bc) + return out_re, out_im + + @staticmethod + def _inline_complex_div( + b: SubgraphBuilder, re1: Node, im1: Node, re2: Node, im2: Node + ) -> Tuple[Node, Node]: + """(a+bi)/(c+di) = ((ac+bd) + (bc-ad)i) / (c²+d²)""" + c2 = b(torch.ops.aten.mul.Tensor, re2, re2) + d2 = b(torch.ops.aten.mul.Tensor, im2, im2) + denom = b(torch.ops.aten.add.Tensor, c2, d2) + ac = b(torch.ops.aten.mul.Tensor, re1, re2) + bd = b(torch.ops.aten.mul.Tensor, im1, im2) + bc = b(torch.ops.aten.mul.Tensor, im1, re2) + ad = b(torch.ops.aten.mul.Tensor, re1, im2) + numer_re = b(torch.ops.aten.add.Tensor, ac, bd) + numer_im = b(torch.ops.aten.sub.Tensor, bc, ad) + out_re = b(torch.ops.aten.div.Tensor, numer_re, denom) + out_im = b(torch.ops.aten.div.Tensor, numer_im, denom) + return out_re, out_im + + @staticmethod + def _inline_complex_sqrt( + b: SubgraphBuilder, re: Node, im: Node + ) -> Tuple[Node, Node]: + """sqrt(z) = r^0.5 * (cos(θ/2) + i*sin(θ/2))""" + re2 = b(torch.ops.aten.mul.Tensor, re, re) + im2 = b(torch.ops.aten.mul.Tensor, im, im) + r2 = b(torch.ops.aten.add.Tensor, re2, im2) + r = b(torch.ops.aten.sqrt.default, r2) + r_sq = b(torch.ops.aten.pow.Tensor_Scalar, r, 0.5) + theta = b(torch.ops.aten.atan2.default, im, re) + half_theta = b(torch.ops.aten.mul.Tensor, theta, 0.5) + cos_ht = b(torch.ops.aten.cos.default, half_theta) + sin_ht = b(torch.ops.aten.sin.default, half_theta) + sq_re = b(torch.ops.aten.mul.Tensor, r_sq, cos_ht) + sq_im = b(torch.ops.aten.mul.Tensor, r_sq, sin_ht) + return sq_re, sq_im + + # ------------------------------------------------------------------ + # Per-op rewrite handlers + # + # Each method receives the node to rewrite and returns True if it + # modified the graph. They are registered in _build_dispatch_table() + # which is called at the end of __init__. + # ------------------------------------------------------------------ + + @_complex_unpacker(torch.ops.aten.view_as_complex.default) + def _rewrite_view_as_complex(self, node: Node) -> bool: + inp = node.args[0] + # The input to view_as_complex is a (..., 2) real-layout tensor that + # represents a complex tensor. After erasing view_as_complex, downstream + # consumers (e.g. mul.Tensor) need to know that this node is in complex + # layout so the correct rewrite branch is chosen. + if isinstance(inp, torch.fx.Node): + inp.meta["is_complex_layout"] = True + node.replace_all_uses_with(inp) + self.gm.graph.erase_node(node) + # Return True so the caller triggers propagate_metadata + gm.recompile(). + # Without recompile the compiled forward still calls the erased node. + return True + + @_complex_unpacker(torch.ops.aten.view_as_real.default) + def _rewrite_view_as_real(self, node: Node) -> bool: + node.replace_all_uses_with(node.args[0]) + self.gm.graph.erase_node(node) + return True # triggers recompile, same reason as above + + @_complex_unpacker(torch.ops.aten.permute.default) + def _rewrite_permute(self, node: Node) -> bool: + # permute on a complex tensor: after rewrite the tensor has an extra + # trailing dim of size 2 (real/imag). Append the index for that + # trailing dim so the permutation stays valid. + inp = node.args[0] + orig_dims = list(node.args[1]) + n_orig = len(orig_dims) + new_dims = [d % n_orig for d in orig_dims] + [n_orig] + with SubgraphBuilder(self.gm.graph, node) as b: + out = b(torch.ops.aten.permute.default, inp, new_dims) + node.replace_all_uses_with(out) + out.meta["is_complex_layout"] = True + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.mul.Tensor, torch.ops.aten.div.Tensor) + def _rewrite_mul_div_tensor(self, node: Node) -> bool: + arg0_is_node = isinstance(node.args[0], torch.fx.Node) + arg1_is_node = isinstance(node.args[1], torch.fx.Node) + + if not arg0_is_node and not arg1_is_node: + return False # both scalars + + if node.target == torch.ops.aten.mul.Tensor and ( + not arg0_is_node or not arg1_is_node + ): + return False # scalar * complex — elementwise-safe + + if node.target == torch.ops.aten.div.Tensor and not arg1_is_node: + return False # complex / scalar — elementwise-safe + + if node.target == torch.ops.aten.div.Tensor and not arg0_is_node: + # scalar / complex: s/(a+bi) = (s*a/(a²+b²)) + i*(-s*b/(a²+b²)) + scalar_val = node.args[0] + z_node = node.args[1] + with SubgraphBuilder(self.gm.graph, node) as b: + re = b(torch.ops.aten.select.int, z_node, -1, 0) + im = b(torch.ops.aten.select.int, z_node, -1, 1) + re2 = b(torch.ops.aten.mul.Tensor, re, re) + im2 = b(torch.ops.aten.mul.Tensor, im, im) + denom = b(torch.ops.aten.add.Tensor, re2, im2) + re_s = b(torch.ops.aten.mul.Tensor, re, scalar_val) + out_re = b(torch.ops.aten.div.Tensor, re_s, denom) + im_s = b(torch.ops.aten.mul.Tensor, im, scalar_val) + neg_im_s = b(torch.ops.aten.neg.default, im_s) + out_im = b(torch.ops.aten.div.Tensor, neg_im_s, denom) + out = self._inline_cat_re_im(b, out_re, out_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + # Both args are Nodes from here on. + if node.target == torch.ops.aten.div.Tensor: + arg0_layout = self._is_complex_layout_node(node.args[0]) + arg1_layout = self._is_complex_layout_node(node.args[1]) + + if arg0_layout and not arg1_layout: + # complex_layout / real — unsqueeze denom for correct broadcast + with SubgraphBuilder(self.gm.graph, node) as b: + denom_unsq = b(torch.ops.aten.unsqueeze.default, node.args[1], -1) + out = b(torch.ops.aten.div.Tensor, node.args[0], denom_unsq) + out.meta["is_complex_layout"] = True + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + elif not arg0_layout and not arg1_layout: + return False # both real — elementwise-safe + else: + # complex / complex — full div rewrite + x_pf = node.args[0].op != "get_attr" + y_pf = node.args[1].op != "get_attr" + original_div, replacement = complex_div_replacement(x_pf, y_pf) + + def match_complex_div( + match: torch.fx.subgraph_rewriter.Match, + original_graph: object, + pattern_graph: object, + ) -> bool: + for original_node in match.nodes_map.values(): + if not isinstance(original_node, torch.fx.Node): + continue + if original_node.name == node.name: + return True + return False + + torch.fx.subgraph_rewriter.replace_pattern_with_filters( + self.gm, + original_div, + replacement, + match_filters=[match_complex_div], + ignore_literals=True, + ) + return True + + # mul.Tensor, both nodes + # Use SubgraphBuilder directly rather than replace_pattern_with_filters so + # that self-multiplication (mul(x, x)) is handled correctly. + # replace_pattern_with_filters requires distinct placeholder nodes for x and y, + # so it silently produces no matches when both args are the same node. + if node.meta.get("is_complex_layout", False): + x, y = node.args[0], node.args[1] + x_is_get_attr = x.op == "get_attr" + y_is_get_attr = y.op == "get_attr" + x_is_complex = self._is_complex_layout_node(x) + y_is_complex = self._is_complex_layout_node(y) + + # complex × real (or real × complex): just scale both components + if x_is_complex and not y_is_complex and not x_is_get_attr: + with SubgraphBuilder(self.gm.graph, node) as b: + x_re = b(torch.ops.aten.select.int, x, -1, 0) + x_im = b(torch.ops.aten.select.int, x, -1, 1) + out_re = b(torch.ops.aten.mul.Tensor, x_re, y) + out_im = b(torch.ops.aten.mul.Tensor, x_im, y) + out = self._inline_cat_re_im(b, out_re, out_im) + node.replace_all_uses_with(out) + out.meta["is_complex_layout"] = True + self.gm.graph.erase_node(node) + return True + if not x_is_complex and y_is_complex and not y_is_get_attr: + with SubgraphBuilder(self.gm.graph, node) as b: + y_re = b(torch.ops.aten.select.int, y, -1, 0) + y_im = b(torch.ops.aten.select.int, y, -1, 1) + out_re = b(torch.ops.aten.mul.Tensor, x, y_re) + out_im = b(torch.ops.aten.mul.Tensor, x, y_im) + out = self._inline_cat_re_im(b, out_re, out_im) + node.replace_all_uses_with(out) + out.meta["is_complex_layout"] = True + self.gm.graph.erase_node(node) + return True + + if not x_is_get_attr and not y_is_get_attr: + # Both are ITensors — use select.int (TRT-compatible) + with SubgraphBuilder(self.gm.graph, node) as b: + x_re = b(torch.ops.aten.select.int, x, -1, 0) + x_im = b(torch.ops.aten.select.int, x, -1, 1) + y_re = b(torch.ops.aten.select.int, y, -1, 0) + y_im = b(torch.ops.aten.select.int, y, -1, 1) + ac = b(torch.ops.aten.mul.Tensor, x_re, y_re) + bd = b(torch.ops.aten.mul.Tensor, x_im, y_im) + ad = b(torch.ops.aten.mul.Tensor, x_re, y_im) + bc = b(torch.ops.aten.mul.Tensor, x_im, y_re) + out_re = b(torch.ops.aten.sub.Tensor, ac, bd) + out_im = b(torch.ops.aten.add.Tensor, ad, bc) + out = self._inline_cat_re_im(b, out_re, out_im) + node.replace_all_uses_with(out) + out.meta["is_complex_layout"] = True + self.gm.graph.erase_node(node) + return True + else: + # At least one arg is a get_attr buffer — fall back to the + # pattern rewriter which uses tensor indexing for get_attr nodes. + x_pf = not x_is_get_attr + y_pf = not y_is_get_attr + original_mul, replacement = complex_mul_replacement(x_pf, y_pf) + + def match_complex_mul( + match: torch.fx.subgraph_rewriter.Match, + original_graph: object, + pattern_graph: object, + ) -> bool: + for original_node in match.nodes_map.values(): + if not isinstance(original_node, torch.fx.Node): + continue + if original_node.name == node.name: + return True + return False + + torch.fx.subgraph_rewriter.replace_pattern_with_filters( + self.gm, + original_mul, + replacement, + match_filters=[match_complex_mul], + ignore_literals=True, + ) + return True + return False + + @_complex_unpacker(torch.ops.aten.add.Tensor, torch.ops.aten.sub.Tensor) + def _rewrite_add_sub_tensor_scalar(self, node: Node) -> bool: + # add.Tensor(z, scalar) / sub.Tensor(z, scalar): scalar applies to real part only. + if len(node.args) < 2 or isinstance(node.args[1], torch.fx.Node): + return False + inp, scalar = node.args[0], node.args[1] + with SubgraphBuilder(self.gm.graph, node) as b: + re = b(torch.ops.aten.select.int, inp, -1, 0) + im = b(torch.ops.aten.select.int, inp, -1, 1) + new_re = b(node.target, re, scalar) + out = self._inline_cat_re_im(b, new_re, im) + node.replace_all_uses_with(out) + out.meta["is_complex_layout"] = True + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.ones_like.default) + def _rewrite_ones_like(self, node: Node) -> bool: + # ones_like in [..., 2] layout produces [1, 1] = 1+1i. We want 1+0i. + inp = node.args[0] + with SubgraphBuilder(self.gm.graph, node) as b: + re_slice = b(torch.ops.aten.select.int, inp, -1, 0) + out_re = b(torch.ops.aten.ones_like.default, re_slice) + out_im = b(torch.ops.aten.zeros_like.default, re_slice) + out = self._inline_cat_re_im(b, out_re, out_im) + node.replace_all_uses_with(out) + out.meta["is_complex_layout"] = True + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.full_like.default) + def _rewrite_full_like(self, node: Node) -> bool: + # full_like(z, fill_value) in [..., 2] layout fills both re and im with + # fill_value → fill_value + fill_value*i. We want fill_value + 0i. + inp = node.args[0] + fill_value = node.args[1] + with SubgraphBuilder(self.gm.graph, node) as b: + re_slice = b(torch.ops.aten.select.int, inp, -1, 0) + out_re = b(torch.ops.aten.full_like.default, re_slice, fill_value) + out_im = b(torch.ops.aten.zeros_like.default, re_slice) + out = self._inline_cat_re_im(b, out_re, out_im) + node.replace_all_uses_with(out) + out.meta["is_complex_layout"] = True + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.sum.dim_IntList) + def _rewrite_sum_dim(self, node: Node) -> bool: + # sum.dim_IntList(inp, dim_list, keepdim=False, dtype=None) + # Negative dims must be shifted by -1 to skip the trailing real/imag dim. + inp = node.args[0] + dims = list(node.args[1]) + new_dims = [d - 1 if d < 0 else d for d in dims] + if new_dims == dims: + return False # all positive — pass-through is correct + keepdim = node.args[2] if len(node.args) > 2 else False + extra = list(node.args[3:]) + with SubgraphBuilder(self.gm.graph, node) as b: + out = b(torch.ops.aten.sum.dim_IntList, inp, new_dims, keepdim, *extra) + node.replace_all_uses_with(out) + out.meta["is_complex_layout"] = True + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.mean.dim) + def _rewrite_mean_dim(self, node: Node) -> bool: + # mean.dim(inp, dim_list, keepdim=False, dtype=None) + inp = node.args[0] + dims = list(node.args[1]) + new_dims = [d - 1 if d < 0 else d for d in dims] + if new_dims == dims: + return False # all positive — pass-through is correct + keepdim = node.args[2] if len(node.args) > 2 else False + extra = list(node.args[3:]) + with SubgraphBuilder(self.gm.graph, node) as b: + out = b(torch.ops.aten.mean.dim, inp, new_dims, keepdim, *extra) + node.replace_all_uses_with(out) + out.meta["is_complex_layout"] = True + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.prod.dim_int) + def _rewrite_prod_dim(self, node: Node) -> bool: + # prod.dim_int(inp, dim, keepdim=False, dtype=None) + inp = node.args[0] + dim = node.args[1] + if dim >= 0: + return False # positive dim — pass-through is correct + new_dim = dim - 1 + keepdim = node.args[2] if len(node.args) > 2 else False + extra = list(node.args[3:]) + with SubgraphBuilder(self.gm.graph, node) as b: + out = b(torch.ops.aten.prod.dim_int, inp, new_dim, keepdim, *extra) + node.replace_all_uses_with(out) + out.meta["is_complex_layout"] = True + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.narrow.default) + def _rewrite_narrow(self, node: Node) -> bool: + # narrow(inp, dim, start, length) — shift negative dim by -1 + inp, dim, start, length = node.args + if dim >= 0: + return False + new_dim = dim - 1 + with SubgraphBuilder(self.gm.graph, node) as b: + out = b(torch.ops.aten.narrow.default, inp, new_dim, start, length) + node.replace_all_uses_with(out) + out.meta["is_complex_layout"] = True + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.roll.default) + def _rewrite_roll(self, node: Node) -> bool: + # roll(inp, shifts, dims) — shift negative dims by -1 + inp = node.args[0] + shifts = node.args[1] + dims = list(node.args[2]) if len(node.args) > 2 else [] + new_dims = [d - 1 if d < 0 else d for d in dims] + if new_dims == dims: + return False + with SubgraphBuilder(self.gm.graph, node) as b: + out = b(torch.ops.aten.roll.default, inp, shifts, new_dims) + node.replace_all_uses_with(out) + out.meta["is_complex_layout"] = True + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.flip.default) + def _rewrite_flip(self, node: Node) -> bool: + # flip(inp, dims) — shift negative dims by -1 + inp = node.args[0] + dims = list(node.args[1]) + new_dims = [d - 1 if d < 0 else d for d in dims] + if new_dims == dims: + return False + with SubgraphBuilder(self.gm.graph, node) as b: + out = b(torch.ops.aten.flip.default, inp, new_dims) + node.replace_all_uses_with(out) + out.meta["is_complex_layout"] = True + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.repeat.default) + def _rewrite_repeat(self, node: Node) -> bool: + # repeat(inp, repeats) — repeats must include a trailing 1 for the + # real/imag dim so the layout is not disrupted. + inp = node.args[0] + repeats = list(node.args[1]) + new_repeats = repeats + [1] + with SubgraphBuilder(self.gm.graph, node) as b: + out = b(torch.ops.aten.repeat.default, inp, new_repeats) + node.replace_all_uses_with(out) + out.meta["is_complex_layout"] = True + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten._conj.default) + def _rewrite_conj(self, node: Node) -> bool: + # conj(a+bi) = a - bi + inp = node.args[0] + with SubgraphBuilder(self.gm.graph, node) as b: + re = b(torch.ops.aten.select.int, inp, -1, 0) + im = b(torch.ops.aten.select.int, inp, -1, 1) + neg_im = b(torch.ops.aten.neg.default, im) + out = self._inline_cat_re_im(b, re, neg_im) + node.replace_all_uses_with(out) + out.meta["is_complex_layout"] = True + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.reciprocal.default) + def _rewrite_reciprocal(self, node: Node) -> bool: + # 1/(a+bi) = a/(a²+b²) - ib/(a²+b²) + inp = node.args[0] + with SubgraphBuilder(self.gm.graph, node) as b: + re = b(torch.ops.aten.select.int, inp, -1, 0) + im = b(torch.ops.aten.select.int, inp, -1, 1) + re2 = b(torch.ops.aten.mul.Tensor, re, re) + im2 = b(torch.ops.aten.mul.Tensor, im, im) + denom = b(torch.ops.aten.add.Tensor, re2, im2) + out_re = b(torch.ops.aten.div.Tensor, re, denom) + neg_im = b(torch.ops.aten.neg.default, im) + out_im = b(torch.ops.aten.div.Tensor, neg_im, denom) + out = self._inline_cat_re_im(b, out_re, out_im) + node.replace_all_uses_with(out) + out.meta["is_complex_layout"] = True + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.abs.default) + def _rewrite_abs(self, node: Node) -> bool: + # |a+bi| = sqrt(a²+b²) + inp = node.args[0] + with SubgraphBuilder(self.gm.graph, node) as b: + re = b(torch.ops.aten.select.int, inp, -1, 0) + im = b(torch.ops.aten.select.int, inp, -1, 1) + re2 = b(torch.ops.aten.mul.Tensor, re, re) + im2 = b(torch.ops.aten.mul.Tensor, im, im) + sum_ = b(torch.ops.aten.add.Tensor, re2, im2) + out = b(torch.ops.aten.sqrt.default, sum_) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.exp.default) + def _rewrite_exp(self, node: Node) -> bool: + # exp(a+bi) = e^a*cos(b) + i*e^a*sin(b) + inp = node.args[0] + with SubgraphBuilder(self.gm.graph, node) as b: + re = b(torch.ops.aten.select.int, inp, -1, 0) + im = b(torch.ops.aten.select.int, inp, -1, 1) + exp_re, exp_im = self._inline_complex_exp(b, re, im) + out = self._inline_cat_re_im(b, exp_re, exp_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.log.default) + def _rewrite_log(self, node: Node) -> bool: + # log(a+bi) = 0.5*log(a²+b²) + i*atan2(b, a) + inp = node.args[0] + with SubgraphBuilder(self.gm.graph, node) as b: + re = b(torch.ops.aten.select.int, inp, -1, 0) + im = b(torch.ops.aten.select.int, inp, -1, 1) + log_re, log_im = self._inline_complex_log(b, re, im) + out = self._inline_cat_re_im(b, log_re, log_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.pow.Tensor_Scalar, torch.ops.aten.sqrt.default) + def _rewrite_pow_sqrt(self, node: Node) -> bool: + # pow(a+bi, n) / sqrt via polar form: r^n*(cos(n*θ) + i*sin(n*θ)) + inp = node.args[0] + exponent = ( + node.args[1] if node.target == torch.ops.aten.pow.Tensor_Scalar else 0.5 + ) + with SubgraphBuilder(self.gm.graph, node) as b: + re = b(torch.ops.aten.select.int, inp, -1, 0) + im = b(torch.ops.aten.select.int, inp, -1, 1) + re2 = b(torch.ops.aten.mul.Tensor, re, re) + im2 = b(torch.ops.aten.mul.Tensor, im, im) + r2 = b(torch.ops.aten.add.Tensor, re2, im2) + r = b(torch.ops.aten.sqrt.default, r2) + rn = b(torch.ops.aten.pow.Tensor_Scalar, r, exponent) + theta = b(torch.ops.aten.atan2.default, im, re) + n_theta = b(torch.ops.aten.mul.Tensor, theta, exponent) + cos_n = b(torch.ops.aten.cos.default, n_theta) + sin_n = b(torch.ops.aten.sin.default, n_theta) + out_re = b(torch.ops.aten.mul.Tensor, rn, cos_n) + out_im = b(torch.ops.aten.mul.Tensor, rn, sin_n) + out = self._inline_cat_re_im(b, out_re, out_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.real.default) + def _rewrite_real(self, node: Node) -> bool: + with SubgraphBuilder(self.gm.graph, node) as b: + out = b(torch.ops.aten.select.int, node.args[0], -1, 0) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.imag.default) + def _rewrite_imag(self, node: Node) -> bool: + with SubgraphBuilder(self.gm.graph, node) as b: + out = b(torch.ops.aten.select.int, node.args[0], -1, 1) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.angle.default) + def _rewrite_angle(self, node: Node) -> bool: + # angle(a+bi) = atan2(b, a) + inp = node.args[0] + with SubgraphBuilder(self.gm.graph, node) as b: + re = b(torch.ops.aten.select.int, inp, -1, 0) + im = b(torch.ops.aten.select.int, inp, -1, 1) + out = b(torch.ops.aten.atan2.default, im, re) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.polar.default) + def _rewrite_polar(self, node: Node) -> bool: + # polar(r, theta) = r*cos(theta) + i*r*sin(theta) + r_arg, theta_arg = node.args[0], node.args[1] + with SubgraphBuilder(self.gm.graph, node) as b: + cos_t = b(torch.ops.aten.cos.default, theta_arg) + sin_t = b(torch.ops.aten.sin.default, theta_arg) + out_re = b(torch.ops.aten.mul.Tensor, r_arg, cos_t) + out_im = b(torch.ops.aten.mul.Tensor, r_arg, sin_t) + out = self._inline_cat_re_im(b, out_re, out_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.add.Scalar, torch.ops.aten.sub.Scalar) + def _rewrite_add_sub_scalar(self, node: Node) -> bool: + # (a+bi) ± s = (a±s) + bi — scalar applies to real part only + inp, scalar = node.args[0], node.args[1] + with SubgraphBuilder(self.gm.graph, node) as b: + re = b(torch.ops.aten.select.int, inp, -1, 0) + im = b(torch.ops.aten.select.int, inp, -1, 1) + new_re = b(node.target, re, scalar) + out = self._inline_cat_re_im(b, new_re, im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.log2.default, torch.ops.aten.log10.default) + def _rewrite_log2_log10(self, node: Node) -> bool: + # log_b(z) = log(z) / log(b) + base_val = ( + math.log(2.0) + if node.target == torch.ops.aten.log2.default + else math.log(10.0) + ) + inp = node.args[0] + inv_base = 1.0 / base_val + with SubgraphBuilder(self.gm.graph, node) as b: + re = b(torch.ops.aten.select.int, inp, -1, 0) + im = b(torch.ops.aten.select.int, inp, -1, 1) + log_re, log_im = self._inline_complex_log(b, re, im) + out_re = b(torch.ops.aten.mul.Tensor, log_re, inv_base) + out_im = b(torch.ops.aten.mul.Tensor, log_im, inv_base) + out = self._inline_cat_re_im(b, out_re, out_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.isnan.default, torch.ops.aten.isinf.default) + def _rewrite_isnan_isinf(self, node: Node) -> bool: + # isnan/isinf(z) = isnan/isinf(re) | isnan/isinf(im) + inp = node.args[0] + op = node.target + with SubgraphBuilder(self.gm.graph, node) as b: + re = b(torch.ops.aten.select.int, inp, -1, 0) + im = b(torch.ops.aten.select.int, inp, -1, 1) + re_flag = b(op, re) + im_flag = b(op, im) + out = b(torch.ops.aten.logical_or.default, re_flag, im_flag) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.log1p.default) + def _rewrite_log1p(self, node: Node) -> bool: + # log1p(a+bi) = log((a+1) + bi) + inp = node.args[0] + with SubgraphBuilder(self.gm.graph, node) as b: + re = b(torch.ops.aten.select.int, inp, -1, 0) + im = b(torch.ops.aten.select.int, inp, -1, 1) + re1 = b(torch.ops.aten.add.Tensor, re, 1.0) + log_re, log_im = self._inline_complex_log(b, re1, im) + out = self._inline_cat_re_im(b, log_re, log_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.expm1.default) + def _rewrite_expm1(self, node: Node) -> bool: + # expm1(a+bi) = (exp(a)*cos(b) - 1) + i*(exp(a)*sin(b)) + inp = node.args[0] + with SubgraphBuilder(self.gm.graph, node) as b: + re = b(torch.ops.aten.select.int, inp, -1, 0) + im = b(torch.ops.aten.select.int, inp, -1, 1) + exp_re, exp_im = self._inline_complex_exp(b, re, im) + out_re = b(torch.ops.aten.sub.Tensor, exp_re, 1.0) + out = self._inline_cat_re_im(b, out_re, exp_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.sin.default) + def _rewrite_sin(self, node: Node) -> bool: + # sin(a+bi) = sin(a)*cosh(b) + i*cos(a)*sinh(b) + inp = node.args[0] + with SubgraphBuilder(self.gm.graph, node) as b: + re = b(torch.ops.aten.select.int, inp, -1, 0) + im = b(torch.ops.aten.select.int, inp, -1, 1) + sin_a = b(torch.ops.aten.sin.default, re) + cosh_b = b(torch.ops.aten.cosh.default, im) + cos_a = b(torch.ops.aten.cos.default, re) + sinh_b = b(torch.ops.aten.sinh.default, im) + out_re = b(torch.ops.aten.mul.Tensor, sin_a, cosh_b) + out_im = b(torch.ops.aten.mul.Tensor, cos_a, sinh_b) + out = self._inline_cat_re_im(b, out_re, out_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.cos.default) + def _rewrite_cos(self, node: Node) -> bool: + # cos(a+bi) = cos(a)*cosh(b) - i*sin(a)*sinh(b) + inp = node.args[0] + with SubgraphBuilder(self.gm.graph, node) as b: + re = b(torch.ops.aten.select.int, inp, -1, 0) + im = b(torch.ops.aten.select.int, inp, -1, 1) + cos_a = b(torch.ops.aten.cos.default, re) + cosh_b = b(torch.ops.aten.cosh.default, im) + sin_a = b(torch.ops.aten.sin.default, re) + sinh_b = b(torch.ops.aten.sinh.default, im) + out_re = b(torch.ops.aten.mul.Tensor, cos_a, cosh_b) + raw_im = b(torch.ops.aten.mul.Tensor, sin_a, sinh_b) + out_im = b(torch.ops.aten.neg.default, raw_im) + out = self._inline_cat_re_im(b, out_re, out_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.sinh.default) + def _rewrite_sinh(self, node: Node) -> bool: + # sinh(a+bi) = sinh(a)*cos(b) + i*cosh(a)*sin(b) + inp = node.args[0] + with SubgraphBuilder(self.gm.graph, node) as b: + re = b(torch.ops.aten.select.int, inp, -1, 0) + im = b(torch.ops.aten.select.int, inp, -1, 1) + sinh_a = b(torch.ops.aten.sinh.default, re) + cos_b = b(torch.ops.aten.cos.default, im) + cosh_a = b(torch.ops.aten.cosh.default, re) + sin_b = b(torch.ops.aten.sin.default, im) + out_re = b(torch.ops.aten.mul.Tensor, sinh_a, cos_b) + out_im = b(torch.ops.aten.mul.Tensor, cosh_a, sin_b) + out = self._inline_cat_re_im(b, out_re, out_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.cosh.default) + def _rewrite_cosh(self, node: Node) -> bool: + # cosh(a+bi) = cosh(a)*cos(b) + i*sinh(a)*sin(b) + inp = node.args[0] + with SubgraphBuilder(self.gm.graph, node) as b: + re = b(torch.ops.aten.select.int, inp, -1, 0) + im = b(torch.ops.aten.select.int, inp, -1, 1) + cosh_a = b(torch.ops.aten.cosh.default, re) + cos_b = b(torch.ops.aten.cos.default, im) + sinh_a = b(torch.ops.aten.sinh.default, re) + sin_b = b(torch.ops.aten.sin.default, im) + out_re = b(torch.ops.aten.mul.Tensor, cosh_a, cos_b) + out_im = b(torch.ops.aten.mul.Tensor, sinh_a, sin_b) + out = self._inline_cat_re_im(b, out_re, out_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.tan.default) + def _rewrite_tan(self, node: Node) -> bool: + # tan(a+bi) = sin(2a)/(cos(2a)+cosh(2b)) + i*sinh(2b)/(cos(2a)+cosh(2b)) + inp = node.args[0] + with SubgraphBuilder(self.gm.graph, node) as b: + re = b(torch.ops.aten.select.int, inp, -1, 0) + im = b(torch.ops.aten.select.int, inp, -1, 1) + two_re = b(torch.ops.aten.mul.Tensor, re, 2.0) + two_im = b(torch.ops.aten.mul.Tensor, im, 2.0) + sin_2a = b(torch.ops.aten.sin.default, two_re) + cos_2a = b(torch.ops.aten.cos.default, two_re) + sinh_2b = b(torch.ops.aten.sinh.default, two_im) + cosh_2b = b(torch.ops.aten.cosh.default, two_im) + denom = b(torch.ops.aten.add.Tensor, cos_2a, cosh_2b) + out_re = b(torch.ops.aten.div.Tensor, sin_2a, denom) + out_im = b(torch.ops.aten.div.Tensor, sinh_2b, denom) + out = self._inline_cat_re_im(b, out_re, out_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.tanh.default) + def _rewrite_tanh(self, node: Node) -> bool: + # tanh(a+bi) = sinh(2a)/(cosh(2a)+cos(2b)) + i*sin(2b)/(cosh(2a)+cos(2b)) + inp = node.args[0] + with SubgraphBuilder(self.gm.graph, node) as b: + re = b(torch.ops.aten.select.int, inp, -1, 0) + im = b(torch.ops.aten.select.int, inp, -1, 1) + two_re = b(torch.ops.aten.mul.Tensor, re, 2.0) + two_im = b(torch.ops.aten.mul.Tensor, im, 2.0) + sinh_2a = b(torch.ops.aten.sinh.default, two_re) + cosh_2a = b(torch.ops.aten.cosh.default, two_re) + sin_2b = b(torch.ops.aten.sin.default, two_im) + cos_2b = b(torch.ops.aten.cos.default, two_im) + denom = b(torch.ops.aten.add.Tensor, cosh_2a, cos_2b) + out_re = b(torch.ops.aten.div.Tensor, sinh_2a, denom) + out_im = b(torch.ops.aten.div.Tensor, sin_2b, denom) + out = self._inline_cat_re_im(b, out_re, out_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.asinh.default) + def _rewrite_asinh(self, node: Node) -> bool: + # asinh(z) = log(z + sqrt(z² + 1)) + inp = node.args[0] + with SubgraphBuilder(self.gm.graph, node) as b: + re, im = self._inline_select_re_im(b, inp) + re2 = b(torch.ops.aten.mul.Tensor, re, re) + im2 = b(torch.ops.aten.mul.Tensor, im, im) + z2_re = b(torch.ops.aten.sub.Tensor, re2, im2) + re_im = b(torch.ops.aten.mul.Tensor, re, im) + z2_im = b(torch.ops.aten.mul.Tensor, re_im, 2.0) + w_re = b(torch.ops.aten.add.Scalar, z2_re, 1.0) # w = z²+1 + sq_re, sq_im = self._inline_complex_sqrt(b, w_re, z2_im) + sum_re = b(torch.ops.aten.add.Tensor, re, sq_re) + sum_im = b(torch.ops.aten.add.Tensor, im, sq_im) + log_re, log_im = self._inline_complex_log(b, sum_re, sum_im) + out = self._inline_cat_re_im(b, log_re, log_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.acosh.default) + def _rewrite_acosh(self, node: Node) -> bool: + # acosh(z) = log(z + sqrt(z² - 1)) + inp = node.args[0] + with SubgraphBuilder(self.gm.graph, node) as b: + re, im = self._inline_select_re_im(b, inp) + re2 = b(torch.ops.aten.mul.Tensor, re, re) + im2 = b(torch.ops.aten.mul.Tensor, im, im) + z2_re = b(torch.ops.aten.sub.Tensor, re2, im2) + re_im = b(torch.ops.aten.mul.Tensor, re, im) + z2_im = b(torch.ops.aten.mul.Tensor, re_im, 2.0) + w_re = b(torch.ops.aten.sub.Scalar, z2_re, 1.0) # w = z²-1 + sq_re, sq_im = self._inline_complex_sqrt(b, w_re, z2_im) + sum_re = b(torch.ops.aten.add.Tensor, re, sq_re) + sum_im = b(torch.ops.aten.add.Tensor, im, sq_im) + log_re, log_im = self._inline_complex_log(b, sum_re, sum_im) + out = self._inline_cat_re_im(b, log_re, log_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.atanh.default) + def _rewrite_atanh(self, node: Node) -> bool: + # atanh(z) = (1/2) * log((1+z) / (1-z)) + inp = node.args[0] + with SubgraphBuilder(self.gm.graph, node) as b: + re, im = self._inline_select_re_im(b, inp) + p_re = b(torch.ops.aten.add.Scalar, re, 1.0) # 1+re + q_re = b(torch.ops.aten.sub.Scalar, re, 1.0) # re-1 + neg_q_re = b(torch.ops.aten.neg.default, q_re) # 1-re + neg_im = b(torch.ops.aten.neg.default, im) + div_re, div_im = self._inline_complex_div(b, p_re, im, neg_q_re, neg_im) + log_re, log_im = self._inline_complex_log(b, div_re, div_im) + out_re = b(torch.ops.aten.mul.Tensor, log_re, 0.5) + out_im = b(torch.ops.aten.mul.Tensor, log_im, 0.5) + out = self._inline_cat_re_im(b, out_re, out_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.asin.default) + def _rewrite_asin(self, node: Node) -> bool: + # asin(z) = -i * log(iz + sqrt(1 - z²)) + inp = node.args[0] + with SubgraphBuilder(self.gm.graph, node) as b: + re, im = self._inline_select_re_im(b, inp) + iz_re = b(torch.ops.aten.neg.default, im) # iz = (-im, re) + re2 = b(torch.ops.aten.mul.Tensor, re, re) + im2 = b(torch.ops.aten.mul.Tensor, im, im) + z2_re = b(torch.ops.aten.sub.Tensor, re2, im2) + re_im = b(torch.ops.aten.mul.Tensor, re, im) + z2_im = b(torch.ops.aten.mul.Tensor, re_im, 2.0) + ones = b(torch.ops.aten.ones_like.default, z2_re) + w_re = b(torch.ops.aten.sub.Tensor, ones, z2_re) # 1-z² + w_im = b(torch.ops.aten.neg.default, z2_im) + sq_re, sq_im = self._inline_complex_sqrt(b, w_re, w_im) + sum_re = b(torch.ops.aten.add.Tensor, iz_re, sq_re) + sum_im = b(torch.ops.aten.add.Tensor, re, sq_im) # iz_im = re + log_re, log_im = self._inline_complex_log(b, sum_re, sum_im) + # -i*(log_re + i*log_im) = log_im + i*(-log_re) + out_im = b(torch.ops.aten.neg.default, log_re) + out = self._inline_cat_re_im(b, log_im, out_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.acos.default) + def _rewrite_acos(self, node: Node) -> bool: + # acos(z) = -i * log(z + i*sqrt(1 - z²)) + inp = node.args[0] + with SubgraphBuilder(self.gm.graph, node) as b: + re, im = self._inline_select_re_im(b, inp) + re2 = b(torch.ops.aten.mul.Tensor, re, re) + im2 = b(torch.ops.aten.mul.Tensor, im, im) + z2_re = b(torch.ops.aten.sub.Tensor, re2, im2) + re_im = b(torch.ops.aten.mul.Tensor, re, im) + z2_im = b(torch.ops.aten.mul.Tensor, re_im, 2.0) + ones = b(torch.ops.aten.ones_like.default, z2_re) + w_re = b(torch.ops.aten.sub.Tensor, ones, z2_re) # 1-z² + w_im = b(torch.ops.aten.neg.default, z2_im) + sq_re, sq_im = self._inline_complex_sqrt(b, w_re, w_im) + isq_re = b(torch.ops.aten.neg.default, sq_im) # i*sqrt = (-sq_im, sq_re) + sum_re = b(torch.ops.aten.add.Tensor, re, isq_re) + sum_im = b(torch.ops.aten.add.Tensor, im, sq_re) + log_re, log_im = self._inline_complex_log(b, sum_re, sum_im) + # -i*(log_re + i*log_im) = log_im + i*(-log_re) + out_im = b(torch.ops.aten.neg.default, log_re) + out = self._inline_cat_re_im(b, log_im, out_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.atan.default) + def _rewrite_atan(self, node: Node) -> bool: + # atan(z) = (i/2) * log((1-iz) / (1+iz)) + inp = node.args[0] + with SubgraphBuilder(self.gm.graph, node) as b: + re, im = self._inline_select_re_im(b, inp) + iz_re = b(torch.ops.aten.neg.default, im) # iz = (-im, re) + ones = b(torch.ops.aten.ones_like.default, re) + p_re = b(torch.ops.aten.sub.Tensor, ones, iz_re) # 1-iz + p_im = b(torch.ops.aten.neg.default, re) + q_re = b(torch.ops.aten.add.Tensor, ones, iz_re) # 1+iz + q_im = re # iz_im = re + div_re, div_im = self._inline_complex_div(b, p_re, p_im, q_re, q_im) + log_re, log_im = self._inline_complex_log(b, div_re, div_im) + # (i/2)*(log_re+i*log_im) = (-log_im/2) + i*(log_re/2) + out_re = b(torch.ops.aten.mul.Tensor, log_im, -0.5) + out_im = b(torch.ops.aten.mul.Tensor, log_re, 0.5) + out = self._inline_cat_re_im(b, out_re, out_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.pow.Tensor_Tensor) + def _rewrite_pow_tensor_tensor(self, node: Node) -> bool: + # z1**z2 = exp(z2 * log(z1)) + z1_inp, z2_inp = node.args[0], node.args[1] + with SubgraphBuilder(self.gm.graph, node) as b: + re1, im1 = self._inline_select_re_im(b, z1_inp) + re2 = b(torch.ops.aten.select.int, z2_inp, -1, 0) + im2 = b(torch.ops.aten.select.int, z2_inp, -1, 1) + log_re, log_im = self._inline_complex_log(b, re1, im1) + mul_re, mul_im = self._inline_complex_mul(b, re2, im2, log_re, log_im) + exp_re, exp_im = self._inline_complex_exp(b, mul_re, mul_im) + out = self._inline_cat_re_im(b, exp_re, exp_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.scalar_tensor.default) + def _rewrite_scalar_tensor(self, node: Node) -> bool: + # scalar_tensor(val, dtype=complex64) → scalar_tensor(0.0, float32) + if dict(node.kwargs).get("dtype") not in (torch.complex64, torch.complex128): + return False + with SubgraphBuilder(self.gm.graph, node) as b: + out = b(torch.ops.aten.scalar_tensor.default, 0.0) + out.kwargs = {"dtype": torch.float32} # type: ignore[assignment] + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + # ------------------------------------------------------------------ + # Shape-manipulation handlers + # + # All of these work on the same principle: in the [..., 2] real layout + # the trailing dimension stores real/imag. Dimension indices that refer + # to the *last* complex dimension (dim=-1) must be shifted by -1 to + # avoid touching or conflating with the trailing 2 dim. + # + # Rule: new_dim = dim - 1 if dim < 0 else dim + # ------------------------------------------------------------------ + + @_complex_unpacker( + torch.ops.aten.reshape.default, + torch.ops.aten.view.default, + torch.ops.aten._unsafe_view.default, + ) + def _rewrite_reshape_view(self, node: Node) -> bool: + # Append 2 to the target shape so the trailing real/imag dim is + # preserved after the reshape. E.g. complex [a,b] reshaped to [c] + # becomes float [a,b,2] reshaped to [c,2]. + inp = node.args[0] + new_shape = list(node.args[1]) + [2] + with SubgraphBuilder(self.gm.graph, node) as b: + out = b(node.target, inp, new_shape) + out.meta["is_complex_layout"] = True + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.flatten.using_ints) + def _rewrite_flatten(self, node: Node) -> bool: + inp = node.args[0] + start_dim = node.args[1] if len(node.args) > 1 else 0 + end_dim = node.args[2] if len(node.args) > 2 else -1 + # Shift negative dims by -1 so end_dim=-1 (last complex dim) maps to + # the second-to-last dim in the real layout, keeping the trailing 2 intact. + new_start = start_dim - 1 if start_dim < 0 else start_dim + new_end = end_dim - 1 if end_dim < 0 else end_dim + with SubgraphBuilder(self.gm.graph, node) as b: + out = b(torch.ops.aten.flatten.using_ints, inp, new_start, new_end) + out.meta["is_complex_layout"] = True + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.unsqueeze.default) + def _rewrite_unsqueeze(self, node: Node) -> bool: + inp = node.args[0] + dim = node.args[1] + # Negative dims: shift by -1 so dim=-1 inserts *before* the trailing + # real/imag dim rather than *after* it. + new_dim = dim - 1 if dim < 0 else dim + with SubgraphBuilder(self.gm.graph, node) as b: + out = b(torch.ops.aten.unsqueeze.default, inp, new_dim) + out.meta["is_complex_layout"] = True + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.squeeze.dim, torch.ops.aten.squeeze.dims) + def _rewrite_squeeze_dim(self, node: Node) -> bool: + inp = node.args[0] + # squeeze.dim(inp, int) vs squeeze.dims(inp, List[int]) + is_multi = node.target == torch.ops.aten.squeeze.dims + raw_dim = node.args[1] + dims_list = list(raw_dim) if is_multi else [raw_dim] + # Shift negative dims so that complex dim=-1 (last complex dim) maps to + # real-layout dim=-2 (second-to-last), keeping the trailing real/imag dim. + # A squeeze on a valid complex dim can never accidentally hit the trailing + # 2 dim: for rank-n complex, valid dims are [-n, n-1]; after the shift + # they land in [-n-1, n-1], all safely before the trailing dim at index n. + new_dims = [d - 1 if d < 0 else d for d in dims_list] + with SubgraphBuilder(self.gm.graph, node) as b: + if is_multi: + out = b(torch.ops.aten.squeeze.dims, inp, new_dims) + else: + out = b(torch.ops.aten.squeeze.dim, inp, new_dims[0]) + out.meta["is_complex_layout"] = True + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.cat.default) + def _rewrite_cat(self, node: Node) -> bool: + tensors = node.args[0] + dim = node.args[1] if len(node.args) > 1 else 0 + # Negative dims: shift by -1 to avoid concatenating into the trailing + # real/imag dim. E.g. cat(tensors, dim=-1) on complex tensors should + # concat along the last *complex* dimension, not the trailing 2. + new_dim = dim - 1 if dim < 0 else dim + with SubgraphBuilder(self.gm.graph, node) as b: + out = b(torch.ops.aten.cat.default, list(tensors), new_dim) + out.meta["is_complex_layout"] = True + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.stack.default) + def _rewrite_stack(self, node: Node) -> bool: + tensors = node.args[0] + dim = node.args[1] if len(node.args) > 1 else 0 + # Negative dims: shift by -1 so a new dim inserted at position -1 lands + # before the trailing real/imag dim, not after it. + new_dim = dim - 1 if dim < 0 else dim + with SubgraphBuilder(self.gm.graph, node) as b: + out = b(torch.ops.aten.stack.default, list(tensors), new_dim) + out.meta["is_complex_layout"] = True + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.t.default) + def _rewrite_t(self, node: Node) -> bool: + # t() is the 2-D transpose shorthand. After unpacking, the tensor is + # 3-D ([..., 2]) so t() would raise. Replace with transpose(0, 1). + inp = node.args[0] + with SubgraphBuilder(self.gm.graph, node) as b: + out = b(torch.ops.aten.transpose.int, inp, 0, 1) + node.replace_all_uses_with(out) + out.meta["is_complex_layout"] = True + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.transpose.int) + def _rewrite_transpose(self, node: Node) -> bool: + inp = node.args[0] + dim0, dim1 = node.args[1], node.args[2] + # Get the original complex rank from node metadata (not yet re-propagated). + node_val = node.meta.get("val", None) + if node_val is None or not hasattr(node_val, "shape"): + logger.warning( + "transpose on complex tensor '%s': no metadata, skipping rewrite. " + "This may produce incorrect results or fail TRT compilation.", + node.name, + ) + return False + n = len(node_val.shape) # original complex rank + # Normalize dims to absolute indices in [0, n-1]: same indices are valid + # in the real layout too (both are before the trailing 2). + abs0 = dim0 % n if dim0 < 0 else dim0 + abs1 = dim1 % n if dim1 < 0 else dim1 + with SubgraphBuilder(self.gm.graph, node) as b: + out = b(torch.ops.aten.transpose.int, inp, abs0, abs1) + node.replace_all_uses_with(out) + out.meta["is_complex_layout"] = True + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.select.int) + def _rewrite_select(self, node: Node) -> bool: + # select.int on a complex tensor selects along a batch/sequence dim. + # In the real layout the trailing dim encodes real/imag, so negative + # dim indices must be shifted by -1 to avoid selecting from that dim. + inp = node.args[0] + dim = node.args[1] + idx = node.args[2] + if dim >= 0: + return False # non-negative dims are unchanged in real layout + new_dim = dim - 1 + with SubgraphBuilder(self.gm.graph, node) as b: + out = b(torch.ops.aten.select.int, inp, new_dim, idx) + out.meta["is_complex_layout"] = True + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.slice.Tensor) + def _rewrite_slice(self, node: Node) -> bool: + inp = node.args[0] + dim = node.args[1] if len(node.args) > 1 else 0 + start = node.args[2] if len(node.args) > 2 else None + end = node.args[3] if len(node.args) > 3 else None + step = node.args[4] if len(node.args) > 4 else 1 + if dim >= 0: + return False # non-negative dims are safe in real layout + new_dim = dim - 1 + with SubgraphBuilder(self.gm.graph, node) as b: + out = b(torch.ops.aten.slice.Tensor, inp, new_dim, start, end, step) + out.meta["is_complex_layout"] = True + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker( + torch.ops.aten.split.Tensor, + torch.ops.aten.split_with_sizes.default, + ) + def _rewrite_split(self, node: Node) -> bool: + inp = node.args[0] + size_or_sizes = node.args[1] + dim = node.args[2] if len(node.args) > 2 else 0 + new_dim = dim - 1 if dim < 0 else dim + with SubgraphBuilder(self.gm.graph, node) as b: + out = b(node.target, inp, size_or_sizes, new_dim) + out.meta["is_complex_layout"] = True + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.chunk.default) + def _rewrite_chunk(self, node: Node) -> bool: + inp = node.args[0] + chunks = node.args[1] + dim = node.args[2] if len(node.args) > 2 else 0 + new_dim = dim - 1 if dim < 0 else dim + with SubgraphBuilder(self.gm.graph, node) as b: + out = b(torch.ops.aten.chunk.default, inp, chunks, new_dim) + out.meta["is_complex_layout"] = True + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.expand.default) + def _rewrite_expand(self, node: Node) -> bool: + # expand(input, size) — size must include the trailing real/imag dim. + # Append 2 to the size list. Negative sizes (-1 = keep dim) are left as-is; + # only the trailing 2 is appended for the real/imag encoding dim. + inp = node.args[0] + size = list(node.args[1]) + new_size = size + [2] + with SubgraphBuilder(self.gm.graph, node) as b: + out = b(torch.ops.aten.expand.default, inp, new_size) + out.meta["is_complex_layout"] = True + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + # ------------------------------------------------------------------ + # Matrix-multiplication handlers + # + # Complex mm: (A+iB)(C+iD) = (AC-BD) + i(AD+BC) — 4 real matmuls. + # ------------------------------------------------------------------ + + def _inline_complex_mm_op( + self, + b: "SubgraphBuilder", + matmul_op: object, + x: Node, + y: Node, + x_was_complex: bool, + y_was_complex: bool, + ) -> "Tuple[Node, Node]": + """Emit real/imag components of a complex matmul using *matmul_op*.""" + if x_was_complex and y_was_complex: + x_re = b(torch.ops.aten.select.int, x, -1, 0) + x_im = b(torch.ops.aten.select.int, x, -1, 1) + y_re = b(torch.ops.aten.select.int, y, -1, 0) + y_im = b(torch.ops.aten.select.int, y, -1, 1) + ac = b(matmul_op, x_re, y_re) + bd = b(matmul_op, x_im, y_im) + ad = b(matmul_op, x_re, y_im) + bc = b(matmul_op, x_im, y_re) + out_re = b(torch.ops.aten.sub.Tensor, ac, bd) + out_im = b(torch.ops.aten.add.Tensor, ad, bc) + elif x_was_complex: + # x is complex, y is real: (A+iB)*C = AC + iBC + x_re = b(torch.ops.aten.select.int, x, -1, 0) + x_im = b(torch.ops.aten.select.int, x, -1, 1) + out_re = b(matmul_op, x_re, y) + out_im = b(matmul_op, x_im, y) + else: + # x is real, y is complex: A*(C+iD) = AC + iAD + y_re = b(torch.ops.aten.select.int, y, -1, 0) + y_im = b(torch.ops.aten.select.int, y, -1, 1) + out_re = b(matmul_op, x, y_re) + out_im = b(matmul_op, x, y_im) + return out_re, out_im + + def _is_complex_layout_node(self, n: Node) -> bool: + """True if *n* is in real [..., 2] layout representing a complex tensor. + + All complex nodes are annotated with node.meta["is_complex_layout"] = True + during the detection phase (or by each rewrite handler as it emits new + nodes), so this is a direct metadata lookup — no shape heuristics needed. + """ + return n.meta.get("is_complex_layout", False) + + @_complex_unpacker(torch.ops.aten.mm.default) + def _rewrite_mm(self, node: Node) -> bool: + if not node.meta.get("is_complex_layout", False): + return False + x, y = node.args[0], node.args[1] + x_c = self._is_complex_layout_node(x) + y_c = self._is_complex_layout_node(y) + if not x_c and not y_c: + return False + with SubgraphBuilder(self.gm.graph, node) as b: + out_re, out_im = self._inline_complex_mm_op( + b, torch.ops.aten.mm.default, x, y, x_c, y_c + ) + out = self._inline_cat_re_im(b, out_re, out_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.bmm.default) + def _rewrite_bmm(self, node: Node) -> bool: + if not node.meta.get("is_complex_layout", False): + return False + x, y = node.args[0], node.args[1] + x_c = self._is_complex_layout_node(x) + y_c = self._is_complex_layout_node(y) + if not x_c and not y_c: + return False + with SubgraphBuilder(self.gm.graph, node) as b: + out_re, out_im = self._inline_complex_mm_op( + b, torch.ops.aten.bmm.default, x, y, x_c, y_c + ) + out = self._inline_cat_re_im(b, out_re, out_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.matmul.default) + def _rewrite_matmul(self, node: Node) -> bool: + if not node.meta.get("is_complex_layout", False): + return False + x, y = node.args[0], node.args[1] + x_c = self._is_complex_layout_node(x) + y_c = self._is_complex_layout_node(y) + if not x_c and not y_c: + return False + with SubgraphBuilder(self.gm.graph, node) as b: + out_re, out_im = self._inline_complex_mm_op( + b, torch.ops.aten.matmul.default, x, y, x_c, y_c + ) + out = self._inline_cat_re_im(b, out_re, out_im) + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + + @_complex_unpacker(torch.ops.aten.where.self) + def _rewrite_where(self, node: Node) -> bool: + # where.self: unsqueeze mask and optionally expand true-branch for complex layout. + if len(node.args) != 3: + return False + node_val = node.meta.get("val", None) + if node_val is None or not hasattr(node_val, "dtype"): + return False + if node_val.dtype not in (torch.complex64, torch.complex128): + return False + mask_node, true_node, other_node = node.args + target_shape = list(node_val.shape) + [2] + with SubgraphBuilder(self.gm.graph, node) as b: + mask_unsq = b(torch.ops.aten.unsqueeze.default, mask_node, -1) + true_arg = true_node + if isinstance(true_node, torch.fx.Node): + true_val = true_node.meta.get("val", None) + if ( + true_val is not None + and hasattr(true_val, "shape") + and list(true_val.shape) == [2] + ): + true_arg = b(torch.ops.aten.expand.default, true_node, target_shape) + out = b(torch.ops.aten.where.self, mask_unsq, true_arg, other_node) + out.meta["is_complex_layout"] = True + node.replace_all_uses_with(out) + self.gm.graph.erase_node(node) + return True + def rewrite_subgraph_nodes(self, subgraphs: List[ComplexSubGraphInfo]) -> None: modified = False + # Detect the existing FakeTensorMode from the graph's placeholders + # *before* any rewrites. We pass this to replace_input_node so that + # new placeholder fake tensors are created under the same mode as the + # rest of the graph. Using a fresh FakeTensorMode would cause "mode + # mismatch" assertions under torch.compile (where a mode is already + # active) and would lose SymInt information for torch.export graphs. + detected_fake_mode = torch._export.utils._detect_fake_mode_from_gm(self.gm) + + # Annotate all nodes that have complex dtype BEFORE any rewriting. + # We stamp node.meta["is_complex_layout"] = True on every complex-dtype node + # so that later passes can reliably distinguish real [..., 2] layout tensors + # (created by this rewriter) from coincidentally-shaped real tensors. + # This is stable across rewrites: after replace_input_node changes dtype to + # float32, is_complex_dtype() would return False, but the metadata flag persists. + detector = ComplexOpDetector() + for subgraph in subgraphs: + for node in subgraph.input_nodes: + if detector.is_complex_dtype(node): + node.meta["is_complex_layout"] = True + for node in subgraph.subgraph_nodes: + if detector.is_complex_dtype(node): + node.meta["is_complex_layout"] = True + + # _DISPATCH maps op -> unbound method; bind self here once per call. + dispatch = {op: method.__get__(self) for op, method in self._DISPATCH.items()} + + logger.debug( + "complex_graph_rewrite begin subgraphs=%d nodes=%s", + len(subgraphs), + [n.name for s in subgraphs for n in s.subgraph_nodes], + ) + for subgraph in subgraphs: for input_node in subgraph.input_nodes: - logger.debug(f"Input node rewrite: {input_node.name}") if input_node.op not in ("call_function"): - self.replace_input_node(input_node) + # Only rewrite inputs that are themselves complex — real inputs + # to complex-output ops (e.g. r, theta for polar) must NOT be + # renamed to *_unpacked_complex. + if not detector.is_complex_dtype(input_node): + continue + self.replace_input_node(input_node, fake_mode=detected_fake_mode) for node in subgraph.subgraph_nodes: - logger.debug(f"Subgraph Node rewrite: {node.name}") - if node.target == torch.ops.aten.view_as_complex.default: - node.replace_all_uses_with(node.args[0]) - self.gm.graph.erase_node(node) - elif node.target == torch.ops.aten.mul.Tensor: - # this is complex mul where inputs = a+ib and output = c+id. - # complex mul returns (ac - bd) + (ad + bc)i - # which is then view_as_real as (ac-bd), (ad+bc) stacked along the last dimension with last dimension size 2 - x_placeholder_or_func = ( - True if node.args[0].op != "get_attr" else False - ) - y_placeholder_or_func = ( - True if node.args[1].op != "get_attr" else False - ) - - replaced_nodes = [] - original_mul, replacement = complex_mul_replacement( - x_placeholder_or_func, y_placeholder_or_func + # Skip nodes that were already erased by a previous pattern replacement + if node.graph is not self.gm.graph: + continue + handler = dispatch.get(node.target) + if handler is not None: + if handler(node): + modified = True + elif node.target in _ELEMENTWISE_SAFE: + logger.debug(" pass-through %s (elementwise-safe)", node.name) + else: + logger.warning( + "Complex op '%s' has no explicit rewrite rule. " + "Wrapping with view_as_complex/view_as_real so the op " + "receives genuine complex tensors and TRT graph-breaks " + "around it into a PyTorch fallback block.", + node.target, ) - - def match_complex_mul( # type: ignore[no-untyped-def] - match: torch.fx.subgraph_rewriter.Match, - original_graph, - pattern_graph, - ) -> bool: - for original_node in match.nodes_map.values(): - if original_node.name == node.name: - return True - return False - - nodes = torch.fx.subgraph_rewriter.replace_pattern_with_filters( - self.gm, - original_mul, - replacement, - match_filters=[match_complex_mul], - ignore_literals=True, + # Generic fallback: for each arg that is a real-layout + # complex node, insert view_as_complex before the node so + # the op sees genuine complex-dtype tensors (correct + # semantics); then, if the node itself originally produced + # a complex-layout output, wrap it with view_as_real and + # redirect all users back onto the real [..., 2] path. + # TRT has no complex-dtype support so it will refuse to + # compile the view_as_complex/op/view_as_real cluster, + # causing the partitioner to create a PyTorch fallback + # block around it — exactly the graph break we want. + new_args = list(node.args) + any_complexified = False + for i, arg in enumerate(node.args): + if not isinstance(arg, torch.fx.Node): + continue + if not arg.meta.get("is_complex_layout", False): + continue + # Skip when val is a list/tuple (e.g. a residual split + # output that wasn't caught by the getitem pass-through). + # Allow None (newly created node without metadata yet). + arg_val = arg.meta.get("val") + if isinstance(arg_val, (list, tuple)): + continue + with self.gm.graph.inserting_before(node): + vc = self.gm.graph.call_function( + torch.ops.aten.view_as_complex.default, + (arg,), + ) + # view_as_complex produces a genuine complex node — + # do NOT set is_complex_layout; it is not a + # real-layout stand-in. + new_args[i] = vc + any_complexified = True + if any_complexified: + node.args = tuple(new_args) + if any_complexified and node.meta.get("is_complex_layout", False): + with self.gm.graph.inserting_after(node): + vr = self.gm.graph.call_function( + torch.ops.aten.view_as_real.default, + (node,), + ) + vr.meta["is_complex_layout"] = True + node.replace_all_uses_with( + vr, + delete_user_cb=lambda user: user is not vr, + ) + modified = True + if modified: + # After rewriting complex ops, any view_as_real node that now receives a + # real tensor must be erased. The subgraph_rewriter replaces the original + # complex mul with a cat of real/imag parts; view_as_real on that result + # is invalid. We detect this by checking whether the input to view_as_real + # is no longer complex-typed (its meta val dtype is real, or has no val yet + # but its target is the real-arithmetic cat output). + for node in list(self.gm.graph.nodes): + if node.target != torch.ops.aten.view_as_real.default: + continue + inp = node.args[0] + if not isinstance(inp, torch.fx.Node): + continue + inp_val = inp.meta.get("val", None) + # If meta is available and dtype is real, erase view_as_real + is_real_input = ( + inp_val is not None + and hasattr(inp_val, "dtype") + and inp_val.dtype not in {torch.complex64, torch.complex128} + ) + # If meta not yet propagated, use the target as a heuristic: + # the real-arithmetic replacement ends with aten.cat.default + if inp_val is None: + is_real_input = inp.target == torch.ops.aten.cat.default + if is_real_input: + inp_desc = ( + f"{inp.name}[{tuple(inp_val.shape)},{inp_val.dtype}]" + if inp_val is not None and hasattr(inp_val, "shape") + else inp.name ) - replaced_nodes += nodes - modified = True - elif node.target == torch.ops.aten.view_as_real.default: - node.replace_all_uses_with(node.args[0]) - self.gm.graph.erase_node(node) - else: - logger.debug(f"Unsupported node target: {node.target}") logger.debug( - "This complex subgraphnode type does not need to replaced" + " erase view_as_real %s (input %s is already real)", + node.name, + inp_desc, ) - - if modified: - self.propagate_metadata() + node.replace_all_uses_with(inp) + self.gm.graph.erase_node(node) + logger.debug("complex_graph_rewrite propagating metadata") + self.propagate_metadata(detected_fake_mode) self.gm.graph.lint() self.gm.recompile() + logger.debug("complex_graph_rewrite done") - def propagate_metadata(self) -> None: - fake_inputs = [] - from torch._subclasses.fake_tensor import FakeTensorMode + def propagate_metadata( + self, existing_fake_mode: Optional[FakeTensorMode] = None + ) -> None: + """Re-propagate FakeTensor metadata after graph rewrites via FakeTensorProp. + + Uses *existing_fake_mode* (detected from the graph's placeholder fake + tensors) when available. This ensures the propagation mode matches the + mode under which the graph was originally traced — critical for both + torch.compile (where a FakeTensorMode is already active on the thread) + and torch.export (where we must preserve the ShapeEnv / SymInt ranges). + + Falls back to a fresh FakeTensorMode only for plain FX graphs that have + no fake tensor metadata at all. + """ from torch.fx.passes.fake_tensor_prop import FakeTensorProp + fake_inputs = [] for node in self.gm.graph.nodes: if node.op == "placeholder": if "val" in node.meta: - with FakeTensorMode(allow_non_fake_inputs=True): - fake_val = node.meta["val"] - fake_inputs.append( - fake_val.to("cuda") - if fake_val.device.type == "cuda" - else fake_val - ) + fake_val = node.meta["val"] + fake_inputs.append( + fake_val.to("cuda") + if fake_val.device.type == "cuda" + else fake_val + ) else: fake_tensor = torch.empty( [s if s != 0 else 1 for s in node.meta["tensor_meta"].shape], @@ -263,9 +1817,13 @@ def propagate_metadata(self) -> None: device=node.meta["tensor_meta"].device, ) fake_inputs.append(fake_tensor) - FakeTensorProp( - self.gm, mode=FakeTensorMode(allow_non_fake_inputs=True) - ).propagate(*fake_inputs) + + prop_mode = ( + existing_fake_mode + if existing_fake_mode is not None + else FakeTensorMode(allow_non_fake_inputs=True) + ) + FakeTensorProp(self.gm, mode=prop_mode).propagate(*fake_inputs) def extract_real_imag(input, placeholder_or_func: bool = True): # type: ignore @@ -337,6 +1895,108 @@ def replacement(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return (original_mul, replacement) +def complex_div_replacement( + x_placeholder_or_func: bool = True, y_placeholder_or_func: bool = True +) -> Tuple[ + Callable[[torch.Tensor, torch.Tensor], torch.Tensor], + Callable[[torch.Tensor, torch.Tensor], torch.Tensor], +]: + """Constructs the original and replacement functions for complex division. + + (a+bi)/(c+di) = ((ac+bd) + (bc-ad)i) / (c²+d²) + """ + + def original_div(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return torch.ops.aten.div.Tensor(x, y) + + def replacement(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + x_real, x_imag = extract_real_imag(x, x_placeholder_or_func) + y_real, y_imag = extract_real_imag(y, y_placeholder_or_func) + + denom = torch.ops.aten.add.Tensor( + torch.ops.aten.mul.Tensor(y_real, y_real), + torch.ops.aten.mul.Tensor(y_imag, y_imag), + ) + real = torch.ops.aten.div.Tensor( + torch.ops.aten.add.Tensor( + torch.ops.aten.mul.Tensor(x_real, y_real), + torch.ops.aten.mul.Tensor(x_imag, y_imag), + ), + denom, + ) + imag = torch.ops.aten.div.Tensor( + torch.ops.aten.sub.Tensor( + torch.ops.aten.mul.Tensor(x_imag, y_real), + torch.ops.aten.mul.Tensor(x_real, y_imag), + ), + denom, + ) + + return torch.ops.aten.cat.default( + [ + torch.ops.aten.unsqueeze.default(real, -1), + torch.ops.aten.unsqueeze.default(imag, -1), + ], + -1, + ) + + return (original_div, replacement) + + +def _get_complex_output_indices(gm: GraphModule) -> List[int]: + """Return indices of output nodes that have complex dtype, before rewriting.""" + complex_dtypes = {torch.complex64, torch.complex128} + output_node = next((n for n in reversed(gm.graph.nodes) if n.op == "output"), None) + if output_node is None: + return [] + # output args is a tuple of the return values + outputs = output_node.args[0] + if not isinstance(outputs, (list, tuple)): + outputs = (outputs,) + indices = [] + for i, out in enumerate(outputs): + if isinstance(out, torch.fx.Node) and "val" in out.meta: + val = out.meta["val"] + if hasattr(val, "dtype") and val.dtype in complex_dtypes: + indices.append(i) + return indices + + +def _get_complex_input_names(gm: GraphModule) -> List[str]: + """Return the original names of placeholder nodes that have complex dtype, before rewriting. + + complex_graph_detection renames complex placeholders from 'name' to 'name_unpacked_complex' + and changes their dtype to float. This captures the original names so the post-partition + pass can insert view_as_real at the graph input boundary. + """ + complex_dtypes = {torch.complex64, torch.complex128} + names = [] + for node in gm.graph.nodes: + if node.op != "placeholder": + continue + val = node.meta.get("val", None) + if val is not None and hasattr(val, "dtype") and val.dtype in complex_dtypes: + names.append(node.name) + return names + + +def _get_complex_input_dtypes(gm: GraphModule) -> dict: + """Return a mapping of placeholder name -> complex dtype for complex-dtype inputs. + + Used by the post-partition boundary pass to know which inputs were complex128 + so it can insert float32 casts when truncate_double=True. + """ + complex_dtypes = {torch.complex64, torch.complex128} + dtypes = {} + for node in gm.graph.nodes: + if node.op != "placeholder": + continue + val = node.meta.get("val", None) + if val is not None and hasattr(val, "dtype") and val.dtype in complex_dtypes: + dtypes[node.name] = val.dtype + return dtypes + + # This lowering pass is used to detect and rewrite complex subgraphs in the graph def complex_graph_detection( gm: GraphModule, settings: CompilationSettings @@ -350,10 +2010,20 @@ def complex_graph_detection( Returns: The modified GraphModule with complex subgraphs rewritten """ + # Capture I/O signature before rewriting — used post-partition to restore + # the complex tensor interface at the graph boundaries. + gm.meta["complex_output_indices"] = _get_complex_output_indices(gm) + gm.meta["complex_input_names"] = _get_complex_input_names(gm) + gm.meta["complex_input_dtypes"] = _get_complex_input_dtypes(gm) + if gm.meta["complex_output_indices"]: + logger.debug( + f"Complex output indices captured: {gm.meta['complex_output_indices']}" + ) + if gm.meta["complex_input_names"]: + logger.debug(f"Complex input names captured: {gm.meta['complex_input_names']}") + complex_op_detector = ComplexOpDetector() - complex_subgraphs = complex_op_detector.find_complex_op_subgraphs( - gm, anchor_target=torch.ops.aten.view_as_real.default - ) + complex_subgraphs = complex_op_detector.find_all_complex_subgraphs(gm) for subgraph in complex_subgraphs: logger.debug(f"Complex subgraph info: {subgraph}") complex_graph_rewriter = ComplexGraphRewriter(gm, settings.truncate_double) diff --git a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py index 72d0be42c7..04c8c50dbe 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py @@ -22,6 +22,9 @@ from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( ConverterRegistry, ) +from torch_tensorrt.dynamo.partitioning._global_partitioner import ( + TorchTensorRTOperatorSupport, +) logger = logging.getLogger(__name__) @@ -42,6 +45,14 @@ def is_node_supported( ) -> bool: node_name = ConverterRegistry.qualified_name_or_str(node.target) + if TorchTensorRTOperatorSupport._has_complex_dtype(node): + # Complex-dtype tensors are not supported by TensorRT; force PyTorch fallback + if not node.is_impure(): + self.unsupported_operators[node_name] = ( + self.unsupported_operators.get(node_name, 0) + 1 + ) + return False + if ( (node in CONVERTERS or node.op == "get_attr") and node_name not in self.torch_executed_ops diff --git a/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py index 707497b227..8d02076607 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py @@ -1,5 +1,5 @@ import logging -from typing import Collection, Dict, List, Mapping, Optional, Sequence, Tuple +from typing import Collection, Dict, List, Mapping, Optional, Sequence, Tuple, Set import torch from torch.fx.graph_module import GraphModule @@ -144,11 +144,41 @@ def __init__( self.unsupported_operators: Dict[str, int] = {} self.torch_executed_ops: Collection[Target] = torch_executed_ops + @staticmethod + def _has_complex_dtype(node: torch.fx.Node) -> bool: + """Return True if the node output or any of its tensor inputs is complex-dtype. + + TensorRT has no native complex-type support. Any node that produces or + consumes a complex tensor must run in the PyTorch fallback so the graph + breaks naturally around it. + """ + _COMPLEX = {torch.complex64, torch.complex128} + + def _dtype(n: torch.fx.Node) -> Optional[torch.dtype]: + val = n.meta.get("val") + return getattr(val, "dtype", None) if val is not None else None + + if _dtype(node) in _COMPLEX: + return True + for arg in node.all_input_nodes: + if _dtype(arg) in _COMPLEX: + return True + return False + def is_node_supported( self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node ) -> bool: node_name = ConverterRegistry.qualified_name_or_str(node.target) + if self._has_complex_dtype(node): + # Complex-dtype tensors are not supported by TensorRT; force PyTorch fallback + # so the graph breaks around the complex cluster inserted by complex_graph_detection. + if not node.is_impure(): + self.unsupported_operators[node_name] = ( + self.unsupported_operators.get(node_name, 0) + 1 + ) + return False + if ( (node in CONVERTERS or node.op == "get_attr") and node_name not in self.torch_executed_ops diff --git a/py/torch_tensorrt/dynamo/partitioning/common.py b/py/torch_tensorrt/dynamo/partitioning/common.py index 3a250085f1..98b478db77 100644 --- a/py/torch_tensorrt/dynamo/partitioning/common.py +++ b/py/torch_tensorrt/dynamo/partitioning/common.py @@ -6,7 +6,11 @@ from torch.fx.experimental.proxy_tensor import unset_fake_temporarily from torch_tensorrt._Input import Input -from torch_tensorrt.dynamo.utils import contains_sym_int, extract_var_range_info +from torch_tensorrt.dynamo.utils import ( + COMPLEX_TO_REAL_DTYPE, + contains_sym_int, + extract_var_range_info, +) logger = logging.getLogger(__name__) @@ -85,6 +89,17 @@ def get_input( """ Based on type of dimensions in the input_shape, construct regular or dynamic shaped inputs """ + if dtype in COMPLEX_TO_REAL_DTYPE: + real_dtype = COMPLEX_TO_REAL_DTYPE[dtype] + real_shape = torch.Size(list(input_shape) + [2]) + logger.info( + f"Input '{name}' has complex dtype {dtype}. TensorRT does not support complex " + f"tensors natively; it will be implicitly unpacked to a real tensor of shape " + f"{real_shape} and dtype {real_dtype} (last dim = [real, imag])." + ) + dtype = real_dtype + input_shape = real_shape + if contains_sym_int(input_shape): return construct_dynamic_input( input_shape, diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 0de257f7c6..5c797a3940 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -27,6 +27,8 @@ from torch._subclasses.fake_tensor import FakeTensor from torch.fx.experimental.proxy_tensor import unset_fake_temporarily from torch.utils._sympy.numbers import int_oo + +from packaging import version from torch_tensorrt._Device import Device from torch_tensorrt._enums import dtype from torch_tensorrt._features import ENABLED_FEATURES @@ -37,8 +39,6 @@ from torch_tensorrt.dynamo._engine_cache import BaseEngineCache from torch_tensorrt.dynamo._settings import CompilationSettings -from packaging import version - from .types import TRTDataType logger = logging.getLogger(__name__) @@ -99,6 +99,12 @@ class Frameworks(Enum): } +COMPLEX_TO_REAL_DTYPE: Dict[torch.dtype, torch.dtype] = { + torch.complex64: torch.float32, + torch.complex128: torch.float64, +} + + def unified_dtype_converter( dtype: Union[TRTDataType, torch.dtype, np.dtype], to: Frameworks ) -> Union[np.dtype, torch.dtype, TRTDataType]: @@ -313,6 +319,20 @@ def prepare_inputs( return inputs elif isinstance(inputs, (torch.Tensor, int, float, bool)): + if isinstance(inputs, torch.Tensor) and inputs.is_complex(): + # Complex tensors are lowered to real tensors with an extra last + # dimension of size 2 (real, imag) by complex_graph_detection. + # Build an Input whose shape/dtype reflects the lowered representation + # while keeping the original complex tensor for tracing (torch.export + # needs the complex tensor to trace the model correctly). + real_view = torch.view_as_real(inputs.contiguous()) + inp = Input.from_tensor( + real_view, disable_memory_format_check=disable_memory_format_check + ) + # Restore the original complex tensor so dynamo_trace can export + # the model with the correct input dtype. + inp.torch_tensor = inputs + return inp return Input.from_tensor( torch.tensor(inputs), disable_memory_format_check=disable_memory_format_check, @@ -859,10 +879,13 @@ def get_output_dtypes(output: Any, truncate_double: bool = False) -> List[dtype] # Placeholder output (e.g. unused slot in flash attention return tuple) pass elif isinstance(output_meta, (FakeTensor, torch.Tensor)): - if truncate_double and output_meta.dtype == torch.float64: + out_dtype = output_meta.dtype + if out_dtype in COMPLEX_TO_REAL_DTYPE: + out_dtype = COMPLEX_TO_REAL_DTYPE[out_dtype] + if truncate_double and out_dtype == torch.float64: output_dtypes.append(dtype.float32) else: - output_dtypes.append(dtype._from(output_meta.dtype)) + output_dtypes.append(dtype._from(out_dtype)) elif isinstance(output_meta, torch.SymInt): output_dtypes.append(dtype.int64) elif "tensor_meta" in output.meta: diff --git a/pyproject.toml b/pyproject.toml index 47d18ed8fe..911e22d3e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,7 @@ lint = [ dev = [ {include-group = "lint"}, + {include-group = "test"}, "pre-commit>=2.20.0", "typos", "mypy", @@ -77,6 +78,7 @@ debug = [ test = [ "pytest", "pytest-xdist", + "pytest-forked>=1.6.0", "parameterized>=0.2.0", "expecttest==0.1.6", ] @@ -114,6 +116,7 @@ include-package-data = false [tool.pytest.ini_options] testpaths = ["tests/py"] +addopts = "-n auto --dist=loadfile" norecursedirs = [ "bazel-*", ".venv", diff --git a/tests/py/dynamo/hlo/__init__.py b/tests/py/dynamo/hlo/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/py/dynamo/hlo/test_complex_graph_break.py b/tests/py/dynamo/hlo/test_complex_graph_break.py new file mode 100644 index 0000000000..e48d749cef --- /dev/null +++ b/tests/py/dynamo/hlo/test_complex_graph_break.py @@ -0,0 +1,252 @@ +"""Tests for complex tensor graph-break behavior in torch-tensorrt. + +These tests verify that when a model contains complex tensor operations mixed with +ops that have no handler in the complex-lowering rewriter, the compiler: + + 1. Wraps the unsupported op with ``view_as_complex`` / ``view_as_real`` so it + receives genuine complex-dtype inputs and returns a real-layout output. + 2. TRT, which has no complex-dtype support, naturally graph-breaks around the + wrapped cluster and runs it as a PyTorch fallback block. + 3. The lowerable complex ops on both sides compile to TRT via + ``complex_graph_detection``. + 4. The overall model produces numerically correct results end-to-end. + +Background +---------- +``complex_graph_detection`` rewrites complex-dtype ATen ops to equivalent +real-arithmetic ops before TRT compilation. When an op is *not* registered +with ``@_complex_unpacker`` and is not in ``_ELEMENTWISE_SAFE`` the rewriter +inserts ``view_as_complex`` before each complex-layout input and +``view_as_real`` after the output, preserving correct semantics and letting +TRT's lack of complex support create the graph break automatically. + +``cumsum`` is used as the representative unsupported op: it has well-defined +PyTorch semantics on complex tensors but has no handler in the rewriter. +""" + +import pytest +import torch +import torch.nn as nn +import torch_tensorrt as torchtrt +from torch_tensorrt.dynamo.lowering.passes.complex_graph_rewrite import ( + complex_graph_detection, +) +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity + +try: + from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule + + _PYTHON_RUNTIME_AVAILABLE = True +except ImportError: # pragma: no cover + _PYTHON_RUNTIME_AVAILABLE = False + + +# --------------------------------------------------------------------------- +# Shared helpers +# --------------------------------------------------------------------------- + + +def _make_freqs(seq: int, dim: int, theta: float = 10000.0) -> torch.Tensor: + """Complex unit-magnitude frequency tensor on CUDA, shape ``(seq, dim//2)``.""" + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(seq, dtype=torch.float32) + freqs = torch.outer(t, freqs) + return torch.polar(torch.ones_like(freqs), freqs).cuda() + + +def _cossim_real(py_out: torch.Tensor, trt_out: torch.Tensor, tag: str) -> None: + """Assert cosine similarity > COSINE_THRESHOLD on a real-valued output.""" + assert not trt_out.is_complex(), f"{tag}: expected real output, got {trt_out.dtype}" + s = cosine_similarity(py_out.contiguous(), trt_out.contiguous()) + assert s > COSINE_THRESHOLD, f"{tag}: cosine sim {s:.4f} < {COSINE_THRESHOLD}" + + +def _count_trt_modules(mod: torch.nn.Module) -> int: + """Return the number of ``PythonTorchTensorRTModule`` submodules (-1 if unavailable).""" + if not _PYTHON_RUNTIME_AVAILABLE: + return -1 + return sum( + 1 for _, m in mod.named_modules() if isinstance(m, PythonTorchTensorRTModule) + ) + + +def _export_and_lower(model: nn.Module, inputs: tuple) -> torch.fx.GraphModule: + """Export model and apply complex_graph_detection lowering pass.""" + with torch.no_grad(): + ep = torch.export.export(model.eval(), inputs) + gm = ep.module() + complex_graph_detection(gm, CompilationSettings()) + return gm + + +# =========================================================================== +# Test 1 — unsupported op gets view_as_complex/view_as_real wrapper +# =========================================================================== + + +class ComplexMulThenCumsum(nn.Module): + """Complex mul (lowerable) followed by cumsum (no rewriter handler). + + After ``complex_graph_detection`` the rewriter cannot handle ``cumsum``. + It inserts ``view_as_complex`` before cumsum's input and ``view_as_real`` + after its output so the op runs in PyTorch with correct complex semantics + while TRT compiles the surrounding real-arithmetic blocks. + """ + + def forward(self, z: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + rotated = z * freqs # complex mul — lowered to real arithmetic by rewriter + accumulated = torch.cumsum(rotated, dim=0) # no handler → PyTorch fallback + return torch.view_as_real(accumulated).flatten(-2) + + +@pytest.mark.unit +def test_unsupported_op_gets_complexify_wrap() -> None: + """The rewriter wraps cumsum with view_as_complex/view_as_real. + + Structural check (no TRT required): + - After lowering, the graph contains ``view_as_complex`` immediately + before ``cumsum`` and ``view_as_real`` immediately after. + - The ``view_as_complex`` input is the real-layout output of the + rewritten complex mul — confirming it is a float32 ``(..., 2)`` node. + - The ``view_as_real`` output feeds the downstream flatten. + - The PyTorch cumsum receives a complex-dtype tensor (correct semantics). + """ + model = ComplexMulThenCumsum().eval().cuda() + z = _make_freqs(8, 64) + freqs = _make_freqs(8, 64) + + gm = _export_and_lower(model, (z, freqs)) + + nodes_by_target: dict = {} + for n in gm.graph.nodes: + nodes_by_target.setdefault(n.target, []).append(n) + + # view_as_complex must be present (inserted by the fallback wrapper) + assert ( + torch.ops.aten.view_as_complex.default in nodes_by_target + ), "Expected view_as_complex to be inserted before cumsum, but it was not found" + + # cumsum must still be present (it was NOT removed) + assert ( + torch.ops.aten.cumsum.default in nodes_by_target + ), "cumsum should remain in the graph (runs as PyTorch fallback)" + + # The view_as_complex output feeds directly into cumsum + vc_node = nodes_by_target[torch.ops.aten.view_as_complex.default][0] + cumsum_node = nodes_by_target[torch.ops.aten.cumsum.default][0] + assert ( + cumsum_node.args[0] is vc_node + ), f"cumsum's first arg should be the view_as_complex node, got {cumsum_node.args[0]}" + + # The view_as_complex input is a real-layout (is_complex_layout) node + vc_input = vc_node.args[0] + assert isinstance(vc_input, torch.fx.Node), "view_as_complex input must be a Node" + assert vc_input.meta.get( + "is_complex_layout", False + ), "view_as_complex input should be a real-layout complex node (is_complex_layout=True)" + + # view_as_real must follow cumsum + assert ( + torch.ops.aten.view_as_real.default in nodes_by_target + ), "Expected view_as_real to be inserted after cumsum, but it was not found" + vr_node = nodes_by_target[torch.ops.aten.view_as_real.default][0] + assert ( + vr_node.args[0] is cumsum_node + ), f"view_as_real's arg should be the cumsum node, got {vr_node.args[0]}" + + # After metadata propagation, cumsum receives a complex-dtype tensor + vc_val = vc_node.meta.get("val") + if vc_val is not None: + assert vc_val.dtype in ( + torch.complex64, + torch.complex128, + ), f"view_as_complex output should be complex, got {vc_val.dtype}" + + +# =========================================================================== +# Test 2 — lowerable ops TRT, unsupported op PyTorch (with complex input), +# lowerable ops TRT again; end-to-end numerical correctness +# =========================================================================== + + +class ComplexTwoTRTBlocksAroundCumsum(nn.Module): + """Two complex-rotation TRT blocks with cumsum (PyTorch) in between. + + Expected graph after ``complex_graph_detection``: + + [Block A — TRT] + z_real, freqs_real → re/im arithmetic for z * freqs → rotated_real + + [PyTorch fallback — complex inputs] + view_as_complex(rotated_real) → cumsum(complex) → view_as_real → acc_real + + [Block B — TRT] + acc_real, freqs_real → re/im arithmetic for acc * freqs → result_real + result_real → view_as_real substitute → flatten → output + """ + + def forward(self, z: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + # Block A: complex rotate — lowered to real arithmetic + rotated = z * freqs + + # Unsupported complex op — rewriter inserts view_as_complex/view_as_real; + # TRT graph-breaks here; cumsum runs in PyTorch on a complex tensor + accumulated = torch.cumsum(rotated, dim=0) + + # Block B: second complex rotate — lowered to real arithmetic + result = accumulated * freqs + return torch.view_as_real(result).flatten(-2) + + +@pytest.mark.unit +def test_complex_partial_lowering_with_graph_break() -> None: + """Lowerable complex ops compile to TRT; cumsum runs in PyTorch on complex input. + + Asserts: + 1. The compiled model is numerically correct (cosine sim > threshold). + 2. At least one ``PythonTorchTensorRTModule`` submodule exists — confirming + the lowerable complex ops were compiled to TRT, not all relegated to + PyTorch fallback. + 3. After lowering, cumsum receives a complex-dtype tensor (the + view_as_complex wrapper was inserted correctly). + """ + model = ComplexTwoTRTBlocksAroundCumsum().eval().cuda() + z = _make_freqs(8, 64) + freqs = _make_freqs(8, 64) + inputs = (z, freqs) + + # Structural check: verify cumsum gets a complex input after lowering + gm = _export_and_lower(model, inputs) + for n in gm.graph.nodes: + if n.target == torch.ops.aten.cumsum.default: + vc_val = n.args[0].meta.get("val") + if vc_val is not None: + assert vc_val.dtype in ( + torch.complex64, + torch.complex128, + ), f"cumsum should receive a complex tensor, got {vc_val.dtype}" + break + + # End-to-end: compile and verify numerical correctness + ep = torch.export.export(model, inputs) + trt_model = torchtrt.dynamo.compile( + ep, + inputs=inputs, + min_block_size=1, + pass_through_build_failures=True, + use_python_runtime=True, + ) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_real(py_out, trt_out, "complex_partial_lowering_with_graph_break") + + # Verify at least one TRT block was created for the lowerable complex ops + n_trt = _count_trt_modules(trt_model) + if n_trt >= 0: + assert n_trt >= 1, ( + f"Expected at least one TRT submodule (lowerable complex ops should " + f"compile to TRT) but found {n_trt}." + ) + + torch._dynamo.reset() diff --git a/tests/py/dynamo/hlo/test_complex_ops.py b/tests/py/dynamo/hlo/test_complex_ops.py new file mode 100644 index 0000000000..27d135c4f1 --- /dev/null +++ b/tests/py/dynamo/hlo/test_complex_ops.py @@ -0,0 +1,2070 @@ +""" +Numerical accuracy stress tests for complex tensor decomposition in torch-tensorrt. + +The complex_graph_detection lowering pass rewrites complex-dtype ops to equivalent +real-arithmetic ops before TRT compilation. These tests verify correctness across: + + - I/O boundaries: complex inputs, complex outputs, mixed real/complex I/O + - Internal subgraphs: complex ops entirely within a TRT block + - Operator coverage: mul, add, sub, abs, angle, conj, real/imag extraction, + gather/scatter (select, slice, index), reshape/view, cat/stack, where, + unsqueeze/squeeze, expand/broadcast, type casting + - Chains: multiple sequential complex ops + - Multiple complex tensors interacting in one graph + - Dynamic shapes: batch and seq_len as symbolic dims + +All tests compare PyTorch (CPU/CUDA reference) vs TRT compiled output via +cosine similarity > COSINE_THRESHOLD on both real and imaginary parts. +""" + +import pytest +import torch +import torch.nn as nn +import torch_tensorrt as torchtrt +from torch.export import Dim +from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_freqs(seq: int, dim: int, theta: float = 10000.0) -> torch.Tensor: + """Complex unit-magnitude frequency tensor, shape (seq, dim//2).""" + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(seq, dtype=torch.float32) + freqs = torch.outer(t, freqs) + return torch.polar(torch.ones_like(freqs), freqs).cuda() + + +def _cossim_complex(py_out: torch.Tensor, trt_out: torch.Tensor, tag: str) -> None: + """Assert cosine similarity on real and imaginary parts separately.""" + assert trt_out.is_complex(), f"{tag}: expected complex output, got {trt_out.dtype}" + assert ( + trt_out.shape == py_out.shape + ), f"{tag}: shape mismatch {trt_out.shape} vs {py_out.shape}" + r = cosine_similarity(py_out.real.contiguous(), trt_out.real.contiguous()) + i = cosine_similarity(py_out.imag.contiguous(), trt_out.imag.contiguous()) + assert ( + r > COSINE_THRESHOLD + ), f"{tag}: real part cosine sim {r:.4f} < {COSINE_THRESHOLD}" + assert ( + i > COSINE_THRESHOLD + ), f"{tag}: imag part cosine sim {i:.4f} < {COSINE_THRESHOLD}" + + +def _cossim_real(py_out: torch.Tensor, trt_out: torch.Tensor, tag: str) -> None: + """Assert cosine similarity on a real-valued output.""" + assert not trt_out.is_complex(), f"{tag}: expected real output, got {trt_out.dtype}" + s = cosine_similarity(py_out.contiguous(), trt_out.contiguous()) + assert s > COSINE_THRESHOLD, f"{tag}: cosine sim {s:.4f} < {COSINE_THRESHOLD}" + + +_COMPILE = dict(ir="dynamo", min_block_size=1, pass_through_build_failures=True) + + +# =========================================================================== +# 1. I/O boundary tests +# =========================================================================== + + +class ComplexInputRealOutput(nn.Module): + """Complex input → real output (magnitude).""" + + def forward(self, z: torch.Tensor) -> torch.Tensor: + # z: complex, output: real magnitude + r = torch.view_as_real(z) + real = r[..., 0] + imag = r[..., 1] + return torch.sqrt(real * real + imag * imag) + + +@pytest.mark.unit +def test_complex_input_real_output(): + model = ComplexInputRealOutput().eval().cuda() + z = _make_freqs(8, 64) # (8, 32) complex64 + inputs = (z,) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_real(py_out, trt_out, "complex_input_real_output") + torch._dynamo.reset() + + +class RealInputComplexOutput(nn.Module): + """Real input → complex output (no view_as_real at graph output).""" + + def forward(self, x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + # x: real (B, S, H, D), freqs: complex (S, D//2) + xc = torch.view_as_complex(x.reshape(*x.shape[:-1], -1, 2)) + return xc * freqs[None, :, None, :] # complex output + + +@pytest.mark.unit +def test_real_input_complex_output(): + model = RealInputComplexOutput().eval().cuda() + x = torch.randn(2, 8, 4, 64).cuda() + freqs = _make_freqs(8, 64) + inputs = (x, freqs) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_complex(py_out, trt_out, "real_input_complex_output") + torch._dynamo.reset() + + +class ComplexInputComplexOutput(nn.Module): + """Complex input × complex input → complex output.""" + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + return a * b + + +@pytest.mark.unit +def test_complex_input_complex_output(): + model = ComplexInputComplexOutput().eval().cuda() + a = _make_freqs(8, 64) + b = _make_freqs(8, 64) + inputs = (a, b) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_complex(py_out, trt_out, "complex_input_complex_output") + torch._dynamo.reset() + + +class MixedRealComplexInputRealOutput(nn.Module): + """One real input, one complex input, real output.""" + + def forward(self, x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + xc = torch.view_as_complex(x.reshape(*x.shape[:-1], -1, 2)) + prod = xc * freqs[None, :, None, :] + return torch.view_as_real(prod).flatten(3) + + +@pytest.mark.unit +def test_mixed_real_complex_input_real_output(): + model = MixedRealComplexInputRealOutput().eval().cuda() + x = torch.randn(2, 8, 4, 64).cuda() + freqs = _make_freqs(8, 64) + inputs = (x, freqs) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_real(py_out, trt_out, "mixed_io_real_output") + torch._dynamo.reset() + + +# =========================================================================== +# 2. Operator coverage +# =========================================================================== + + +class ComplexAdd(nn.Module): + """Complex addition: (a+bi) + (c+di) = (a+c) + (b+d)i.""" + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + return a + b + + +@pytest.mark.unit +def test_complex_add_output(): + model = ComplexAdd().eval().cuda() + a = _make_freqs(8, 64) + b = _make_freqs(8, 64) + inputs = (a, b) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_complex(py_out, trt_out, "complex_add") + torch._dynamo.reset() + + +class ComplexSub(nn.Module): + """Complex subtraction.""" + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + return a - b + + +@pytest.mark.unit +def test_complex_sub_output(): + model = ComplexSub().eval().cuda() + a = _make_freqs(8, 64) + b = _make_freqs(8, 64) + inputs = (a, b) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_complex(py_out, trt_out, "complex_sub") + torch._dynamo.reset() + + +class ComplexMulChain(nn.Module): + """Chain of two complex multiplications: (a * b) * c.""" + + def forward( + self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor + ) -> torch.Tensor: + return (a * b) * c + + +@pytest.mark.unit +def test_complex_mul_chain(): + model = ComplexMulChain().eval().cuda() + a = _make_freqs(8, 64) + b = _make_freqs(8, 64) + c = _make_freqs(8, 64) + inputs = (a, b, c) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_complex(py_out, trt_out, "complex_mul_chain") + torch._dynamo.reset() + + +class ComplexDiv(nn.Module): + """Complex division: (a+bi)/(c+di) = ((ac+bd) + (bc-ad)i) / (c²+d²).""" + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + return a / b + + +@pytest.mark.unit +def test_complex_div(): + model = ComplexDiv().eval().cuda() + a = _make_freqs(8, 64) + b = _make_freqs( + 8, 64, theta=500.0 + ) # different theta → different angles → non-trivial imaginary + inputs = (a, b) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_complex(py_out, trt_out, "complex_div") + torch._dynamo.reset() + + +class ComplexScalarMul(nn.Module): + """Scale a complex tensor by a real scalar.""" + + def forward(self, z: torch.Tensor) -> torch.Tensor: + # Scale by 2.0 — real * complex + r = torch.view_as_real(z) + scaled = r * 2.0 + return torch.view_as_complex(scaled) + + +@pytest.mark.unit +def test_complex_scalar_mul_output(): + model = ComplexScalarMul().eval().cuda() + z = _make_freqs(8, 64) + inputs = (z,) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_complex(py_out, trt_out, "complex_scalar_mul") + torch._dynamo.reset() + + +class ComplexAbs(nn.Module): + """Complex magnitude: |z| = sqrt(re^2 + im^2) — real output.""" + + def forward(self, z: torch.Tensor) -> torch.Tensor: + r = torch.view_as_real(z) + return (r * r).sum(-1).sqrt() + + +@pytest.mark.unit +def test_complex_abs(): + model = ComplexAbs().eval().cuda() + z = _make_freqs(8, 64) + inputs = (z,) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_real(py_out, trt_out, "complex_abs") + torch._dynamo.reset() + + +class ComplexAbsNative(nn.Module): + """torch.abs on a complex tensor — exercises the aten.abs.default rewrite.""" + + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.abs(z) + + +@pytest.mark.unit +def test_complex_abs_native(): + model = ComplexAbsNative().eval().cuda() + z = torch.polar(2 * torch.ones(8, 32), torch.randn(8, 32)).cuda() + inputs = (z,) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_real(py_out, trt_out, "complex_abs_native") + torch._dynamo.reset() + + +class ComplexExp(nn.Module): + """torch.exp on a complex tensor: exp(a+bi) = e^a*(cos b + i sin b).""" + + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.exp(z) + + +@pytest.mark.unit +def test_complex_exp(): + model = ComplexExp().eval().cuda() + # small magnitudes to keep exp from overflowing + z = torch.polar(0.1 * torch.ones(8, 32), torch.randn(8, 32)).cuda() + inputs = (z,) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_complex(py_out, trt_out, "complex_exp") + torch._dynamo.reset() + + +class ComplexLog(nn.Module): + """torch.log on a complex tensor: log(a+bi) = log|z| + i*angle(z).""" + + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.log(z) + + +@pytest.mark.unit +def test_complex_log(): + model = ComplexLog().eval().cuda() + z = torch.polar(2 * torch.ones(8, 32), torch.randn(8, 32)).cuda() + inputs = (z,) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_complex(py_out, trt_out, "complex_log") + torch._dynamo.reset() + + +class ComplexPow(nn.Module): + """z**n via polar form: r^n * (cos(nθ) + i*sin(nθ)).""" + + def forward(self, z: torch.Tensor) -> torch.Tensor: + return z**3 + + +@pytest.mark.unit +def test_complex_pow(): + model = ComplexPow().eval().cuda() + z = torch.polar(torch.ones(8, 32), torch.randn(8, 32)).cuda() + inputs = (z,) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_complex(py_out, trt_out, "complex_pow") + torch._dynamo.reset() + + +class ComplexSqrt(nn.Module): + """torch.sqrt on a complex tensor: z**0.5.""" + + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.sqrt(z) + + +@pytest.mark.unit +def test_complex_sqrt(): + model = ComplexSqrt().eval().cuda() + z = torch.polar(4 * torch.ones(8, 32), torch.randn(8, 32)).cuda() + inputs = (z,) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_complex(py_out, trt_out, "complex_sqrt") + torch._dynamo.reset() + + +class ComplexConj(nn.Module): + """torch.conj on a complex tensor — exercises the _conj rewrite.""" + + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.conj(z) + + +@pytest.mark.unit +def test_complex_conj(): + model = ComplexConj().eval().cuda() + z = _make_freqs(8, 64) + inputs = (z,) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs).resolve_conj() + trt_out = trt_model(*inputs) + _cossim_complex(py_out, trt_out, "complex_conj") + torch._dynamo.reset() + + +class ComplexConjMul(nn.Module): + """z * conj(z) = |z|^2 — real-valued result returned as complex.""" + + def forward(self, z: torch.Tensor) -> torch.Tensor: + r = torch.view_as_real(z) + re, im = r[..., 0], r[..., 1] + # conj(z) has same real, negated imag + real_part = re * re + im * im # ac - b(-d) = ac + bd when c=a, d=-b + imag_part = torch.zeros_like(real_part) + return torch.view_as_complex(torch.stack([real_part, imag_part], dim=-1)) + + +@pytest.mark.unit +def test_complex_conj_mul(): + model = ComplexConjMul().eval().cuda() + z = _make_freqs(8, 64) + inputs = (z,) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_complex(py_out, trt_out, "complex_conj_mul") + torch._dynamo.reset() + + +# =========================================================================== +# 3. Gather/scatter: select, slice, index +# =========================================================================== + + +class ComplexSelect(nn.Module): + """Select a slice along a dimension from a complex tensor.""" + + def forward(self, z: torch.Tensor) -> torch.Tensor: + # z: (S, D) complex — select first half of S, then mul + half = z[:4, :] # slice along seq dim + return half * z[4:, :] # element-wise complex mul, real output via view_as_real + # returns complex — covered by complex output test + + +@pytest.mark.unit +def test_complex_select_and_mul(): + model = ComplexSelect().eval().cuda() + z = _make_freqs(8, 64) # (8, 32) complex64 + inputs = (z,) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_complex(py_out, trt_out, "complex_select_mul") + torch._dynamo.reset() + + +class ComplexSlice(nn.Module): + """Slice two halves of a complex tensor and multiply them.""" + + def forward(self, z: torch.Tensor) -> torch.Tensor: + # z: (S, D) complex — split into first and second half along D + half_d = z.shape[-1] // 2 + a = z[:, :half_d] # (S, D//2) complex + b = z[:, half_d:] # (S, D//2) complex + return a * b # complex output + + +@pytest.mark.unit +def test_complex_slice_and_mul(): + model = ComplexSlice().eval().cuda() + z = _make_freqs(8, 64) # (8, 32) complex64 + inputs = (z,) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_complex(py_out, trt_out, "complex_slice_mul") + torch._dynamo.reset() + + +# =========================================================================== +# 4. Shape manipulation: reshape, unsqueeze, squeeze, expand, flatten +# =========================================================================== + + +class ComplexReshapeAndMul(nn.Module): + """Reshape a complex tensor, then multiply.""" + + def forward(self, x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + # x: real (B, S, H*D), freqs: complex (S, D//2) + B, S, HD = x.shape + H = 4 + D = HD // H + xr = x.view(B, S, H, D) + xc = torch.view_as_complex(xr.reshape(B, S, H, -1, 2)) # (B,S,H,D//2) complex + return torch.view_as_real(xc * freqs[None, :, None, :]).flatten(3) + + +@pytest.mark.unit +def test_complex_reshape_and_mul(): + model = ComplexReshapeAndMul().eval().cuda() + x = torch.randn(2, 8, 64).cuda() + freqs = _make_freqs(8, 16) # (8, 8) complex, head_dim=16 -> D//2=8 + inputs = (x, freqs) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_real(py_out, trt_out, "complex_reshape_mul") + torch._dynamo.reset() + + +class ComplexUnsqueezeExpand(nn.Module): + """Unsqueeze and expand a complex tensor before multiplication.""" + + def forward(self, z: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + # z: (S, D) complex, freqs: (D,) complex + # unsqueeze freqs to broadcast over S + return z * freqs.unsqueeze(0) # (S,D) complex output + + +@pytest.mark.unit +def test_complex_unsqueeze_expand(): + model = ComplexUnsqueezeExpand().eval().cuda() + z = _make_freqs(8, 64) # (8, 32) + freqs = _make_freqs(1, 64).squeeze(0) # (32,) + inputs = (z, freqs) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_complex(py_out, trt_out, "complex_unsqueeze_expand") + torch._dynamo.reset() + + +# =========================================================================== +# 5. Concatenation and stacking +# =========================================================================== + + +class ComplexCat(nn.Module): + """Concatenate two complex tensors along the sequence dimension.""" + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + return torch.cat([a, b], dim=0) # (2S, D) complex output + + +@pytest.mark.unit +def test_complex_cat(): + model = ComplexCat().eval().cuda() + a = _make_freqs(4, 64) + b = _make_freqs(4, 64) + inputs = (a, b) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_complex(py_out, trt_out, "complex_cat") + torch._dynamo.reset() + + +class ComplexCatThenMul(nn.Module): + """Concatenate two complex tensors then multiply by a third.""" + + def forward( + self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor + ) -> torch.Tensor: + ab = torch.cat([a, b], dim=0) # (2S, D) + return ab * c # complex output + + +@pytest.mark.unit +def test_complex_cat_then_mul(): + model = ComplexCatThenMul().eval().cuda() + a = _make_freqs(4, 64) + b = _make_freqs(4, 64) + c = _make_freqs(8, 64) + inputs = (a, b, c) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_complex(py_out, trt_out, "complex_cat_then_mul") + torch._dynamo.reset() + + +class ComplexStackRealView(nn.Module): + """Stack real-view representations of two complex tensors, then multiply. + + Tests that the rewriter correctly handles complex ops on stacked real tensors: + view_as_real(a) and view_as_real(b) are stacked, then used to form two + independent complex multiplications. + """ + + def forward( + self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor + ) -> torch.Tensor: + # a, b: (S, D) complex, c: (S, D) complex + # Multiply each independently and add — tests multiple complex paths + return torch.view_as_real(a * c).flatten(-2) + torch.view_as_real( + b * c + ).flatten(-2) + + +@pytest.mark.unit +def test_complex_stack_real_view(): + model = ComplexStackRealView().eval().cuda() + a = _make_freqs(8, 64) + b = _make_freqs(8, 64) + c = _make_freqs(8, 64) + inputs = (a, b, c) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_real(py_out, trt_out, "complex_stack_real_view") + torch._dynamo.reset() + + +# =========================================================================== +# 6. Where / masked selection +# =========================================================================== + + +class ComplexWhere(nn.Module): + """Conditional selection between two complex tensors.""" + + def forward( + self, mask: torch.Tensor, a: torch.Tensor, b: torch.Tensor + ) -> torch.Tensor: + # Operate on real/imag separately — where doesn't support complex natively + ar = torch.view_as_real(a) + br = torch.view_as_real(b) + m = mask.unsqueeze(-1) # broadcast over last (2,) dim + out = torch.where(m, ar, br) + return torch.view_as_complex(out.contiguous()) + + +@pytest.mark.unit +def test_complex_where(): + model = ComplexWhere().eval().cuda() + a = _make_freqs(8, 64) + b = _make_freqs(8, 64) + mask = (torch.randn(8, 32) > 0).cuda() + inputs = (mask, a, b) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_complex(py_out, trt_out, "complex_where") + torch._dynamo.reset() + + +# =========================================================================== +# 7. Multiple complex subgraphs in one model +# =========================================================================== + + +class DualComplexPath(nn.Module): + """Two independent complex multiplications merged at the output. + + freqs is passed already broadcast-ready (same shape as the complex view of x/y) + so no indexing/unsqueeze is needed on the complex tensor. + """ + + def forward( + self, + x: torch.Tensor, + y: torch.Tensor, + freqs: torch.Tensor, + ) -> torch.Tensor: + # Path A: x rotated by freqs + xa = torch.view_as_complex(x.reshape(*x.shape[:-1], -1, 2)) + out_a = torch.view_as_real(xa * freqs).flatten(3) + # Path B: y rotated by same freqs + xb = torch.view_as_complex(y.reshape(*y.shape[:-1], -1, 2)) + out_b = torch.view_as_real(xb * freqs).flatten(3) + return out_a + out_b # real output + + +@pytest.mark.unit +def test_dual_complex_path(): + model = DualComplexPath().eval().cuda() + x = torch.randn(2, 8, 4, 64).cuda() + y = torch.randn(2, 8, 4, 64).cuda() + # freqs must match the complex view shape (2,8,4,32) — broadcast via register_buffer + freqs = ( + _make_freqs(8, 64).unsqueeze(0).unsqueeze(2).expand(2, 8, 4, 32).contiguous() + ) + inputs = (x, y, freqs) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_real(py_out, trt_out, "dual_complex_path") + torch._dynamo.reset() + + +# =========================================================================== +# 8. Complex ops interleaved with real ops +# =========================================================================== + + +class ComplexSandwich(nn.Module): + """Real → complex → real → linear → complex → real sandwich. + + Uses a buffer for freqs so the complex tensor is a get_attr (not placeholder), + which the rewriter handles via stacked real tensor. + """ + + def __init__(self, freqs: torch.Tensor) -> None: + super().__init__() + self.register_buffer("freqs", freqs) + self.linear = nn.Linear(64, 64, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # real → complex rotation using buffer freqs (get_attr path) + xc = torch.view_as_complex(x.reshape(*x.shape[:-1], -1, 2)) + rotated = torch.view_as_real(xc * self.freqs).flatten(3) # (B,S,H,D) real + # real linear + out = self.linear(rotated) + # another complex rotation + outc = torch.view_as_complex(out.reshape(*out.shape[:-1], -1, 2)) + return torch.view_as_real(outc * self.freqs).flatten(3) + + +@pytest.mark.unit +def test_complex_sandwich(): + freqs = ( + _make_freqs(8, 64).unsqueeze(0).unsqueeze(2).expand(2, 8, 4, 32).contiguous() + ) + model = ComplexSandwich(freqs).eval().cuda() + x = torch.randn(2, 8, 4, 64).cuda() + inputs = (x,) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_real(py_out, trt_out, "complex_sandwich") + torch._dynamo.reset() + + +# =========================================================================== +# 9. Complex nn.Parameter (get_attr path) +# =========================================================================== + + +class ComplexParamMul(nn.Module): + """Complex weight stored as nn.Parameter — exercises the get_attr rewrite path.""" + + def __init__(self, freqs: torch.Tensor) -> None: + super().__init__() + # nn.Parameter, not register_buffer — still a get_attr node in the exported graph + self.freqs = nn.Parameter(freqs, requires_grad=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + xc = torch.view_as_complex(x.reshape(*x.shape[:-1], -1, 2)) + return torch.view_as_real(xc * self.freqs).flatten(3) + + +@pytest.mark.unit +def test_complex_param_get_attr(): + freqs = ( + _make_freqs(8, 64).unsqueeze(0).unsqueeze(2).expand(2, 8, 4, 32).contiguous() + ) + model = ComplexParamMul(freqs).eval().cuda() + x = torch.randn(2, 8, 4, 64).cuda() + inputs = (x,) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_real(py_out, trt_out, "complex_param_get_attr") + torch._dynamo.reset() + + +# =========================================================================== +# 10. Dynamic shapes +# =========================================================================== + + +class ComplexMulDynamic(nn.Module): + """Complex RoPE with dynamic seq_len.""" + + def forward(self, x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + xc = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + return torch.view_as_real(xc * freqs[None, :, None, :]).flatten(3) + + +@pytest.mark.unit +def test_complex_mul_dynamic_seqlen(): + """Dynamic seq_len: x has shape (B, seq, H, D), freqs has shape (seq, D//2).""" + model = ComplexMulDynamic().eval().cuda() + x = torch.randn(2, 8, 4, 64).cuda() + freqs = _make_freqs(8, 64) # (8, 32) + inputs = (x, freqs) + + # x dim-1 and freqs dim-0 are both the seq dimension — share the same Dim + seq = Dim("seq", min=2, max=64) + dynamic_shapes = ({1: seq}, {0: seq}) + ep = torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes) + trt_model = torchtrt.dynamo.compile( + ep, inputs=inputs, min_block_size=1, pass_through_build_failures=True + ) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_real(py_out, trt_out, "complex_mul_dynamic_seqlen") + torch._dynamo.reset() + + +class ComplexOutputDynamic(nn.Module): + """Complex output with dynamic batch.""" + + def forward(self, x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + xc = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + return xc * freqs[None, :, None, :] # complex output + + +@pytest.mark.unit +def test_complex_output_dynamic_batch(): + model = ComplexOutputDynamic().eval().cuda() + x = torch.randn(2, 8, 4, 64).cuda() + freqs = _make_freqs(8, 64) + inputs = (x, freqs) + + batch = Dim("batch", min=1, max=8) + dynamic_shapes = ({0: batch}, {}) + ep = torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes) + trt_model = torchtrt.dynamo.compile( + ep, inputs=inputs, min_block_size=1, pass_through_build_failures=True + ) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_complex(py_out, trt_out, "complex_output_dynamic_batch") + torch._dynamo.reset() + + +# =========================================================================== +# 11. Numerical precision: complex64 vs truncated complex128 +# =========================================================================== + + +class Complex128Model(nn.Module): + """Uses complex128 (double precision).""" + + def forward(self, z: torch.Tensor, w: torch.Tensor) -> torch.Tensor: + return torch.view_as_real(z * w).flatten(-2) + + +@pytest.mark.unit +def test_complex128_truncated_to_float32(): + """complex128 with truncate_double=True should compile to float32 arithmetic.""" + model = Complex128Model().eval().cuda() + z = torch.polar( + torch.ones(8, 32, dtype=torch.float64), + torch.randn(8, 32, dtype=torch.float64), + ).cuda() + w = torch.polar( + torch.ones(8, 32, dtype=torch.float64), + torch.randn(8, 32, dtype=torch.float64), + ).cuda() + inputs = (z, w) + trt_model = torchtrt.compile( + model, + inputs=inputs, + ir="dynamo", + min_block_size=1, + pass_through_build_failures=True, + truncate_double=True, + ) + py_out = model(*inputs).float() # cast reference to float32 for comparison + trt_out = trt_model(*inputs) + _cossim_real(py_out, trt_out, "complex128_truncated") + torch._dynamo.reset() + + +# =========================================================================== +# 12. End-to-end: full attention-style block with complex RoPE +# =========================================================================== + + +class AttentionWithComplexRoPE(nn.Module): + """Multi-head self-attention with complex-number RoPE and real output.""" + + def __init__(self, d_model: int = 64, n_heads: int = 4) -> None: + super().__init__() + self.n_heads = n_heads + self.head_dim = d_model // n_heads + self.q_proj = nn.Linear(d_model, d_model, bias=False) + self.k_proj = nn.Linear(d_model, d_model, bias=False) + self.v_proj = nn.Linear(d_model, d_model, bias=False) + self.out_proj = nn.Linear(d_model, d_model, bias=False) + + def _apply_rope(self, x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + B, S, H, D = x.shape + xc = torch.view_as_complex(x.reshape(B, S, H, -1, 2)) + return torch.view_as_real(xc * freqs[None, :, None, :]).flatten(3) + + def forward(self, x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + B, S, C = x.shape + H, D = self.n_heads, self.head_dim + q = self.q_proj(x).view(B, S, H, D) + k = self.k_proj(x).view(B, S, H, D) + v = self.v_proj(x).view(B, S, H, D) + q = self._apply_rope(q, freqs) + k = self._apply_rope(k, freqs) + # Scaled dot-product attention + scale = D**-0.5 + attn = torch.einsum("bshd,bthd->bhst", q, k) * scale + attn = torch.softmax(attn, dim=-1) + out = torch.einsum("bhst,bthd->bshd", attn, v).reshape(B, S, C) + return self.out_proj(out) + + +@pytest.mark.unit +def test_attention_with_complex_rope_static(): + model = AttentionWithComplexRoPE(d_model=64, n_heads=4).eval().cuda() + x = torch.randn(2, 8, 64).cuda() + freqs = _make_freqs(8, 16) # head_dim=16, D//2=8 + inputs = (x, freqs) + trt_model = torchtrt.compile(model, inputs=inputs, **_COMPILE) + py_out = model(*inputs) + trt_out = trt_model(*inputs) + _cossim_real(py_out, trt_out, "attention_with_complex_rope") + torch._dynamo.reset() + + +# =========================================================================== +# 13. Elementwise-safe structural ops (clone, permute) +# =========================================================================== + + +class ComplexClone(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + return z.clone() * z + + +@pytest.mark.unit +def test_complex_clone(): + model = ComplexClone().eval().cuda() + z = _make_freqs(8, 32) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_clone") + torch._dynamo.reset() + + +class ComplexPermute(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + # permute spatial dims only, then apply mul so complex subgraph is detected + return z.permute(1, 0) * z.permute(1, 0) + + +@pytest.mark.unit +def test_complex_permute(): + model = ComplexPermute().eval().cuda() + z = _make_freqs(8, 32) # (8, 16) complex64 + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_permute") + torch._dynamo.reset() + + +# =========================================================================== +# 14. Extraction / construction ops +# =========================================================================== + + +class ComplexRealExtract(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + return z.real + + +@pytest.mark.unit +def test_complex_real(): + model = ComplexRealExtract().eval().cuda() + z = _make_freqs(8, 32) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + py_out = model(z) + trt_out = trt_model(z) + _cossim_real(py_out, trt_out, "complex_real") + torch._dynamo.reset() + + +class ComplexImagExtract(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + return z.imag + + +@pytest.mark.unit +def test_complex_imag(): + model = ComplexImagExtract().eval().cuda() + z = _make_freqs(8, 32) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + py_out = model(z) + trt_out = trt_model(z) + _cossim_real(py_out, trt_out, "complex_imag") + torch._dynamo.reset() + + +class ComplexAngle(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.angle(z) + + +@pytest.mark.unit +def test_complex_angle(): + model = ComplexAngle().eval().cuda() + z = _make_freqs(8, 32) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + py_out = model(z) + trt_out = trt_model(z) + _cossim_real(py_out, trt_out, "complex_angle") + torch._dynamo.reset() + + +class ComplexPolar(nn.Module): + def forward(self, r: torch.Tensor, theta: torch.Tensor) -> torch.Tensor: + return torch.polar(r, theta) + + +@pytest.mark.unit +def test_complex_polar(): + r = torch.rand(8, 16, device="cuda") + 0.1 + theta = torch.rand(8, 16, device="cuda") * 2 * 3.14159 + model = ComplexPolar().eval().cuda() + trt_model = torchtrt.compile(model, inputs=(r, theta), **_COMPILE) + py_out = model(r, theta) + trt_out = trt_model(r, theta) + _cossim_complex(py_out, trt_out, "complex_polar") + torch._dynamo.reset() + + +class ComplexReciprocal(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.reciprocal(z) + + +@pytest.mark.unit +def test_complex_reciprocal(): + model = ComplexReciprocal().eval().cuda() + # Use non-unit magnitude to avoid trivial 1/z=conj(z) for |z|=1 + z = _make_freqs(8, 32) * 2.0 + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_reciprocal") + torch._dynamo.reset() + + +class ComplexRsqrt(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.rsqrt(z) + + +@pytest.mark.unit +def test_complex_rsqrt(): + model = ComplexRsqrt().eval().cuda() + # Use polar form with r > 0 so rsqrt is well-defined + r = torch.rand(8, 16, device="cuda") + 0.5 + theta = torch.rand(8, 16, device="cuda") * 3.14159 + z = torch.polar(r, theta) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_rsqrt") + torch._dynamo.reset() + + +class ComplexAddScalar(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + # Exercise add.Scalar: (a+bi)+2 = (a+2)+bi + return torch.view_as_complex(torch.view_as_real(z).add(0.0)) + 2.0 + + +@pytest.mark.unit +def test_complex_add_scalar(): + """add.Scalar: scalar adds to real part only — (a+2) + bi.""" + model = ComplexAddScalar().eval().cuda() + z = _make_freqs(8, 32) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_add_scalar") + torch._dynamo.reset() + + +class ComplexSgn(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.sgn(z) + + +@pytest.mark.unit +def test_complex_sgn(): + """sgn(z) = z/|z|, sgn(0) = 0.""" + model = ComplexSgn().eval().cuda() + r = torch.rand(8, 16, device="cuda") + 0.1 + theta = torch.rand(8, 16, device="cuda") * 2 * 3.14159 + # Include one zero entry + r[0, 0] = 0.0 + theta[0, 0] = 0.0 + z = torch.polar(r, theta) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_sgn") + torch._dynamo.reset() + + +class ComplexLog2(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.log2(z) + + +@pytest.mark.unit +def test_complex_log2(): + model = ComplexLog2().eval().cuda() + r = torch.rand(8, 16, device="cuda") + 0.5 + theta = torch.rand(8, 16, device="cuda") * 3.14159 + z = torch.polar(r, theta) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_log2") + torch._dynamo.reset() + + +class ComplexLog10(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.log10(z) + + +@pytest.mark.unit +def test_complex_log10(): + model = ComplexLog10().eval().cuda() + r = torch.rand(8, 16, device="cuda") + 0.5 + theta = torch.rand(8, 16, device="cuda") * 3.14159 + z = torch.polar(r, theta) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_log10") + torch._dynamo.reset() + + +class ComplexLog1p(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.log1p(z) + + +@pytest.mark.unit +def test_complex_log1p(): + model = ComplexLog1p().eval().cuda() + # |z| < 1 for numerical stability + r = torch.rand(8, 16, device="cuda") * 0.5 + theta = torch.rand(8, 16, device="cuda") * 2 * 3.14159 + z = torch.polar(r, theta) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_log1p") + torch._dynamo.reset() + + +class ComplexExpm1(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.expm1(z) + + +@pytest.mark.unit +def test_complex_expm1(): + model = ComplexExpm1().eval().cuda() + # Small magnitude to avoid exp overflow + r = torch.rand(8, 16, device="cuda") * 0.3 + theta = torch.rand(8, 16, device="cuda") * 2 * 3.14159 + z = torch.polar(r, theta) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_expm1") + torch._dynamo.reset() + + +class ComplexIsnan(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + # Output bool → cast to float for cosine sim + return torch.isnan(z).float() + + +@pytest.mark.unit +def test_complex_isnan(): + model = ComplexIsnan().eval().cuda() + z = _make_freqs(8, 32) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + py_out = model(z) + trt_out = trt_model(z) + # All-zero output: check element-wise equality + assert torch.allclose(py_out, trt_out), "complex_isnan: output mismatch" + torch._dynamo.reset() + + +class ComplexIsinf(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.isinf(z).float() + + +@pytest.mark.unit +def test_complex_isinf(): + model = ComplexIsinf().eval().cuda() + z = _make_freqs(8, 32) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + py_out = model(z) + trt_out = trt_model(z) + assert torch.allclose(py_out, trt_out), "complex_isinf: output mismatch" + torch._dynamo.reset() + + +# =========================================================================== +# 15. Trigonometric ops +# =========================================================================== + + +class ComplexSin(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.sin(z) + + +@pytest.mark.unit +def test_complex_sin(): + model = ComplexSin().eval().cuda() + r = torch.ones(8, 16, device="cuda") * 0.5 + theta = torch.linspace(0.1, 1.5, 16, device="cuda").unsqueeze(0).expand(8, -1) + z = torch.polar(r, theta) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_sin") + torch._dynamo.reset() + + +class ComplexCos(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.cos(z) + + +@pytest.mark.unit +def test_complex_cos(): + model = ComplexCos().eval().cuda() + r = torch.ones(8, 16, device="cuda") * 0.5 + theta = torch.linspace(0.1, 1.5, 16, device="cuda").unsqueeze(0).expand(8, -1) + z = torch.polar(r, theta) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_cos") + torch._dynamo.reset() + + +class ComplexSinh(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.sinh(z) + + +@pytest.mark.unit +def test_complex_sinh(): + model = ComplexSinh().eval().cuda() + # Small imaginary part to avoid cosh overflow + r = torch.rand(8, 16, device="cuda") * 0.5 + theta = torch.rand(8, 16, device="cuda") * 0.5 + z = torch.polar(r, theta) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_sinh") + torch._dynamo.reset() + + +class ComplexCosh(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.cosh(z) + + +@pytest.mark.unit +def test_complex_cosh(): + model = ComplexCosh().eval().cuda() + r = torch.rand(8, 16, device="cuda") * 0.5 + theta = torch.rand(8, 16, device="cuda") * 0.5 + z = torch.polar(r, theta) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_cosh") + torch._dynamo.reset() + + +class ComplexTan(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.tan(z) + + +@pytest.mark.unit +def test_complex_tan(): + model = ComplexTan().eval().cuda() + # Avoid a = ±pi/4 where denom → 0 + r = torch.rand(8, 16, device="cuda") * 0.4 + theta = torch.rand(8, 16, device="cuda") * 0.3 + 0.5 + z = torch.polar(r, theta) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_tan") + torch._dynamo.reset() + + +class ComplexTanh(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.tanh(z) + + +@pytest.mark.unit +def test_complex_tanh(): + model = ComplexTanh().eval().cuda() + r = torch.rand(8, 16, device="cuda") * 0.5 + theta = torch.rand(8, 16, device="cuda") * 0.5 + z = torch.polar(r, theta) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_tanh") + torch._dynamo.reset() + + +# =========================================================================== +# 16. Inverse trigonometric ops +# =========================================================================== + + +class ComplexAsinh(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.asinh(z) + + +@pytest.mark.unit +def test_complex_asinh(): + model = ComplexAsinh().eval().cuda() + r = torch.rand(8, 16, device="cuda") * 0.8 + 0.1 + theta = torch.rand(8, 16, device="cuda") * 2 * 3.14159 + z = torch.polar(r, theta) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_asinh") + torch._dynamo.reset() + + +class ComplexAcosh(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.acosh(z) + + +@pytest.mark.unit +def test_complex_acosh(): + model = ComplexAcosh().eval().cuda() + # |Re(z)| > 1 for non-trivial (non-purely-imaginary) result + r = torch.rand(8, 16, device="cuda") * 0.5 + 1.5 + theta = torch.rand(8, 16, device="cuda") * 0.4 + z = torch.polar(r, theta) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_acosh") + torch._dynamo.reset() + + +class ComplexAtanh(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.atanh(z) + + +@pytest.mark.unit +def test_complex_atanh(): + model = ComplexAtanh().eval().cuda() + # |z| < 1 to stay within principal domain + r = torch.rand(8, 16, device="cuda") * 0.6 + 0.1 + theta = torch.rand(8, 16, device="cuda") * 2 * 3.14159 + z = torch.polar(r, theta) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_atanh") + torch._dynamo.reset() + + +class ComplexAsin(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.asin(z) + + +@pytest.mark.unit +def test_complex_asin(): + """asin(z) = -i*log(iz + sqrt(1-z²)). + Tested with |z| < 1 to avoid branch-cut ambiguity on the real axis.""" + model = ComplexAsin().eval().cuda() + r = torch.rand(8, 16, device="cuda") * 0.6 + 0.1 + theta = torch.rand(8, 16, device="cuda") * 2 * 3.14159 + z = torch.polar(r, theta) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_asin") + torch._dynamo.reset() + + +class ComplexAcos(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.acos(z) + + +@pytest.mark.unit +def test_complex_acos(): + """acos(z) = -i*log(z + i*sqrt(1-z²)). + Tested with |z| < 1 to avoid branch-cut ambiguity on the real axis.""" + model = ComplexAcos().eval().cuda() + r = torch.rand(8, 16, device="cuda") * 0.6 + 0.1 + theta = torch.rand(8, 16, device="cuda") * 2 * 3.14159 + z = torch.polar(r, theta) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_acos") + torch._dynamo.reset() + + +class ComplexAtan(nn.Module): + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.atan(z) + + +@pytest.mark.unit +def test_complex_atan(): + """atan(z) = (i/2)*log((1-iz)/(1+iz)). + Tested with |z| < 1.""" + model = ComplexAtan().eval().cuda() + r = torch.rand(8, 16, device="cuda") * 0.6 + 0.1 + theta = torch.rand(8, 16, device="cuda") * 2 * 3.14159 + z = torch.polar(r, theta) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_atan") + torch._dynamo.reset() + + +# =========================================================================== +# 17. Complex-complex power (pow.Tensor_Tensor) +# =========================================================================== + + +class ComplexPowTensorTensor(nn.Module): + def forward(self, z1: torch.Tensor, z2: torch.Tensor) -> torch.Tensor: + return torch.pow(z1, z2) + + +@pytest.mark.unit +def test_complex_pow_tensor_tensor(): + """z1**z2 = exp(z2 * log(z1)), both complex.""" + model = ComplexPowTensorTensor().eval().cuda() + # Use unit-magnitude base to keep values bounded + r1 = torch.ones(8, 16, device="cuda") + theta1 = torch.rand(8, 16, device="cuda") * 2 * 3.14159 + z1 = torch.polar(r1, theta1) + # Small exponent magnitude to avoid overflow + r2 = torch.rand(8, 16, device="cuda") * 0.3 + theta2 = torch.rand(8, 16, device="cuda") * 2 * 3.14159 + z2 = torch.polar(r2, theta2) + trt_model = torchtrt.compile(model, inputs=(z1, z2), **_COMPILE) + _cossim_complex(model(z1, z2), trt_model(z1, z2), "complex_pow_tensor_tensor") + torch._dynamo.reset() + + +# =========================================================================== +# 18. Composite complex-only multi-op chains +# =========================================================================== + + +class ComplexLogExp(nn.Module): + """exp(log(z)) ≈ z — round-trip through log and exp.""" + + def forward(self, z: torch.Tensor) -> torch.Tensor: + return torch.exp(torch.log(z)) + + +@pytest.mark.unit +def test_complex_log_exp(): + """exp(log(z)) ≈ z: round-trip verifies log and exp rewrites compose correctly.""" + model = ComplexLogExp().eval().cuda() + r = torch.rand(8, 16, device="cuda") * 0.8 + 0.2 + theta = torch.rand(8, 16, device="cuda") * 1.5 + z = torch.polar(r, theta) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_log_exp") + torch._dynamo.reset() + + +class ComplexMulAddSub(nn.Module): + """(a*b)+c-d — four complex operands, two muls and add/sub.""" + + def forward(self, a, b, c, d): + return (a * b) + c - d + + +@pytest.mark.unit +def test_complex_mul_add_sub(): + """(a*b)+c-d with four complex inputs.""" + model = ComplexMulAddSub().eval().cuda() + + def _rc(): + return torch.polar( + torch.rand(8, 16, device="cuda") + 0.2, + torch.rand(8, 16, device="cuda") * 2 * 3.14159, + ) + + a, b, c, d = _rc(), _rc(), _rc(), _rc() + trt_model = torchtrt.compile(model, inputs=(a, b, c, d), **_COMPILE) + _cossim_complex(model(a, b, c, d), trt_model(a, b, c, d), "complex_mul_add_sub") + torch._dynamo.reset() + + +class ComplexConjThenMul(nn.Module): + """conj(a) * b.""" + + def forward(self, a, b): + return torch.conj(a) * b + + +@pytest.mark.unit +def test_complex_conj_then_mul(): + """conj(a)*b: conjugate followed by complex multiply.""" + model = ComplexConjThenMul().eval().cuda() + + def _rc(): + return torch.polar( + torch.rand(8, 16, device="cuda") + 0.2, + torch.rand(8, 16, device="cuda") * 2 * 3.14159, + ) + + a, b = _rc(), _rc() + trt_model = torchtrt.compile(model, inputs=(a, b), **_COMPILE) + _cossim_complex(model(a, b), trt_model(a, b), "complex_conj_then_mul") + torch._dynamo.reset() + + +class ComplexAbsThenLog(nn.Module): + """log(abs(z)) — chain ending in real output.""" + + def forward(self, z): + return torch.log(torch.abs(z)) + + +@pytest.mark.unit +def test_complex_abs_then_log(): + """log(|z|): abs(complex) → log(real), result is real.""" + model = ComplexAbsThenLog().eval().cuda() + r = torch.rand(8, 16, device="cuda") * 0.8 + 0.2 + z = torch.polar(r, torch.rand(8, 16, device="cuda") * 2 * 3.14159) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + py_out = model(z) + trt_out = trt_model(z) + _cossim_real(py_out, trt_out, "complex_abs_then_log") + torch._dynamo.reset() + + +class ComplexSqrtThenMul(nn.Module): + """sqrt(a) * sqrt(b) — two sqrt rewrites in one graph.""" + + def forward(self, a, b): + return torch.sqrt(a) * torch.sqrt(b) + + +@pytest.mark.unit +def test_complex_sqrt_then_mul(): + """sqrt(a)*sqrt(b) ≈ sqrt(a*b) — exercises two sqrt rewrites in one graph.""" + model = ComplexSqrtThenMul().eval().cuda() + r = torch.rand(8, 16, device="cuda") + 0.5 + a = torch.polar(r, torch.rand(8, 16, device="cuda") * 3.14159) + b = torch.polar(r, torch.rand(8, 16, device="cuda") * 3.14159) + trt_model = torchtrt.compile(model, inputs=(a, b), **_COMPILE) + _cossim_complex(model(a, b), trt_model(a, b), "complex_sqrt_then_mul") + torch._dynamo.reset() + + +class ComplexPowThenAdd(nn.Module): + """z**2 + z — polynomial evaluation via pow + add.""" + + def forward(self, z): + return z**2 + z + + +@pytest.mark.unit +def test_complex_pow_then_add(): + """z² + z — quadratic in z, exercises pow.Tensor_Scalar → add chain.""" + model = ComplexPowThenAdd().eval().cuda() + z = torch.polar( + torch.rand(8, 16, device="cuda") * 0.8 + 0.2, + torch.rand(8, 16, device="cuda") * 2 * 3.14159, + ) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_pow_then_add") + torch._dynamo.reset() + + +class ComplexSinCosPythagorean(nn.Module): + """sin(z)² + cos(z)² — Pythagorean identity over ℂ.""" + + def forward(self, z): + s = torch.sin(z) + c = torch.cos(z) + return s * s + c * c + + +@pytest.mark.unit +def test_complex_sin_cos_pythagorean(): + """sin²(z) + cos²(z): TRT vs PyTorch agree numerically.""" + model = ComplexSinCosPythagorean().eval().cuda() + r = torch.rand(8, 16, device="cuda") * 0.4 + theta = ( + torch.linspace(0.1, 1.2, 16, device="cuda") + .unsqueeze(0) + .expand(8, -1) + .contiguous() + ) + z = torch.polar(r, theta) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + py_out = model(z) + trt_out = trt_model(z) + _cossim_real( + py_out.real.contiguous(), + trt_out.real.contiguous(), + "complex_sin_cos_pythagorean", + ) + torch._dynamo.reset() + + +class ComplexExpThenAbs(nn.Module): + """|exp(z)| = exp(Re(z)) — chain: exp → abs, result is real.""" + + def forward(self, z): + return torch.abs(torch.exp(z)) + + +@pytest.mark.unit +def test_complex_exp_then_abs(): + """|exp(z)| = exp(Re(z)): exercises exp rewrite feeding into abs rewrite.""" + model = ComplexExpThenAbs().eval().cuda() + r = torch.rand(8, 16, device="cuda") * 0.3 + theta = torch.rand(8, 16, device="cuda") * 2 * 3.14159 + z = torch.polar(r, theta) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + py_out = model(z) + trt_out = trt_model(z) + _cossim_real(py_out, trt_out, "complex_exp_then_abs") + torch._dynamo.reset() + + +class ComplexNormalize(nn.Module): + """z / |z| — normalize to unit circle via abs + divide.""" + + def forward(self, z): + mag = torch.abs(z) + # Avoid aten.complex.default — build complex divisor via view_as_complex. + mag_c = torch.view_as_complex(torch.stack([mag, torch.zeros_like(mag)], dim=-1)) + return z / mag_c + + +@pytest.mark.unit +def test_complex_normalize(): + """z/|z|: unit-normalize a complex tensor.""" + model = ComplexNormalize().eval().cuda() + z = torch.polar( + torch.rand(8, 16, device="cuda") * 0.8 + 0.2, + torch.rand(8, 16, device="cuda") * 2 * 3.14159, + ) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_normalize") + torch._dynamo.reset() + + +# =========================================================================== +# 19. Complex + real interleaved computations +# =========================================================================== + + +class ComplexMulThenRealLinear(nn.Module): + """Complex rotation followed by a real-valued linear projection (core RoPE pattern).""" + + def __init__(self) -> None: + super().__init__() + self.proj = nn.Linear(64, 32, bias=False) + + def forward(self, x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + xc = torch.view_as_complex(x.reshape(*x.shape[:-1], -1, 2)) + rotated = torch.view_as_real(xc * freqs).flatten(-2) + return self.proj(rotated) + + +@pytest.mark.unit +def test_complex_mul_then_real_linear(): + """Complex RoPE rotation followed by a real linear layer.""" + model = ComplexMulThenRealLinear().eval().cuda() + x = torch.randn(2, 8, 64, device="cuda") + freqs = _make_freqs(8, 64) + trt_model = torchtrt.compile(model, inputs=(x, freqs), **_COMPILE) + py_out = model(x, freqs) + trt_out = trt_model(x, freqs) + _cossim_real(py_out, trt_out, "complex_mul_then_real_linear") + torch._dynamo.reset() + + +class RealNormThenComplexMul(nn.Module): + """LayerNorm on the real input, then rotate with complex freqs.""" + + def __init__(self, d: int = 64) -> None: + super().__init__() + self.norm = nn.LayerNorm(d) + + def forward(self, x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + x = self.norm(x) + xc = torch.view_as_complex(x.reshape(*x.shape[:-1], -1, 2)) + return torch.view_as_real(xc * freqs).flatten(-2) + + +@pytest.mark.unit +def test_real_norm_then_complex_mul(): + """LayerNorm (real) → view_as_complex → complex mul → view_as_real.""" + model = RealNormThenComplexMul(d=64).eval().cuda() + x = torch.randn(2, 8, 64, device="cuda") + freqs = _make_freqs(8, 64) + trt_model = torchtrt.compile(model, inputs=(x, freqs), **_COMPILE) + py_out = model(x, freqs) + trt_out = trt_model(x, freqs) + _cossim_real(py_out, trt_out, "real_norm_then_complex_mul") + torch._dynamo.reset() + + +class ComplexMulThenRealActivation(nn.Module): + """Complex rotation → real view → GELU activation.""" + + def forward(self, x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + xc = torch.view_as_complex(x.reshape(*x.shape[:-1], -1, 2)) + real_out = torch.view_as_real(xc * freqs).flatten(-2) + return torch.nn.functional.gelu(real_out) + + +@pytest.mark.unit +def test_complex_mul_then_gelu(): + """Complex rotation followed by GELU on the real-valued output.""" + model = ComplexMulThenRealActivation().eval().cuda() + x = torch.randn(2, 8, 64, device="cuda") + freqs = _make_freqs(8, 64) + trt_model = torchtrt.compile(model, inputs=(x, freqs), **_COMPILE) + py_out = model(x, freqs) + trt_out = trt_model(x, freqs) + _cossim_real(py_out, trt_out, "complex_mul_then_gelu") + torch._dynamo.reset() + + +class RealScaleThenComplexAddSub(nn.Module): + """Scale two real tensors, pack as complex, do add and sub.""" + + def __init__(self) -> None: + super().__init__() + self.scale = nn.Parameter(torch.ones(1)) + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + # x, y: real (B, D, 2) — pack as complex + xa = x * self.scale + ya = y * self.scale + zx = torch.view_as_complex(xa) + zy = torch.view_as_complex(ya) + return torch.view_as_real(zx + zy - zx) + + +@pytest.mark.unit +def test_real_scale_then_complex_add_sub(): + """Real scale → pack as complex → add/sub → unpack.""" + model = RealScaleThenComplexAddSub().eval().cuda() + x = torch.randn(4, 16, 2, device="cuda") + y = torch.randn(4, 16, 2, device="cuda") + trt_model = torchtrt.compile(model, inputs=(x, y), **_COMPILE) + py_out = model(x, y) + trt_out = trt_model(x, y) + _cossim_real(py_out, trt_out, "real_scale_then_complex_add_sub") + torch._dynamo.reset() + + +class ComplexMagPhaseRecompose(nn.Module): + """Decompose into magnitude + phase, apply real ops to each, recompose.""" + + def __init__(self) -> None: + super().__init__() + self.mag_scale = nn.Parameter(torch.ones(1)) + + def forward(self, z: torch.Tensor) -> torch.Tensor: + mag = torch.abs(z) + phase = torch.angle(z) + mag2 = mag * self.mag_scale.abs() + phase2 = torch.clamp(phase, -1.5, 1.5) + return torch.polar(mag2, phase2) + + +@pytest.mark.unit +def test_complex_mag_phase_recompose(): + """Decompose z → (|z|, angle) → scale+clip → polar recompose.""" + model = ComplexMagPhaseRecompose().eval().cuda() + r = torch.rand(8, 16, device="cuda") + 0.3 + theta = torch.rand(8, 16, device="cuda") * 2.0 - 1.0 + z = torch.polar(r, theta) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_mag_phase_recompose") + torch._dynamo.reset() + + +class ComplexResidual(nn.Module): + """Complex residual: z + exp(log(z)).""" + + def forward(self, z: torch.Tensor) -> torch.Tensor: + return z + torch.exp(torch.log(z)) + + +@pytest.mark.unit +def test_complex_residual(): + """z + exp(log(z)) ≈ 2z — residual connection through complex ops.""" + model = ComplexResidual().eval().cuda() + r = torch.rand(8, 16, device="cuda") + 0.5 + theta = torch.rand(8, 16, device="cuda") * 1.5 + z = torch.polar(r, theta) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_residual") + torch._dynamo.reset() + + +class ComplexGatedMul(nn.Module): + """Real sigmoid gate applied to a complex tensor.""" + + def __init__(self) -> None: + super().__init__() + self.gate_proj = nn.Linear(32, 16, bias=False) + + def forward(self, x: torch.Tensor, z: torch.Tensor) -> torch.Tensor: + gate = torch.sigmoid(self.gate_proj(x)) + gate_c = torch.view_as_complex( + torch.stack([gate, torch.zeros_like(gate)], dim=-1) + ) + return z * gate_c + + +@pytest.mark.unit +def test_complex_gated_mul(): + """Real sigmoid gate × complex tensor — real and complex subgraphs in one model.""" + model = ComplexGatedMul().eval().cuda() + x = torch.randn(4, 32, device="cuda") + z = torch.polar( + torch.rand(4, 16, device="cuda") + 0.3, + torch.rand(4, 16, device="cuda") * 2 * 3.14159, + ) + trt_model = torchtrt.compile(model, inputs=(x, z), **_COMPILE) + _cossim_complex(model(x, z), trt_model(x, z), "complex_gated_mul") + torch._dynamo.reset() + + +# =========================================================================== +# 20. Multi-layer and branching subgraph integration tests +# =========================================================================== + + +class MultiHeadRoPE(nn.Module): + """Apply independent RoPE rotations to Q, K, V and compute attention logits.""" + + def __init__(self, seq: int = 8, dim: int = 32) -> None: + super().__init__() + self.freq_q = nn.Parameter(_make_freqs(seq, dim).detach()) + self.freq_k = nn.Parameter(_make_freqs(seq, dim).detach()) + self.freq_v = nn.Parameter(_make_freqs(seq, dim).detach()) + + def forward(self, q, k, v): + def rope(x, freq): + xc = torch.view_as_complex(x.reshape(*x.shape[:-1], -1, 2)) + return torch.view_as_real(xc * freq).flatten(-2) + + q_r = rope(q, self.freq_q) + k_r = rope(k, self.freq_k) + v_r = rope(v, self.freq_v) + scores = torch.bmm(q_r, k_r.transpose(1, 2)) / (q_r.shape[-1] ** 0.5) + return torch.bmm(scores, v_r) + + +@pytest.mark.unit +def test_multi_head_rope(): + """Q/K/V independently rotated by RoPE, then bmm attention — 3 complex subgraphs.""" + model = MultiHeadRoPE(seq=8, dim=32).eval().cuda() + B, S, D = 2, 8, 32 + q = torch.randn(B, S, D, device="cuda") + k = torch.randn(B, S, D, device="cuda") + v = torch.randn(B, S, D, device="cuda") + trt_model = torchtrt.compile(model, inputs=(q, k, v), **_COMPILE) + _cossim_real(model(q, k, v), trt_model(q, k, v), "multi_head_rope") + torch._dynamo.reset() + + +class ParallelComplexBranches(nn.Module): + """One complex input forks into two independent rotation paths, then concat + project.""" + + def __init__(self, dim: int = 16) -> None: + super().__init__() + self.freq_a = nn.Parameter(_make_freqs(8, dim * 2).detach()) + self.freq_b = nn.Parameter(_make_freqs(8, dim * 2).detach()) + self.proj = nn.Linear(dim * 4, dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + z = torch.view_as_complex(x.reshape(*x.shape[:-1], -1, 2)) + real_a = torch.view_as_real(z * self.freq_a).flatten(-2) + real_b = torch.view_as_real(z * self.freq_b).flatten(-2) + return self.proj(torch.cat([real_a, real_b], dim=-1)) + + +@pytest.mark.unit +def test_parallel_complex_branches(): + """One complex input forks into two rotation paths, concat, then project.""" + model = ParallelComplexBranches(dim=16).eval().cuda() + x = torch.randn(2, 8, 32, device="cuda") + trt_model = torchtrt.compile(model, inputs=(x,), **_COMPILE) + _cossim_real(model(x), trt_model(x), "parallel_complex_branches") + torch._dynamo.reset() + + +class TransformerLikeBlock(nn.Module): + """One layer: RoPE rotation + real FFN with residual.""" + + def __init__(self, d: int = 32) -> None: + super().__init__() + self.freq = nn.Parameter(_make_freqs(8, d).detach()) + self.norm = nn.LayerNorm(d) + self.ff1 = nn.Linear(d, d * 2, bias=False) + self.ff2 = nn.Linear(d * 2, d, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + xc = torch.view_as_complex(x.reshape(*x.shape[:-1], -1, 2)) + rotated = torch.view_as_real(xc * self.freq).flatten(-2) + h = self.norm(rotated) + h = torch.nn.functional.gelu(self.ff1(h)) + h = self.ff2(h) + return rotated + h + + +class StackedTransformerBlocks(nn.Module): + """Two sequential transformer-like blocks, each with complex RoPE.""" + + def __init__(self, d: int = 32) -> None: + super().__init__() + self.block1 = TransformerLikeBlock(d) + self.block2 = TransformerLikeBlock(d) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.block2(self.block1(x)) + + +@pytest.mark.unit +def test_stacked_transformer_blocks(): + """Two stacked transformer-like blocks, each containing a complex RoPE sub-graph.""" + model = StackedTransformerBlocks(d=32).eval().cuda() + x = torch.randn(2, 8, 32, device="cuda") + trt_model = torchtrt.compile(model, inputs=(x,), **_COMPILE) + _cossim_real(model(x), trt_model(x), "stacked_transformer_blocks") + torch._dynamo.reset() + + +class FourComplexInputsMulAdd(nn.Module): + """z1*z2 + z3*z4 — four distinct complex runtime inputs.""" + + def forward(self, z1, z2, z3, z4): + return z1 * z2 + z3 * z4 + + +@pytest.mark.unit +def test_four_complex_inputs_mul_add(): + """z1*z2 + z3*z4 — four complex runtime inputs, two muls and one add.""" + model = FourComplexInputsMulAdd().eval().cuda() + + def _rc(shape): + return torch.polar( + torch.rand(*shape, device="cuda") * 0.8 + 0.2, + torch.rand(*shape, device="cuda") * 2 * 3.14159, + ) + + z1, z2, z3, z4 = [_rc((4, 16)) for _ in range(4)] + trt_model = torchtrt.compile(model, inputs=(z1, z2, z3, z4), **_COMPILE) + _cossim_complex( + model(z1, z2, z3, z4), trt_model(z1, z2, z3, z4), "four_complex_inputs_mul_add" + ) + torch._dynamo.reset() + + +class CrossAttentionComplexQ(nn.Module): + """Cross-attention: complex-rotated queries, real key/value projections.""" + + def __init__(self, d_q: int = 32, d_kv: int = 64) -> None: + super().__init__() + self.freq = nn.Parameter(_make_freqs(8, d_q).detach()) + self.norm_q = nn.LayerNorm(d_q) + self.Wk = nn.Linear(d_kv, d_q, bias=False) + self.Wv = nn.Linear(d_kv, d_q, bias=False) + + def forward(self, q_real, kv): + qc = torch.view_as_complex(q_real.reshape(*q_real.shape[:-1], -1, 2)) + q = self.norm_q(torch.view_as_real(qc * self.freq).flatten(-2)) + k = self.Wk(kv) + v = self.Wv(kv) + scores = torch.bmm(q, k.transpose(1, 2)) / (q.shape[-1] ** 0.5) + return torch.bmm(scores, v) + + +@pytest.mark.unit +def test_cross_attention_complex_q(): + """Cross-attention: complex-rotated query, real key/value projections.""" + model = CrossAttentionComplexQ(d_q=32, d_kv=64).eval().cuda() + q_real = torch.randn(2, 8, 32, device="cuda") + kv = torch.randn(2, 12, 64, device="cuda") + trt_model = torchtrt.compile(model, inputs=(q_real, kv), **_COMPILE) + _cossim_real(model(q_real, kv), trt_model(q_real, kv), "cross_attention_complex_q") + torch._dynamo.reset() + + +class ComplexRotator(nn.Module): + """Single complex rotation layer wrapping a learnable frequency buffer.""" + + def __init__(self, seq: int = 8, dim: int = 32) -> None: + super().__init__() + self.freq = nn.Parameter(_make_freqs(seq, dim).detach()) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + xc = torch.view_as_complex(x.reshape(*x.shape[:-1], -1, 2)) + return torch.view_as_real(xc * self.freq).flatten(-2) + + +class NestedComplexRotators(nn.Module): + """Two ComplexRotator sub-modules with a real LayerNorm between them.""" + + def __init__(self, d: int = 32) -> None: + super().__init__() + self.rot1 = ComplexRotator(seq=8, dim=d) + self.norm = nn.LayerNorm(d) + self.rot2 = ComplexRotator(seq=8, dim=d) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.rot1(x) + x = self.norm(x) + return self.rot2(x) + + +@pytest.mark.unit +def test_nested_complex_rotators(): + """Two nested ComplexRotator sub-modules with a real LayerNorm between them.""" + model = NestedComplexRotators(d=32).eval().cuda() + x = torch.randn(2, 8, 32, device="cuda") + trt_model = torchtrt.compile(model, inputs=(x,), **_COMPILE) + _cossim_real(model(x), trt_model(x), "nested_complex_rotators") + torch._dynamo.reset() + + +class ComplexNormThenProject(nn.Module): + """abs(z) → LayerNorm → rescale z: real and complex subgraphs share an edge.""" + + def __init__(self, dim: int = 16) -> None: + super().__init__() + self.norm = nn.LayerNorm(dim) + + def forward(self, z: torch.Tensor) -> torch.Tensor: + mag = torch.abs(z) + scale = self.norm(mag) + scale_c = torch.view_as_complex( + torch.stack([scale, torch.zeros_like(scale)], dim=-1) + ) + return z * scale_c + + +@pytest.mark.unit +def test_complex_norm_then_project(): + """abs(z) → LayerNorm → rescale z: real and complex subgraphs share an edge.""" + model = ComplexNormThenProject(dim=16).eval().cuda() + z = torch.polar( + torch.rand(4, 16, device="cuda") * 0.8 + 0.2, + torch.rand(4, 16, device="cuda") * 2 * 3.14159, + ) + trt_model = torchtrt.compile(model, inputs=(z,), **_COMPILE) + _cossim_complex(model(z), trt_model(z), "complex_norm_then_project") + torch._dynamo.reset() + + +class ComplexRotateProject(nn.Module): + """Two complex rotations separated by a real linear layer.""" + + def __init__(self, d: int = 32) -> None: + super().__init__() + self.freq1 = nn.Parameter(_make_freqs(8, d).detach()) + self.freq2 = nn.Parameter(_make_freqs(8, d).detach()) + self.proj = nn.Linear(d, d, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + xc = torch.view_as_complex(x.reshape(*x.shape[:-1], -1, 2)) + r1 = torch.view_as_real(xc * self.freq1).flatten(-2) + r2 = self.proj(r1) + xc2 = torch.view_as_complex(r2.reshape(*r2.shape[:-1], -1, 2)) + return torch.view_as_real(xc2 * self.freq2).flatten(-2) + + +@pytest.mark.unit +def test_complex_rotate_project(): + """Two complex rotations separated by a real linear layer.""" + model = ComplexRotateProject(d=32).eval().cuda() + x = torch.randn(2, 8, 32, device="cuda") + trt_model = torchtrt.compile(model, inputs=(x,), **_COMPILE) + _cossim_real(model(x), trt_model(x), "complex_rotate_project") + torch._dynamo.reset() diff --git a/tests/py/dynamo/hlo/test_rope_embedding.py b/tests/py/dynamo/hlo/test_rope_embedding.py new file mode 100644 index 0000000000..bbd76c7a92 --- /dev/null +++ b/tests/py/dynamo/hlo/test_rope_embedding.py @@ -0,0 +1,526 @@ +""" +Tests for Rotary Position Embedding (RoPE) compilation with torch-tensorrt. + +RoPE is a critical subgraph used in modern LLMs (LLaMA, Qwen, Mistral, etc.). +Two common forms are tested: + 1. HuggingFace-style: rotate_half + apply_rotary_pos_emb using cos/sin tensors + 2. Complex-number form: view_as_complex + complex multiply + view_as_real + +Both static and dynamic shapes (varying seq_len, batch) are covered, as well as +RoPE embedded inside a larger attention block (a common failure mode). +""" + +import os + +import pytest +import torch +import torch.nn as nn +import torch_tensorrt as torchtrt +from torch.export import Dim +from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity + +# --------------------------------------------------------------------------- +# Shared helper modules +# --------------------------------------------------------------------------- + + +class HFRotaryEmbedding(nn.Module): + """HuggingFace-style RoPE as used in LLaMA / Qwen / Mistral. + + Identical to ``apply_rotary_pos_emb`` in transformers: + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + """ + + def rotate_half(self, x: torch.Tensor) -> torch.Tensor: + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ): + # cos/sin shape: (batch, seq_len, head_dim) – unsqueeze head dim + cos = cos.unsqueeze(1) # (batch, 1, seq_len, head_dim) + sin = sin.unsqueeze(1) + q_embed = (q * cos) + (self.rotate_half(q) * sin) + k_embed = (k * cos) + (self.rotate_half(k) * sin) + return q_embed, k_embed + + +class ComplexRotaryEmbedding(nn.Module): + """Complex-number RoPE as used in original LLaMA / Meta models. + + Applies pre-computed complex frequency tensor via: + x_complex = view_as_complex(x.reshape(..., -1, 2)) + out = view_as_real(x_complex * freqs_cis).flatten(-2) + """ + + def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + # x: (batch, seq_len, n_heads, head_dim) + # freqs_cis: (seq_len, head_dim // 2) complex + x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis[None, :, None, :] # broadcast over batch and heads + x_out = torch.view_as_real(x_ * freqs_cis).flatten(3) + return x_out.type_as(x) + + +def _make_freqs_cis( + seq_len: int, head_dim: int, theta: float = 10000.0 +) -> torch.Tensor: + """Pre-compute complex frequency tensor on CUDA.""" + freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2).float() / head_dim)) + t = torch.arange(seq_len, dtype=torch.float32) + freqs = torch.outer(t, freqs) + return torch.polar(torch.ones_like(freqs), freqs).cuda() + + +# --------------------------------------------------------------------------- +# Test 1: HuggingFace-style RoPE – static shapes +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +def test_rope_hf_style_static(): + """HF rotate_half RoPE compiles and produces correct outputs (static shapes).""" + model = HFRotaryEmbedding().eval().cuda() + + q = torch.randn(1, 12, 5, 128, dtype=torch.float32).cuda() + k = torch.randn(1, 12, 5, 128, dtype=torch.float32).cuda() + # cos/sin: (batch, seq_len, head_dim) + cos = torch.randn(1, 5, 128, dtype=torch.float32).cuda() + sin = torch.randn(1, 5, 128, dtype=torch.float32).cuda() + inputs = (q, k, cos, sin) + + compile_spec = { + "inputs": inputs, + "ir": "dynamo", + "min_block_size": 1, + "pass_through_build_failures": True, + } + trt_model = torchtrt.compile(model, **compile_spec) + + py_q, py_k = model(*inputs) + trt_q, trt_k = trt_model(*inputs) + + cos_sim_q = cosine_similarity(py_q, trt_q) + cos_sim_k = cosine_similarity(py_k, trt_k) + assert cos_sim_q > COSINE_THRESHOLD, ( + f"test_rope_hf_style_static: q outputs differ. " + f"Cosine sim: {cos_sim_q:.4f} < threshold {COSINE_THRESHOLD}" + ) + assert cos_sim_k > COSINE_THRESHOLD, ( + f"test_rope_hf_style_static: k outputs differ. " + f"Cosine sim: {cos_sim_k:.4f} < threshold {COSINE_THRESHOLD}" + ) + torch._dynamo.reset() + + +# --------------------------------------------------------------------------- +# Test 2: HuggingFace-style RoPE – dynamic seq_len +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +def test_rope_hf_style_dynamic(): + """HF rotate_half RoPE compiles and produces correct outputs (dynamic seq_len).""" + model = HFRotaryEmbedding().eval().cuda() + + q = torch.randn(1, 12, 5, 128, dtype=torch.float32).cuda() + k = torch.randn(1, 12, 5, 128, dtype=torch.float32).cuda() + cos = torch.randn(1, 5, 128, dtype=torch.float32).cuda() + sin = torch.randn(1, 5, 128, dtype=torch.float32).cuda() + inputs = (q, k, cos, sin) + + seq_len = Dim("seq_len", min=2, max=2048) + # q/k: (batch, n_heads, seq_len, head_dim) – seq_len is dim 2 + # cos/sin: (batch, seq_len, head_dim) – seq_len is dim 1 + dynamic_shapes = ( + {2: seq_len}, + {2: seq_len}, + {1: seq_len}, + {1: seq_len}, + ) + exp_program = torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes) + + compile_spec = { + "inputs": inputs, + "min_block_size": 1, + "pass_through_build_failures": True, + } + trt_model = torchtrt.dynamo.compile(exp_program, **compile_spec) + + py_q, py_k = model(*inputs) + trt_q, trt_k = trt_model(*inputs) + + cos_sim_q = cosine_similarity(py_q, trt_q) + cos_sim_k = cosine_similarity(py_k, trt_k) + assert cos_sim_q > COSINE_THRESHOLD, ( + f"test_rope_hf_style_dynamic: q outputs differ. " + f"Cosine sim: {cos_sim_q:.4f} < threshold {COSINE_THRESHOLD}" + ) + assert cos_sim_k > COSINE_THRESHOLD, ( + f"test_rope_hf_style_dynamic: k outputs differ. " + f"Cosine sim: {cos_sim_k:.4f} < threshold {COSINE_THRESHOLD}" + ) + torch._dynamo.reset() + + +# --------------------------------------------------------------------------- +# Test 3: Complex-number RoPE – static shapes +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +def test_rope_complex_form_static(): + """Complex (view_as_complex/view_as_real) RoPE compiles correctly (static shapes).""" + BATCH, SEQ_LEN, N_HEADS, HEAD_DIM = 2, 8, 4, 64 + model = ComplexRotaryEmbedding().eval().cuda() + + x = torch.randn(BATCH, SEQ_LEN, N_HEADS, HEAD_DIM, dtype=torch.float32).cuda() + freqs_cis = _make_freqs_cis(SEQ_LEN, HEAD_DIM) + inputs = (x, freqs_cis) + + compile_spec = { + "inputs": inputs, + "ir": "dynamo", + "min_block_size": 1, + "pass_through_build_failures": True, + } + trt_model = torchtrt.compile(model, **compile_spec) + + py_out = model(*inputs) + trt_out = trt_model(*inputs) + + cos_sim = cosine_similarity(py_out, trt_out) + assert cos_sim > COSINE_THRESHOLD, ( + f"test_rope_complex_form_static: outputs differ. " + f"Cosine sim: {cos_sim:.4f} < threshold {COSINE_THRESHOLD}" + ) + torch._dynamo.reset() + + +# --------------------------------------------------------------------------- +# Test 4: Complex-number RoPE – dynamic batch and seq_len +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +def test_rope_complex_form_dynamic(): + """Complex RoPE compiles correctly with dynamic batch and seq_len.""" + BATCH, SEQ_LEN, N_HEADS, HEAD_DIM = 2, 8, 4, 64 + model = ComplexRotaryEmbedding().eval().cuda() + + x = torch.randn(BATCH, SEQ_LEN, N_HEADS, HEAD_DIM, dtype=torch.float32).cuda() + freqs_cis = _make_freqs_cis(SEQ_LEN, HEAD_DIM) + inputs = (x, freqs_cis) + + batch = Dim("batch", min=1, max=4) + seq_len = Dim("seq_len", min=2, max=512) + # x: (batch, seq_len, n_heads, head_dim) + # freqs_cis: (seq_len, head_dim//2) complex – dim 0 is seq_len + dynamic_shapes = ( + {0: batch, 1: seq_len}, + {0: seq_len}, + ) + exp_program = torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes) + + compile_spec = { + "inputs": inputs, + "min_block_size": 1, + "pass_through_build_failures": True, + } + trt_model = torchtrt.dynamo.compile(exp_program, **compile_spec) + + py_out = model(*inputs) + trt_out = trt_model(*inputs) + + cos_sim = cosine_similarity(py_out, trt_out) + assert cos_sim > COSINE_THRESHOLD, ( + f"test_rope_complex_form_dynamic: outputs differ. " + f"Cosine sim: {cos_sim:.4f} < threshold {COSINE_THRESHOLD}" + ) + torch._dynamo.reset() + + +# --------------------------------------------------------------------------- +# Test 5: RoPE embedded inside an attention block – static shapes +# --------------------------------------------------------------------------- + + +class AttentionWithRoPE(nn.Module): + """Minimal self-attention block with HF-style RoPE, as found in LLaMA/Qwen. + + This exercises RoPE inside a larger graph—a common failure mode where + the shape inference for cos/sin unsqueeze interacts with the projection + output shapes. + """ + + def __init__(self, embed_dim: int = 64, n_heads: int = 4): + super().__init__() + self.n_heads = n_heads + self.head_dim = embed_dim // n_heads + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False) + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False) + self.o_proj = nn.Linear(embed_dim, embed_dim, bias=False) + + def rotate_half(self, x: torch.Tensor) -> torch.Tensor: + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rotary_pos_emb( + self, + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ): + cos = cos.unsqueeze(1) # add head dim + sin = sin.unsqueeze(1) + return (q * cos) + (self.rotate_half(q) * sin), (k * cos) + ( + self.rotate_half(k) * sin + ) + + def forward( + self, + hidden_states: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ) -> torch.Tensor: + batch, seq_len, _ = hidden_states.shape + q = ( + self.q_proj(hidden_states) + .view(batch, seq_len, self.n_heads, self.head_dim) + .transpose(1, 2) + ) + k = ( + self.k_proj(hidden_states) + .view(batch, seq_len, self.n_heads, self.head_dim) + .transpose(1, 2) + ) + v = ( + self.v_proj(hidden_states) + .view(batch, seq_len, self.n_heads, self.head_dim) + .transpose(1, 2) + ) + + q, k = self.apply_rotary_pos_emb(q, k, cos, sin) + + attn_out = torch.nn.functional.scaled_dot_product_attention( + q, k, v, is_causal=True + ) + attn_out = attn_out.transpose(1, 2).reshape(batch, seq_len, -1) + return self.o_proj(attn_out) + + +@pytest.mark.unit +def test_rope_in_attention_block_static(): + """RoPE inside a full attention block compiles correctly (static shapes).""" + EMBED_DIM, N_HEADS, BATCH, SEQ_LEN = 64, 4, 2, 16 + HEAD_DIM = EMBED_DIM // N_HEADS + + model = AttentionWithRoPE(EMBED_DIM, N_HEADS).eval().cuda() + + hidden = torch.randn(BATCH, SEQ_LEN, EMBED_DIM, dtype=torch.float32).cuda() + cos = torch.randn(BATCH, SEQ_LEN, HEAD_DIM, dtype=torch.float32).cuda() + sin = torch.randn(BATCH, SEQ_LEN, HEAD_DIM, dtype=torch.float32).cuda() + inputs = (hidden, cos, sin) + + compile_spec = { + "inputs": inputs, + "ir": "dynamo", + "min_block_size": 1, + "pass_through_build_failures": True, + } + trt_model = torchtrt.compile(model, **compile_spec) + + py_out = model(*inputs) + trt_out = trt_model(*inputs) + + cos_sim = cosine_similarity(py_out, trt_out) + assert cos_sim > COSINE_THRESHOLD, ( + f"test_rope_in_attention_block_static: outputs differ. " + f"Cosine sim: {cos_sim:.4f} < threshold {COSINE_THRESHOLD}" + ) + torch._dynamo.reset() + + +# --------------------------------------------------------------------------- +# Test 6: RoPE embedded inside an attention block – dynamic seq_len +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +def test_rope_in_attention_block_dynamic(): + """RoPE inside a full attention block compiles correctly (dynamic seq_len).""" + EMBED_DIM, N_HEADS, BATCH, SEQ_LEN = 64, 4, 2, 16 + HEAD_DIM = EMBED_DIM // N_HEADS + + model = AttentionWithRoPE(EMBED_DIM, N_HEADS).eval().cuda() + + hidden = torch.randn(BATCH, SEQ_LEN, EMBED_DIM, dtype=torch.float32).cuda() + cos = torch.randn(BATCH, SEQ_LEN, HEAD_DIM, dtype=torch.float32).cuda() + sin = torch.randn(BATCH, SEQ_LEN, HEAD_DIM, dtype=torch.float32).cuda() + inputs = (hidden, cos, sin) + + seq_len = Dim("seq_len", min=2, max=2048) + # hidden: (batch, seq_len, embed_dim) – seq_len is dim 1 + # cos/sin: (batch, seq_len, head_dim) – seq_len is dim 1 + dynamic_shapes = ( + {1: seq_len}, + {1: seq_len}, + {1: seq_len}, + ) + exp_program = torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes) + + compile_spec = { + "inputs": inputs, + "min_block_size": 1, + "pass_through_build_failures": True, + } + trt_model = torchtrt.dynamo.compile(exp_program, **compile_spec) + + py_out = model(*inputs) + trt_out = trt_model(*inputs) + + cos_sim = cosine_similarity(py_out, trt_out) + assert cos_sim > COSINE_THRESHOLD, ( + f"test_rope_in_attention_block_dynamic: outputs differ. " + f"Cosine sim: {cos_sim:.4f} < threshold {COSINE_THRESHOLD}" + ) + torch._dynamo.reset() + + +# --------------------------------------------------------------------------- +# Test 7: Complex RoPE – serialization with retrace=True then inference +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +def test_rope_complex_form_serialization_retrace(tmp_path): + """Complex RoPE survives save(retrace=True) + load + inference round-trip. + + When retrace=True, torch_tensorrt.save re-exports the compiled GraphModule + via torch.export.export (strict=False), inlining the view_as_real unpacking + ops that live in the Python runtime forward(). The reloaded ExportedProgram + must accept the original complex inputs and produce correct results. + """ + BATCH, SEQ_LEN, N_HEADS, HEAD_DIM = 2, 8, 4, 64 + model = ComplexRotaryEmbedding().eval().cuda() + + x = torch.randn(BATCH, SEQ_LEN, N_HEADS, HEAD_DIM, dtype=torch.float32).cuda() + freqs_cis = _make_freqs_cis(SEQ_LEN, HEAD_DIM) + inputs = (x, freqs_cis) + + # Step 1: compile + compile_spec = { + "inputs": inputs, + "ir": "dynamo", + "min_block_size": 1, + "pass_through_build_failures": True, + } + trt_model = torchtrt.compile(model, **compile_spec) + + py_out = model(*inputs) + trt_out_before = trt_model(*inputs) + + cos_sim_before = cosine_similarity(py_out, trt_out_before) + assert cos_sim_before > COSINE_THRESHOLD, ( + f"test_rope_complex_form_serialization_retrace: pre-save TRT output wrong. " + f"Cosine sim: {cos_sim_before:.4f} < threshold {COSINE_THRESHOLD}" + ) + + # Step 2: save with retrace=True — re-exports the compiled GraphModule so + # the view_as_real input-unpacking is inlined into the exported graph. + ep_path = str(tmp_path / "rope_complex_trt.ep") + torchtrt.save( + trt_model, + ep_path, + output_format="exported_program", + arg_inputs=list(inputs), + retrace=True, + ) + assert os.path.exists(ep_path), "Serialized .ep file was not created" + + # Step 3: reload + loaded_ep = torchtrt.load(ep_path) + # torch_tensorrt.load returns ExportedProgram; call .module() to get the + # callable GraphModule. + loaded_module = loaded_ep.module() + + # Step 4: inference on reloaded model + trt_out_after = loaded_module(*inputs) + + cos_sim_after = cosine_similarity(py_out, trt_out_after) + assert cos_sim_after > COSINE_THRESHOLD, ( + f"test_rope_complex_form_serialization_retrace: post-load TRT output wrong. " + f"Cosine sim: {cos_sim_after:.4f} < threshold {COSINE_THRESHOLD}" + ) + torch._dynamo.reset() + + +# --------------------------------------------------------------------------- +# Test 8: Complex output – model whose output is a complex tensor +# --------------------------------------------------------------------------- + + +class ComplexOutputModel(nn.Module): + """A model that outputs a complex tensor. + + This exercises the post-partition complex output restoration pass: + complex_graph_detection rewrites the internal complex ops to real + arithmetic before partitioning, and the compiler must re-insert + view_as_complex at the output boundary when the tail block is TRT. + """ + + def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + # x: (batch, seq_len, n_heads, head_dim) – real + # freqs_cis: (seq_len, head_dim // 2) – complex + # Returns: complex tensor of shape (batch, seq_len, n_heads, head_dim // 2) + x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis[None, :, None, :] + return x_ * freqs_cis # complex output – no view_as_real + + +@pytest.mark.unit +def test_complex_output_static(): + """Model with a complex tensor output compiles and produces correct results.""" + model = ComplexOutputModel().eval().cuda() + + x = torch.randn(1, 4, 8, 64, dtype=torch.float32).cuda() + freqs_cis = _make_freqs_cis(4, 64) # shape (4, 32), complex64 + inputs = (x, freqs_cis) + + compile_spec = { + "inputs": inputs, + "ir": "dynamo", + "min_block_size": 1, + "pass_through_build_failures": True, + } + trt_model = torchtrt.compile(model, **compile_spec) + + py_out = model(*inputs) + trt_out = trt_model(*inputs) + + assert ( + trt_out.is_complex() + ), f"test_complex_output_static: TRT output should be complex, got dtype {trt_out.dtype}" + assert ( + trt_out.shape == py_out.shape + ), f"test_complex_output_static: shape mismatch {trt_out.shape} vs {py_out.shape}" + # Compare real and imaginary parts via cosine similarity + cos_sim_real = cosine_similarity(py_out.real, trt_out.real) + cos_sim_imag = cosine_similarity(py_out.imag, trt_out.imag) + assert ( + cos_sim_real > COSINE_THRESHOLD + ), f"test_complex_output_static: real part cosine sim {cos_sim_real:.4f} < {COSINE_THRESHOLD}" + assert ( + cos_sim_imag > COSINE_THRESHOLD + ), f"test_complex_output_static: imag part cosine sim {cos_sim_imag:.4f} < {COSINE_THRESHOLD}" + torch._dynamo.reset() diff --git a/tests/py/dynamo/lowering/test_complex_rewrite.py b/tests/py/dynamo/lowering/test_complex_rewrite.py new file mode 100644 index 0000000000..3d82d7d87b --- /dev/null +++ b/tests/py/dynamo/lowering/test_complex_rewrite.py @@ -0,0 +1,1232 @@ +"""Comprehensive numerical-equivalence tests for complex_graph_detection lowering pass. + +Each test verifies: + lowered_gm(view_as_real(z)) ≡ original_model(z) (numerically) + +The lowering pass rewrites complex-dtype ops to real arithmetic on a [..., 2] +layout (trailing dim encodes real/imag). After lowering, all inputs and outputs +are in that real layout; the test harness converts back to complex before +comparison. + +Organisation +------------ + 1. Infrastructure helpers + 2. Elementwise arithmetic (mul / div / add / sub variants) + 3. Complex-specific ops (real, imag, conj, abs, angle, polar) + 4. Transcendental functions (exp, log, pow, sin/cos/tan …) + 5. Shape manipulation (permute, reshape/view, flatten, squeeze/unsqueeze, + cat, stack, select, slice, split, chunk, expand, + transpose, t, clone, narrow, roll, flip) + 6. Matrix multiplication (mm, bmm, matmul) + 7. Elementwise-safe pass-through verification + 8. Reduction ops (sum / mean — positive dims pass, negative = xfail) + 9. Creation-op bugs (ones_like → xfail, zeros_like → pass, full_like → xfail) +10. Chain / composition tests + +xfail tests document known bugs or missing handlers. They are expected to fail. +If they start passing a handler was fixed — remove the xfail marker. +""" + +from __future__ import annotations + +from typing import Any, Tuple + +import pytest +import torch +import torch.nn as nn + +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.lowering.passes.complex_graph_rewrite import ( + complex_graph_detection, +) + +# --------------------------------------------------------------------------- +# 1. Infrastructure +# --------------------------------------------------------------------------- + +_RTOL = 1e-4 +_ATOL = 1e-4 + + +def _export_and_lower( + model: nn.Module, example_inputs: Tuple[Any, ...] +) -> torch.fx.GraphModule: + """Export *model* and apply the complex_graph_detection lowering pass.""" + with torch.no_grad(): + exp = torch.export.export(model.eval(), example_inputs) + gm = exp.module() + complex_graph_detection(gm, CompilationSettings()) + return gm + + +def _real_inputs(inputs: Tuple[Any, ...]) -> Tuple[Any, ...]: + """Convert complex tensors to [..., 2] real layout.""" + return tuple( + ( + torch.view_as_real(x).contiguous() + if isinstance(x, torch.Tensor) and x.is_complex() + else x + ) + for x in inputs + ) + + +def _to_complex_if_needed(out: torch.Tensor, ref: torch.Tensor) -> torch.Tensor: + """Reinterpret *out* as complex if *ref* is complex and *out* has trailing dim 2.""" + if ref.is_complex() and not out.is_complex() and out.shape[-1] == 2: + return torch.view_as_complex(out.contiguous()) + return out + + +def _assert_close(ref: torch.Tensor, got: torch.Tensor, tag: str) -> None: + if ref.dtype == torch.bool: + assert torch.equal(got, ref), f"{tag}: bool tensor mismatch" + return + if ref.is_complex(): + assert got.is_complex(), f"{tag}: expected complex output, got {got.dtype}" + assert ref.shape == got.shape, f"{tag}: shape {got.shape} != {ref.shape}" + torch.testing.assert_close( + got.real.float(), ref.real.float(), rtol=_RTOL, atol=_ATOL + ) + torch.testing.assert_close( + got.imag.float(), ref.imag.float(), rtol=_RTOL, atol=_ATOL + ) + else: + assert not got.is_complex(), f"{tag}: expected real output, got {got.dtype}" + assert ref.shape == got.shape, f"{tag}: shape {got.shape} != {ref.shape}" + torch.testing.assert_close(got.float(), ref.float(), rtol=_RTOL, atol=_ATOL) + + +def _check_op(model: nn.Module, inputs: Tuple[Any, ...], tag: str) -> None: + """Full pipeline: run model → export+lower → compare.""" + with torch.no_grad(): + ref = model(*inputs) + + gm = _export_and_lower(model, inputs) + raw = gm(*_real_inputs(inputs)) + + if isinstance(raw, (list, tuple)): + ref_list = list(ref) if isinstance(ref, (list, tuple)) else [ref] + for i, (r, o) in enumerate(zip(ref_list, raw)): + got = _to_complex_if_needed(o, r) + _assert_close(r, got, f"{tag}[{i}]") + else: + got = _to_complex_if_needed(raw, ref) + _assert_close(ref, got, tag) + + +# Convenience: 2-D complex inputs used by most tests +def _z(rows: int = 3, cols: int = 4) -> torch.Tensor: + return torch.randn(rows, cols, dtype=torch.complex64) + + +def _z3d(b: int = 2, m: int = 3, n: int = 4) -> torch.Tensor: + return torch.randn(b, m, n, dtype=torch.complex64) + + +# =========================================================================== +# 2. Elementwise arithmetic +# =========================================================================== + + +@pytest.mark.unit +def test_mul_complex_complex(): + class M(nn.Module): + def forward(self, x, y): + return x * y + + z1, z2 = _z(), _z() + _check_op(M(), (z1, z2), "mul_cc") + + +@pytest.mark.unit +def test_mul_complex_real(): + """Complex × real tensor — only the complex part gets the mul handler.""" + + class M(nn.Module): + def forward(self, z, r): + return z * r # z complex, r real — result is complex + + z = _z() + r = torch.randn(3, 4) + _check_op(M(), (z, r), "mul_cr") + + +@pytest.mark.unit +def test_mul_scalar(): + """z * scalar — both re/im scaled equally (elementwise-safe).""" + + class M(nn.Module): + def forward(self, z): + return z * 3.0 + + _check_op(M(), (_z(),), "mul_scalar") + + +@pytest.mark.unit +def test_div_complex_complex(): + class M(nn.Module): + def forward(self, x, y): + return x / y + + _check_op(M(), (_z(), _z() + 0.1), "div_cc") + + +@pytest.mark.unit +def test_div_complex_scalar(): + class M(nn.Module): + def forward(self, z): + return z / 2.0 + + _check_op(M(), (_z(),), "div_cscalar") + + +@pytest.mark.unit +def test_div_scalar_complex(): + """scalar / complex — s/(a+bi) = (sa - sbi)/(a²+b²).""" + + class M(nn.Module): + def forward(self, z): + return 4.0 / (z + 0.1) + + _check_op(M(), (_z(),), "div_scalar_c") + + +@pytest.mark.unit +def test_add_tensor(): + """z1 + z2 — both complex; elementwise-safe (component-wise).""" + + class M(nn.Module): + def forward(self, x, y): + return x + y + + _check_op(M(), (_z(), _z()), "add_tensor") + + +@pytest.mark.unit +def test_sub_tensor(): + class M(nn.Module): + def forward(self, x, y): + return x - y + + _check_op(M(), (_z(), _z()), "sub_tensor") + + +@pytest.mark.unit +def test_add_scalar(): + """(a+bi) + s = (a+s) + bi — scalar added to real part only.""" + + class M(nn.Module): + def forward(self, z): + return z + 2.5 + + _check_op(M(), (_z(),), "add_scalar") + + +@pytest.mark.unit +def test_sub_scalar(): + class M(nn.Module): + def forward(self, z): + return z - 1.0 + + _check_op(M(), (_z(),), "sub_scalar") + + +@pytest.mark.unit +def test_neg(): + """Negation is elementwise-safe (flips sign of both re/im).""" + + class M(nn.Module): + def forward(self, z): + return -z + + _check_op(M(), (_z(),), "neg") + + +# =========================================================================== +# 3. Complex-specific ops +# =========================================================================== + + +@pytest.mark.unit +def test_real(): + """z.real → real tensor (select re component).""" + + class M(nn.Module): + def forward(self, z): + return z.real + + _check_op(M(), (_z(3, 5),), "real") # shape (3,5) so last dim≠2 + + +@pytest.mark.unit +def test_imag(): + class M(nn.Module): + def forward(self, z): + return z.imag + + _check_op(M(), (_z(3, 5),), "imag") + + +@pytest.mark.unit +def test_conj(): + """conj(a+bi) = a - bi.""" + + class M(nn.Module): + def forward(self, z): + return torch.conj(z) + + _check_op(M(), (_z(),), "conj") + + +@pytest.mark.unit +def test_abs(): + """|a+bi| = sqrt(a²+b²) — real output.""" + + class M(nn.Module): + def forward(self, z): + return torch.abs(z) + + _check_op(M(), (_z(3, 5),), "abs") + + +@pytest.mark.unit +def test_angle(): + """angle(a+bi) = atan2(b, a) — real output.""" + + class M(nn.Module): + def forward(self, z): + return torch.angle(z) + + _check_op(M(), (_z(3, 5),), "angle") + + +@pytest.mark.unit +def test_polar(): + """polar(r, theta) = r*cos(theta) + i*r*sin(theta).""" + + class M(nn.Module): + def forward(self, r, theta): + return torch.polar(r, theta) + + r = torch.rand(3, 4) + 0.1 + theta = torch.randn(3, 4) + _check_op(M(), (r, theta), "polar") + + +# =========================================================================== +# 4. Transcendental functions +# =========================================================================== + + +@pytest.mark.unit +def test_exp(): + """exp(a+bi) = e^a*(cos(b) + i*sin(b)).""" + + class M(nn.Module): + def forward(self, z): + return torch.exp(z) + + _check_op(M(), (_z(),), "exp") + + +@pytest.mark.unit +def test_log(): + class M(nn.Module): + def forward(self, z): + return torch.log(z) + + _check_op(M(), (_z() + 0.1,), "log") + + +@pytest.mark.unit +def test_log2(): + class M(nn.Module): + def forward(self, z): + return torch.log2(z) + + _check_op(M(), (_z() + 0.1,), "log2") + + +@pytest.mark.unit +def test_log10(): + class M(nn.Module): + def forward(self, z): + return torch.log10(z) + + _check_op(M(), (_z() + 0.1,), "log10") + + +@pytest.mark.unit +def test_log1p(): + class M(nn.Module): + def forward(self, z): + return torch.log1p(z) + + _check_op(M(), (_z() + 0.1,), "log1p") + + +@pytest.mark.unit +def test_expm1(): + class M(nn.Module): + def forward(self, z): + return torch.expm1(z) + + # Use small values to keep numbers from overflowing + z = torch.randn(3, 4, dtype=torch.complex64) * 0.3 + _check_op(M(), (z,), "expm1") + + +@pytest.mark.unit +def test_sqrt(): + class M(nn.Module): + def forward(self, z): + return torch.sqrt(z) + + _check_op(M(), (_z(),), "sqrt") + + +@pytest.mark.unit +def test_pow_scalar(): + """z**n via polar form.""" + + class M(nn.Module): + def forward(self, z): + return z**2.0 + + _check_op(M(), (_z() + 0.1,), "pow_scalar") + + +@pytest.mark.unit +def test_pow_tensor(): + """z1**z2 = exp(z2 * log(z1)).""" + + class M(nn.Module): + def forward(self, x, y): + return x**y + + z1 = _z() + 0.5 + z2 = torch.randn(3, 4, dtype=torch.complex64) * 0.3 + _check_op(M(), (z1, z2), "pow_tensor") + + +@pytest.mark.unit +def test_sin(): + class M(nn.Module): + def forward(self, z): + return torch.sin(z) + + z = torch.randn(3, 4, dtype=torch.complex64) * 0.5 + _check_op(M(), (z,), "sin") + + +@pytest.mark.unit +def test_cos(): + class M(nn.Module): + def forward(self, z): + return torch.cos(z) + + z = torch.randn(3, 4, dtype=torch.complex64) * 0.5 + _check_op(M(), (z,), "cos") + + +@pytest.mark.unit +def test_tan(): + class M(nn.Module): + def forward(self, z): + return torch.tan(z) + + z = torch.randn(3, 4, dtype=torch.complex64) * 0.3 + _check_op(M(), (z,), "tan") + + +@pytest.mark.unit +def test_sinh(): + class M(nn.Module): + def forward(self, z): + return torch.sinh(z) + + z = torch.randn(3, 4, dtype=torch.complex64) * 0.5 + _check_op(M(), (z,), "sinh") + + +@pytest.mark.unit +def test_cosh(): + class M(nn.Module): + def forward(self, z): + return torch.cosh(z) + + z = torch.randn(3, 4, dtype=torch.complex64) * 0.5 + _check_op(M(), (z,), "cosh") + + +@pytest.mark.unit +def test_tanh(): + class M(nn.Module): + def forward(self, z): + return torch.tanh(z) + + z = torch.randn(3, 4, dtype=torch.complex64) * 0.5 + _check_op(M(), (z,), "tanh") + + +@pytest.mark.unit +def test_asin(): + class M(nn.Module): + def forward(self, z): + return torch.asin(z) + + z = torch.randn(3, 4, dtype=torch.complex64) * 0.5 + _check_op(M(), (z,), "asin") + + +@pytest.mark.unit +def test_acos(): + class M(nn.Module): + def forward(self, z): + return torch.acos(z) + + z = torch.randn(3, 4, dtype=torch.complex64) * 0.5 + _check_op(M(), (z,), "acos") + + +@pytest.mark.unit +def test_atan(): + class M(nn.Module): + def forward(self, z): + return torch.atan(z) + + z = torch.randn(3, 4, dtype=torch.complex64) * 0.5 + _check_op(M(), (z,), "atan") + + +@pytest.mark.unit +def test_asinh(): + class M(nn.Module): + def forward(self, z): + return torch.asinh(z) + + z = torch.randn(3, 4, dtype=torch.complex64) * 0.5 + _check_op(M(), (z,), "asinh") + + +@pytest.mark.unit +def test_acosh(): + class M(nn.Module): + def forward(self, z): + return torch.acosh(z) + + # acosh needs |z| > 1 to avoid NaN + z = torch.randn(3, 4, dtype=torch.complex64) + 2.0 + _check_op(M(), (z,), "acosh") + + +@pytest.mark.unit +def test_atanh(): + class M(nn.Module): + def forward(self, z): + return torch.atanh(z) + + z = torch.randn(3, 4, dtype=torch.complex64) * 0.3 + _check_op(M(), (z,), "atanh") + + +@pytest.mark.unit +def test_isnan(): + """isnan/isinf: boolean output, checks re|im.""" + + class M(nn.Module): + def forward(self, z): + return torch.isnan(z) + + _check_op(M(), (_z(3, 5),), "isnan") + + +@pytest.mark.unit +def test_isinf(): + class M(nn.Module): + def forward(self, z): + return torch.isinf(z) + + _check_op(M(), (_z(3, 5),), "isinf") + + +# =========================================================================== +# 5. Shape manipulation +# =========================================================================== + + +@pytest.mark.unit +def test_view_as_real_complex_bypass(): + """view_as_real → view_as_complex is a round-trip no-op after lowering.""" + + class M(nn.Module): + def forward(self, z): + r = torch.view_as_real(z) + return torch.view_as_complex(r) + + _check_op(M(), (_z(),), "var_vac_bypass") + + +@pytest.mark.unit +def test_permute(): + class M(nn.Module): + def forward(self, z): + return z.permute(1, 0) + + _check_op(M(), (_z(),), "permute_2d") + + +@pytest.mark.unit +def test_permute_3d(): + class M(nn.Module): + def forward(self, z): + return z.permute(0, 2, 1) + + _check_op(M(), (_z3d(),), "permute_3d") + + +@pytest.mark.unit +def test_reshape(): + class M(nn.Module): + def forward(self, z): + return z.reshape(12) + + _check_op(M(), (_z(),), "reshape") + + +@pytest.mark.unit +def test_reshape_batch(): + class M(nn.Module): + def forward(self, z): + return z.reshape(2, 6) + + _check_op(M(), (torch.randn(3, 4, dtype=torch.complex64),), "reshape_batch") + + +@pytest.mark.unit +def test_view(): + class M(nn.Module): + def forward(self, z): + return z.view(12) + + _check_op(M(), (torch.randn(3, 4, dtype=torch.complex64).contiguous(),), "view") + + +@pytest.mark.unit +def test_flatten_all(): + class M(nn.Module): + def forward(self, z): + return z.flatten() + + _check_op(M(), (_z3d(),), "flatten_all") + + +@pytest.mark.unit +def test_flatten_partial(): + class M(nn.Module): + def forward(self, z): + return z.flatten(1, -1) + + _check_op(M(), (_z3d(),), "flatten_partial") + + +@pytest.mark.unit +def test_flatten_start_neg(): + class M(nn.Module): + def forward(self, z): + return z.flatten(-2, -1) + + _check_op(M(), (_z3d(),), "flatten_neg_dims") + + +@pytest.mark.unit +def test_unsqueeze_pos(): + class M(nn.Module): + def forward(self, z): + return z.unsqueeze(0) + + _check_op(M(), (_z(),), "unsqueeze_pos") + + +@pytest.mark.unit +def test_unsqueeze_neg(): + class M(nn.Module): + def forward(self, z): + return z.unsqueeze(-1) + + _check_op(M(), (_z(),), "unsqueeze_neg") + + +@pytest.mark.unit +def test_unsqueeze_mid_neg(): + class M(nn.Module): + def forward(self, z): + return z.unsqueeze(-2) + + _check_op(M(), (_z3d(),), "unsqueeze_mid_neg") + + +@pytest.mark.unit +def test_squeeze_pos(): + class M(nn.Module): + def forward(self, z): + return z.squeeze(0) + + _check_op(M(), (torch.randn(1, 4, dtype=torch.complex64),), "squeeze_pos") + + +@pytest.mark.unit +def test_squeeze_neg(): + class M(nn.Module): + def forward(self, z): + return z.squeeze(-2) + + _check_op(M(), (torch.randn(3, 1, 4, dtype=torch.complex64),), "squeeze_neg") + + +@pytest.mark.unit +def test_squeeze_last_dim(): + """squeeze(dim=-1) removes the last *complex* dim (not real/imag encoding).""" + + class M(nn.Module): + def forward(self, z): + return z.squeeze(-1) + + _check_op(M(), (torch.randn(3, 1, dtype=torch.complex64),), "squeeze_last") + + +@pytest.mark.unit +def test_cat_dim0(): + class M(nn.Module): + def forward(self, x, y): + return torch.cat([x, y], dim=0) + + _check_op(M(), (_z(2, 4), _z(3, 4)), "cat_dim0") + + +@pytest.mark.unit +def test_cat_dim1(): + class M(nn.Module): + def forward(self, x, y): + return torch.cat([x, y], dim=1) + + _check_op(M(), (_z(3, 2), _z(3, 3)), "cat_dim1") + + +@pytest.mark.unit +def test_cat_neg_dim(): + """cat(tensors, dim=-1) on complex — must concat the last *complex* dim.""" + + class M(nn.Module): + def forward(self, x, y): + return torch.cat([x, y], dim=-1) + + _check_op(M(), (_z(3, 2), _z(3, 3)), "cat_neg_dim") + + +@pytest.mark.unit +def test_stack_dim0(): + class M(nn.Module): + def forward(self, x, y): + return torch.stack([x, y], dim=0) + + _check_op(M(), (_z(), _z()), "stack_dim0") + + +@pytest.mark.unit +def test_stack_neg_dim(): + """stack(tensors, dim=-1) — new dim must land before real/imag encoding.""" + + class M(nn.Module): + def forward(self, x, y): + return torch.stack([x, y], dim=-1) + + _check_op(M(), (_z(), _z()), "stack_neg_dim") + + +@pytest.mark.unit +def test_select_pos(): + class M(nn.Module): + def forward(self, z): + return z[1] + + _check_op(M(), (_z(),), "select_pos") + + +@pytest.mark.unit +def test_select_neg_dim(): + """select along the last complex dim (dim=-1).""" + + class M(nn.Module): + def forward(self, z): + return z.select(-1, 2) + + _check_op(M(), (_z(),), "select_neg_dim") + + +@pytest.mark.unit +def test_slice_pos(): + class M(nn.Module): + def forward(self, z): + return z[1:] + + _check_op(M(), (_z(),), "slice_pos") + + +@pytest.mark.unit +def test_slice_neg_dim(): + class M(nn.Module): + def forward(self, z): + return z[..., 1:3] + + _check_op(M(), (torch.randn(3, 6, dtype=torch.complex64),), "slice_neg_dim") + + +@pytest.mark.unit +def test_split(): + class M(nn.Module): + def forward(self, z): + a, b = z.split(2, dim=0) + return a + b + + _check_op(M(), (torch.randn(4, 3, dtype=torch.complex64),), "split_pos") + + +@pytest.mark.unit +def test_split_neg_dim(): + class M(nn.Module): + def forward(self, z): + a, b = z.split(2, dim=-1) + return a + b + + _check_op(M(), (torch.randn(3, 4, dtype=torch.complex64),), "split_neg") + + +@pytest.mark.unit +def test_chunk(): + class M(nn.Module): + def forward(self, z): + a, b = z.chunk(2, dim=0) + return a * b + + _check_op(M(), (torch.randn(4, 3, dtype=torch.complex64),), "chunk_pos") + + +@pytest.mark.unit +def test_transpose_2d(): + class M(nn.Module): + def forward(self, z): + return z.transpose(0, 1) + + _check_op(M(), (_z(),), "transpose_2d") + + +@pytest.mark.unit +def test_transpose_neg(): + class M(nn.Module): + def forward(self, z): + return z.transpose(-2, -1) + + _check_op(M(), (_z3d(),), "transpose_neg") + + +@pytest.mark.unit +def test_t_default(): + """t.default (2D transpose) is elementwise-safe.""" + + class M(nn.Module): + def forward(self, z): + return z.t() + + _check_op(M(), (_z(),), "t_default") + + +@pytest.mark.unit +def test_expand(): + class M(nn.Module): + def forward(self, z): + return z.expand(3, 4) + + _check_op(M(), (torch.randn(1, 4, dtype=torch.complex64),), "expand") + + +@pytest.mark.unit +def test_narrow_pos(): + """narrow along a non-negative dim — pass-through is correct.""" + + class M(nn.Module): + def forward(self, z): + return z.narrow(0, 1, 2) + + _check_op(M(), (_z(),), "narrow_pos") + + +@pytest.mark.unit +def test_narrow_neg_dim(): + class M(nn.Module): + def forward(self, z): + return z.narrow(-1, 1, 2) + + _check_op(M(), (torch.randn(3, 5, dtype=torch.complex64),), "narrow_neg_dim") + + +@pytest.mark.unit +def test_roll_pos(): + """roll along a positive dim — pass-through is correct.""" + + class M(nn.Module): + def forward(self, z): + return z.roll(1, 0) + + _check_op(M(), (_z(),), "roll_pos") + + +@pytest.mark.unit +def test_roll_neg_dim(): + class M(nn.Module): + def forward(self, z): + return z.roll(1, -1) + + _check_op(M(), (_z(),), "roll_neg_dim") + + +@pytest.mark.unit +def test_flip_pos(): + """flip along a positive dim — pass-through is correct.""" + + class M(nn.Module): + def forward(self, z): + return z.flip([0]) + + _check_op(M(), (_z(),), "flip_pos") + + +@pytest.mark.unit +def test_flip_neg_dim(): + class M(nn.Module): + def forward(self, z): + return z.flip([-1]) + + _check_op(M(), (_z(),), "flip_neg_dim") + + +@pytest.mark.unit +def test_repeat(): + class M(nn.Module): + def forward(self, z): + return z.repeat(2, 1) + + _check_op(M(), (_z(),), "repeat") + + +# =========================================================================== +# 6. Matrix multiplication +# =========================================================================== + + +@pytest.mark.unit +def test_mm(): + class M(nn.Module): + def forward(self, x, y): + return torch.mm(x, y) + + x = torch.randn(3, 4, dtype=torch.complex64) + y = torch.randn(4, 5, dtype=torch.complex64) + _check_op(M(), (x, y), "mm") + + +@pytest.mark.unit +def test_bmm(): + class M(nn.Module): + def forward(self, x, y): + return torch.bmm(x, y) + + x = torch.randn(2, 3, 4, dtype=torch.complex64) + y = torch.randn(2, 4, 5, dtype=torch.complex64) + _check_op(M(), (x, y), "bmm") + + +@pytest.mark.unit +def test_matmul_2d(): + class M(nn.Module): + def forward(self, x, y): + return x @ y + + x = torch.randn(3, 4, dtype=torch.complex64) + y = torch.randn(4, 5, dtype=torch.complex64) + _check_op(M(), (x, y), "matmul_2d") + + +@pytest.mark.unit +def test_matmul_3d(): + class M(nn.Module): + def forward(self, x, y): + return x @ y + + x = torch.randn(2, 3, 4, dtype=torch.complex64) + y = torch.randn(2, 4, 5, dtype=torch.complex64) + _check_op(M(), (x, y), "matmul_3d") + + +@pytest.mark.unit +def test_mm_self_multiply(): + """mm(z, z) — self-multiplication should use the same node twice correctly.""" + + class M(nn.Module): + def forward(self, z): + return torch.mm(z, z.t()) + + z = torch.randn(4, 4, dtype=torch.complex64) + _check_op(M(), (z,), "mm_self") + + +# =========================================================================== +# 7. Elementwise-safe pass-through verification +# =========================================================================== + + +@pytest.mark.unit +def test_clone(): + class M(nn.Module): + def forward(self, z): + return z.clone() + + _check_op(M(), (_z(),), "clone") + + +@pytest.mark.unit +def test_zeros_like(): + """zeros_like(z) → 0+0i (correct — all zeros in [..., 2] layout).""" + + class M(nn.Module): + def forward(self, z): + return torch.zeros_like(z) + + _check_op(M(), (_z(),), "zeros_like") + + +@pytest.mark.unit +def test_mul_scalar_elementwise(): + """mul.Scalar is elementwise-safe: scales both re and im.""" + + class M(nn.Module): + def forward(self, z): + return torch.ops.aten.mul.Scalar(z, 2.5) + + _check_op(M(), (_z(),), "mul_scalar_aten") + + +@pytest.mark.unit +def test_div_scalar_elementwise(): + class M(nn.Module): + def forward(self, z): + return z / 4.0 + + _check_op(M(), (_z(),), "div_scalar_elementwise") + + +# =========================================================================== +# 8. Reduction ops +# — positive dims: pass-through gives correct results +# — negative dims: no handler yet → xfail +# =========================================================================== + + +@pytest.mark.unit +def test_sum_pos_dim(): + """sum(z, dim=0) — positive dim, pass-through is correct.""" + + class M(nn.Module): + def forward(self, z): + return z.sum(dim=0) + + _check_op(M(), (_z(),), "sum_pos") + + +@pytest.mark.unit +def test_sum_pos_dim_keepdim(): + class M(nn.Module): + def forward(self, z): + return z.sum(dim=1, keepdim=True) + + _check_op(M(), (_z3d(),), "sum_pos_keepdim") + + +@pytest.mark.unit +def test_sum_neg_dim(): + class M(nn.Module): + def forward(self, z): + return z.sum(dim=-1) + + _check_op(M(), (_z(),), "sum_neg") + + +@pytest.mark.unit +def test_mean_pos_dim(): + class M(nn.Module): + def forward(self, z): + return z.mean(dim=0) + + _check_op(M(), (_z(),), "mean_pos") + + +@pytest.mark.unit +def test_mean_neg_dim(): + class M(nn.Module): + def forward(self, z): + return z.mean(dim=-1) + + _check_op(M(), (_z(),), "mean_neg") + + +# =========================================================================== +# 9. Creation-op bugs (xfail = documented known failures) +# =========================================================================== + + +@pytest.mark.unit +def test_ones_like_bug(): + """ones_like(z) should give 1+0i everywhere, not 1+1i.""" + + class M(nn.Module): + def forward(self, z): + return torch.ones_like(z) + + _check_op(M(), (_z(),), "ones_like") + + +@pytest.mark.unit +def test_full_like_bug(): + """full_like(z, 3.0) should give 3+0i everywhere.""" + + class M(nn.Module): + def forward(self, z): + return torch.full_like(z, 3.0) + + _check_op(M(), (_z(),), "full_like") + + +# =========================================================================== +# 10. Chain / composition tests +# =========================================================================== + + +@pytest.mark.unit +def test_mul_then_exp(): + class M(nn.Module): + def forward(self, x, y): + return torch.exp(x * y) + + z = torch.randn(3, 4, dtype=torch.complex64) * 0.3 + _check_op(M(), (z, z.clone()), "mul_then_exp") + + +@pytest.mark.unit +def test_reshape_then_mul(): + class M(nn.Module): + def forward(self, x, y): + return x.reshape(12) * y + + x = _z() + y = torch.randn(12, dtype=torch.complex64) + _check_op(M(), (x, y), "reshape_then_mul") + + +@pytest.mark.unit +def test_mm_then_reshape(): + class M(nn.Module): + def forward(self, x, y): + return (x @ y).reshape(15) + + x = torch.randn(3, 4, dtype=torch.complex64) + y = torch.randn(4, 5, dtype=torch.complex64) + _check_op(M(), (x, y), "mm_then_reshape") + + +@pytest.mark.unit +def test_cat_then_exp(): + class M(nn.Module): + def forward(self, x, y): + return torch.exp(torch.cat([x, y], dim=0)) + + z = torch.randn(2, 4, dtype=torch.complex64) * 0.3 + _check_op(M(), (z, z.clone()), "cat_then_exp") + + +@pytest.mark.unit +def test_unsqueeze_squeeze_round_trip(): + class M(nn.Module): + def forward(self, z): + return z.unsqueeze(1).squeeze(1) + + _check_op(M(), (_z(),), "unsqueeze_squeeze_rt") + + +@pytest.mark.unit +def test_permute_mul(): + class M(nn.Module): + def forward(self, x, y): + return x.permute(1, 0) * y.permute(1, 0) + + _check_op(M(), (_z(), _z()), "permute_mul") + + +@pytest.mark.unit +def test_transpose_then_mm(): + class M(nn.Module): + def forward(self, x, y): + return x @ y.transpose(-2, -1) + + x = torch.randn(3, 4, dtype=torch.complex64) + y = torch.randn(5, 4, dtype=torch.complex64) + _check_op(M(), (x, y), "transpose_mm") + + +@pytest.mark.unit +def test_rope_style_pattern(): + """RoPE-like pattern: split → mul with freqs → cat.""" + + class M(nn.Module): + def forward(self, q, freqs): + # q: [B, T, D] complex, freqs: [T, D] complex + return q * freqs.unsqueeze(0) + + q = _z3d(2, 8, 4) + freqs = _z(8, 4) + _check_op(M(), (q, freqs), "rope_style") + + +@pytest.mark.unit +def test_multiop_chain(): + """sin(exp(z) + conj(z)) — exercises several handlers in sequence.""" + + class M(nn.Module): + def forward(self, z): + return torch.sin(torch.exp(z * 0.1) + torch.conj(z)) + + z = torch.randn(3, 4, dtype=torch.complex64) * 0.2 + _check_op(M(), (z,), "multiop_chain") + + +@pytest.mark.unit +def test_abs_then_mul(): + """abs(z) is real; multiplying by a real scalar stays real.""" + + class M(nn.Module): + def forward(self, z): + return torch.abs(z) * 2.0 + + _check_op(M(), (_z(3, 5),), "abs_then_mul") + + +@pytest.mark.unit +def test_split_then_mul_then_cat(): + """split → element-wise mul → cat.""" + + class M(nn.Module): + def forward(self, z): + a, b = z.split(2, dim=1) # [3,2] each + return torch.cat([a * b, b * a], dim=1) + + _check_op(M(), (torch.randn(3, 4, dtype=torch.complex64),), "split_mul_cat") diff --git a/tests/py/dynamo/models/test_model_refit.py b/tests/py/dynamo/models/test_model_refit.py index 1bdbd2dc60..3f8fafe7d2 100644 --- a/tests/py/dynamo/models/test_model_refit.py +++ b/tests/py/dynamo/models/test_model_refit.py @@ -522,6 +522,10 @@ def test_refit_one_engine_bert_with_weightmap(): torch._dynamo.reset() +@unittest.skipIf( + not importlib.util.find_spec("torchvision"), + "torchvision is not installed", +) @unittest.skipIf( not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available", @@ -582,6 +586,10 @@ def test_refit_one_engine_inline_runtime_with_weightmap(tmpdir): torch._dynamo.reset() +@unittest.skipIf( + not importlib.util.find_spec("torchvision"), + "torchvision is not installed", +) @unittest.skipIf( not torch_trt.ENABLED_FEATURES.refit, "Refit feature is not supported in Python 3.13 or higher", @@ -773,6 +781,10 @@ def forward(self, x): torch._dynamo.reset() +@unittest.skipIf( + not importlib.util.find_spec("torchvision"), + "torchvision is not installed", +) @unittest.skipIf( not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available", @@ -892,6 +904,10 @@ def test_refit_one_engine_bert_without_weightmap(): torch._dynamo.reset() +@unittest.skipIf( + not importlib.util.find_spec("torchvision"), + "torchvision is not installed", +) @unittest.skipIf( not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available", @@ -949,6 +965,10 @@ def test_refit_one_engine_inline_runtime_without_weightmap(tmpdir): torch._dynamo.reset() +@unittest.skipIf( + not importlib.util.find_spec("torchvision"), + "torchvision is not installed", +) @unittest.skipIf( not torch_trt.ENABLED_FEATURES.refit, "Refit feature is not supported in Python 3.13 or higher", @@ -1128,3 +1148,220 @@ def forward(self, x): # Clean up model env torch._dynamo.reset() + + +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, + "TorchScript Frontend is not available", +) +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) +@pytest.mark.unit +def test_complex_buffer_refit(): + """Refit a model whose weights include a complex-valued buffer (e.g. RoPE freqs). + + Exercises the combined complex_graph_detection + refit_module_weights path: + - complex get_attr buffer is unpacked to real by the lowering pass + - complex placeholder input goes through view_as_real at the TRT boundary + - after refitting with new frequencies the output matches the new model + """ + + class ComplexFreqModel(nn.Module): + def __init__(self, freqs: torch.Tensor): + super().__init__() + self.register_buffer("freqs", freqs.cuda()) + + def forward(self, z: torch.Tensor) -> torch.Tensor: + # complex mul then expose as real so TRT can produce a real output + return torch.view_as_real(z * self.freqs) + + SEQ, DIM = 8, 32 + + def make_freqs() -> torch.Tensor: + angles = torch.rand(SEQ, DIM // 2) + return torch.polar(torch.ones_like(angles), angles).cuda() + + freqs1 = make_freqs() + freqs2 = make_freqs() + + model1 = ComplexFreqModel(freqs1).eval() + model2 = ComplexFreqModel(freqs2).eval() + + z = torch.randn(SEQ, DIM // 2, dtype=torch.complex64).cuda() + inputs = [z] + + exp_program1 = torch.export.export(model1, tuple(inputs)) + exp_program2 = torch.export.export(model2, tuple(inputs)) + + trt_gm = torchtrt.dynamo.compile( + exp_program1, + tuple(inputs), + use_python_runtime=True, + enabled_precisions={torch.float}, + min_block_size=1, + immutable_weights=False, + ) + + new_trt_gm = refit_module_weights( + compiled_module=trt_gm, + new_weight_module=exp_program2, + arg_inputs=inputs, + use_weight_map_cache=True, + verify_output=True, + ) + + expected_output = exp_program2.module()(*inputs) + refitted_output = new_trt_gm(*inputs) + + assertions.assertTrue( + torch.allclose(expected_output, refitted_output, atol=1e-2, rtol=1e-2), + "Refit with complex buffer failed: output mismatch after refit", + ) + + torch._dynamo.reset() + + +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, + "TorchScript Frontend is not available", +) +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) +@pytest.mark.unit +def test_complex_buffer_with_real_param_refit(): + """Refit a model that mixes a complex buffer with a real nn.Linear weight. + + Verifies that Stage 3 slice-matching for complex buffer constants coexists + correctly with ordinary weight-name-map entries for real parameters. + After refitting both the frequencies and the projection matrix, the output + should match the new model exactly. + """ + + SEQ, DIM = 8, 32 + + class ComplexRotateAndProject(nn.Module): + def __init__(self, freqs: torch.Tensor): + super().__init__() + self.register_buffer("freqs", freqs.cuda()) + self.proj = nn.Linear(DIM, DIM, bias=False) + + def forward(self, z: torch.Tensor) -> torch.Tensor: + rotated = z * self.freqs # complex mul, (SEQ, DIM//2) + r = torch.view_as_real(rotated) # (SEQ, DIM//2, 2) + flat = r.reshape(z.shape[0], -1) # (SEQ, DIM) + return self.proj(flat) # (SEQ, DIM) real output + + def make_freqs() -> torch.Tensor: + angles = torch.rand(SEQ, DIM // 2) + return torch.polar(torch.ones_like(angles), angles).cuda() + + model1 = ComplexRotateAndProject(make_freqs()).eval().cuda() + model2 = ComplexRotateAndProject(make_freqs()).eval().cuda() + + z = torch.randn(SEQ, DIM // 2, dtype=torch.complex64).cuda() + inputs = [z] + + exp_program1 = torch.export.export(model1, tuple(inputs)) + exp_program2 = torch.export.export(model2, tuple(inputs)) + + trt_gm = torchtrt.dynamo.compile( + exp_program1, + tuple(inputs), + use_python_runtime=True, + enabled_precisions={torch.float}, + min_block_size=1, + immutable_weights=False, + ) + + new_trt_gm = refit_module_weights( + compiled_module=trt_gm, + new_weight_module=exp_program2, + arg_inputs=inputs, + use_weight_map_cache=True, + verify_output=True, + ) + + expected_output = exp_program2.module()(*inputs) + refitted_output = new_trt_gm(*inputs) + + assertions.assertTrue( + torch.allclose(expected_output, refitted_output, atol=1e-2, rtol=1e-2), + "Refit with complex buffer + real param failed: output mismatch", + ) + + torch._dynamo.reset() + + +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, + "TorchScript Frontend is not available", +) +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) +@pytest.mark.unit +def test_dual_complex_buffer_refit(): + """Refit a model with two independent complex buffers. + + Ensures Stage 3 value-based matching correctly distinguishes the real and + imaginary slices of freqs_a from those of freqs_b when both are unpacked to + separate _unpacked_complex state-dict entries with the same shape. + """ + + SEQ, DIM = 8, 32 + + class DualComplexFreqModel(nn.Module): + def __init__(self, freqs_a: torch.Tensor, freqs_b: torch.Tensor): + super().__init__() + self.register_buffer("freqs_a", freqs_a.cuda()) + self.register_buffer("freqs_b", freqs_b.cuda()) + + def forward(self, z: torch.Tensor) -> torch.Tensor: + ra = torch.view_as_real(z * self.freqs_a) # (SEQ, DIM//2, 2) + rb = torch.view_as_real(z * self.freqs_b) # (SEQ, DIM//2, 2) + return ra + rb # real output + + def make_freqs() -> torch.Tensor: + angles = torch.rand(SEQ, DIM // 2) + return torch.polar(torch.ones_like(angles), angles).cuda() + + model1 = DualComplexFreqModel(make_freqs(), make_freqs()).eval() + model2 = DualComplexFreqModel(make_freqs(), make_freqs()).eval() + + z = torch.randn(SEQ, DIM // 2, dtype=torch.complex64).cuda() + inputs = [z] + + exp_program1 = torch.export.export(model1, tuple(inputs)) + exp_program2 = torch.export.export(model2, tuple(inputs)) + + trt_gm = torchtrt.dynamo.compile( + exp_program1, + tuple(inputs), + use_python_runtime=True, + enabled_precisions={torch.float}, + min_block_size=1, + immutable_weights=False, + ) + + new_trt_gm = refit_module_weights( + compiled_module=trt_gm, + new_weight_module=exp_program2, + arg_inputs=inputs, + use_weight_map_cache=True, + verify_output=True, + ) + + expected_output = exp_program2.module()(*inputs) + refitted_output = new_trt_gm(*inputs) + + assertions.assertTrue( + torch.allclose(expected_output, refitted_output, atol=1e-2, rtol=1e-2), + "Refit with dual complex buffers failed: output mismatch", + ) + + torch._dynamo.reset() diff --git a/uv.lock b/uv.lock index 82459b609a..c320b1b50b 100644 --- a/uv.lock +++ b/uv.lock @@ -32,9 +32,6 @@ required-markers = [ "python_full_version < '3.14' and platform_machine == 'AMD64' and sys_platform == 'win32'", ] -[options] -prerelease-mode = "allow" - [[package]] name = "accelerate" version = "1.12.0" @@ -2900,6 +2897,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/99/6c/64cafaceea3f99927e84b38a362ec6a8f24f33061c90bda77dfe1cd4c3c6/pulp-3.3.0-py3-none-any.whl", hash = "sha256:dd6ad2d63f196d1254eddf9dcff5cd224912c1f046120cb7c143c5b0eda63fae", size = 16387700, upload-time = "2025-09-18T08:14:53.368Z" }, ] +[[package]] +name = "py" +version = "1.11.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/98/ff/fec109ceb715d2a6b4c4a85a61af3b40c723a961e8828319fbcb15b868dc/py-1.11.0.tar.gz", hash = "sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719", size = 207796, upload-time = "2021-11-04T17:17:01.377Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f6/f0/10642828a8dfb741e5f3fbaac830550a518a775c7fff6f04a007259b0548/py-1.11.0-py2.py3-none-any.whl", hash = "sha256:607c53218732647dff4acdfcd50cb62615cedf612e72d1724fb1a0cc6405b378", size = 98708, upload-time = "2021-11-04T17:17:00.152Z" }, +] + [[package]] name = "py-cpuinfo" version = "9.0.0" @@ -3157,6 +3163,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" }, ] +[[package]] +name = "pytest-forked" +version = "1.6.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "py", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "pytest", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/8c/c9/93ad2ba2413057ee694884b88cf7467a46c50c438977720aeac26e73fdb7/pytest-forked-1.6.0.tar.gz", hash = "sha256:4dafd46a9a600f65d822b8f605133ecf5b3e1941ebb3588e943b4e3eb71a5a3f", size = 9977, upload-time = "2023-02-12T23:22:27.544Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f4/af/9c0bda43e486a3c9bf1e0f876d0f241bc3f229d7d65d09331a0868db9629/pytest_forked-1.6.0-py3-none-any.whl", hash = "sha256:810958f66a91afb1a1e2ae83089d8dc1cd2437ac96b12963042fbb9fb4d16af0", size = 4897, upload-time = "2023-02-12T23:22:26.022Z" }, +] + [[package]] name = "pytest-xdist" version = "3.8.0" @@ -4046,9 +4065,14 @@ debug = [ dev = [ { name = "black", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "clang-format", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "expecttest", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "isort", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "mypy", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "parameterized", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "pre-commit", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "pytest", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "pytest-forked", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "pytest-xdist", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "pyyaml", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "ruff", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "typos", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, @@ -4074,6 +4098,7 @@ test = [ { name = "expecttest", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "parameterized", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "pytest", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { name = "pytest-forked", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "pytest-xdist", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, ] test-ext = [ @@ -4108,9 +4133,14 @@ debug = [ dev = [ { name = "black", specifier = ">=24.0.0" }, { name = "clang-format", specifier = "==14.0.6" }, + { name = "expecttest", specifier = "==0.1.6" }, { name = "isort" }, { name = "mypy" }, + { name = "parameterized", specifier = ">=0.2.0" }, { name = "pre-commit", specifier = ">=2.20.0" }, + { name = "pytest" }, + { name = "pytest-forked", specifier = ">=1.6.0" }, + { name = "pytest-xdist" }, { name = "pyyaml" }, { name = "ruff" }, { name = "typos" }, @@ -4134,6 +4164,7 @@ test = [ { name = "expecttest", specifier = "==0.1.6" }, { name = "parameterized", specifier = ">=0.2.0" }, { name = "pytest" }, + { name = "pytest-forked", specifier = ">=1.6.0" }, { name = "pytest-xdist" }, ] test-ext = [