Skip to content

🐛 [Bug] aten::index_put converter fails with "Dynamic shape in free dimensions is not supported" when a non-indexed dimension is dynamic #4139

@narendasan

Description

@narendasan

aten::index_put converter fails with "Dynamic shape in free dimensions is not supported" when a non-indexed dimension is dynamic

Component: dynamo / converter
Labels: bug, dynamo


Description

The aten::index_put converter rejects a model where the indexed tensor has a dynamic dimension that is not the dimension being indexed (a "free" dynamic dimension).

The pattern is a KV-cache scatter-write common in autoregressive transformer decoders:

cache[..., idx, :] = values
# cache:  (2, N, max_ctx, H)  — N is dynamic batch
# idx:    (L,)                — indexes dim-2 (the cache/time dim)
# values: (2, N, L, H)

torch.export succeeds. The error is raised during torch_tensorrt.dynamo.compile.


Error message

Dynamic shape in free dimensions is not supported
While executing %index_put : [num_users=1] = call_function[target=torch.ops.aten.index_put.default](
    args = (%cache, [None, None, %idx], %values), kwargs = {}
)

Minimum reproducer

import torch
import torch_tensorrt
from torch.export import Dim, export


class IndexPutModel(torch.nn.Module):
    def forward(
        self,
        cache: torch.Tensor,   # (2, N, max_ctx, H)
        values: torch.Tensor,  # (2, N, L, H)
        idx: torch.Tensor,     # (L,)
    ) -> torch.Tensor:
        cache[..., idx, :] = values
        return cache


N = 4
max_ctx = 256
L = 1
H = 512

cache  = torch.zeros(2, N, max_ctx, H, dtype=torch.float16, device="cuda")
values = torch.randn(2, N, L,       H, dtype=torch.float16, device="cuda")
idx    = torch.tensor([3], dtype=torch.long, device="cuda")

model = IndexPutModel().eval().cuda()

batch_dim = Dim("batch", min=1, max=64)
ep = export(
    model,
    args=(cache, values, idx),
    dynamic_shapes={
        "cache":  {1: batch_dim},
        "values": {1: batch_dim},
        "idx":    {},
    },
)
# Export succeeds. Compile fails:
trt_model = torch_tensorrt.dynamo.compile(
    ep,
    arg_inputs=[
        torch_tensorrt.Input(min_shape=(2, 1,  max_ctx, H),
                             opt_shape=(2, N,  max_ctx, H),
                             max_shape=(2, 64, max_ctx, H),
                             dtype=torch.float16),
        torch_tensorrt.Input(min_shape=(2, 1,  L, H),
                             opt_shape=(2, N,  L, H),
                             max_shape=(2, 64, L, H),
                             dtype=torch.float16),
        torch_tensorrt.Input(min_shape=(L,), opt_shape=(L,), max_shape=(L,),
                             dtype=torch.long),
    ],
    use_explicit_typing=True,
    min_block_size=1,
)

Expected behavior

Compilation succeeds. The indexed dimension (dim=2) is static (max_ctx=256). The dynamic dimension (dim=1, batch N) is a free dimension that is fully selected by the None entries in the index list — it should not prevent lowering.


Workaround

Replace index_put with scatter_ on the indexed dimension. This lowers to IScatterLayer, which does not restrict free dynamic dimensions:

idx_scatter = idx.view(1, 1, -1, 1).expand_as(values)  # (2, N, L, H)
cache.scatter_(2, idx_scatter, values)

Environment

torch 2.12.0.dev20260318+cu130
torch_tensorrt 2.12.0.dev0+fb4fc99f4
tensorrt 10.15.1.29
Python 3.12.3
GPU NVIDIA GeForce RTX 5090 (SM 12.0)
CUDA 13.0

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions