-
Notifications
You must be signed in to change notification settings - Fork 385
Open
Labels
bugSomething isn't workingSomething isn't working
Description
aten::index_put converter fails with "Dynamic shape in free dimensions is not supported" when a non-indexed dimension is dynamic
Component: dynamo / converter
Labels: bug, dynamo
Description
The aten::index_put converter rejects a model where the indexed tensor has a dynamic dimension that is not the dimension being indexed (a "free" dynamic dimension).
The pattern is a KV-cache scatter-write common in autoregressive transformer decoders:
cache[..., idx, :] = values
# cache: (2, N, max_ctx, H) — N is dynamic batch
# idx: (L,) — indexes dim-2 (the cache/time dim)
# values: (2, N, L, H)torch.export succeeds. The error is raised during torch_tensorrt.dynamo.compile.
Error message
Dynamic shape in free dimensions is not supported
While executing %index_put : [num_users=1] = call_function[target=torch.ops.aten.index_put.default](
args = (%cache, [None, None, %idx], %values), kwargs = {}
)
Minimum reproducer
import torch
import torch_tensorrt
from torch.export import Dim, export
class IndexPutModel(torch.nn.Module):
def forward(
self,
cache: torch.Tensor, # (2, N, max_ctx, H)
values: torch.Tensor, # (2, N, L, H)
idx: torch.Tensor, # (L,)
) -> torch.Tensor:
cache[..., idx, :] = values
return cache
N = 4
max_ctx = 256
L = 1
H = 512
cache = torch.zeros(2, N, max_ctx, H, dtype=torch.float16, device="cuda")
values = torch.randn(2, N, L, H, dtype=torch.float16, device="cuda")
idx = torch.tensor([3], dtype=torch.long, device="cuda")
model = IndexPutModel().eval().cuda()
batch_dim = Dim("batch", min=1, max=64)
ep = export(
model,
args=(cache, values, idx),
dynamic_shapes={
"cache": {1: batch_dim},
"values": {1: batch_dim},
"idx": {},
},
)
# Export succeeds. Compile fails:
trt_model = torch_tensorrt.dynamo.compile(
ep,
arg_inputs=[
torch_tensorrt.Input(min_shape=(2, 1, max_ctx, H),
opt_shape=(2, N, max_ctx, H),
max_shape=(2, 64, max_ctx, H),
dtype=torch.float16),
torch_tensorrt.Input(min_shape=(2, 1, L, H),
opt_shape=(2, N, L, H),
max_shape=(2, 64, L, H),
dtype=torch.float16),
torch_tensorrt.Input(min_shape=(L,), opt_shape=(L,), max_shape=(L,),
dtype=torch.long),
],
use_explicit_typing=True,
min_block_size=1,
)Expected behavior
Compilation succeeds. The indexed dimension (dim=2) is static (max_ctx=256). The dynamic dimension (dim=1, batch N) is a free dimension that is fully selected by the None entries in the index list — it should not prevent lowering.
Workaround
Replace index_put with scatter_ on the indexed dimension. This lowers to IScatterLayer, which does not restrict free dynamic dimensions:
idx_scatter = idx.view(1, 1, -1, 1).expand_as(values) # (2, N, L, H)
cache.scatter_(2, idx_scatter, values)Environment
torch |
2.12.0.dev20260318+cu130 |
torch_tensorrt |
2.12.0.dev0+fb4fc99f4 |
tensorrt |
10.15.1.29 |
| Python | 3.12.3 |
| GPU | NVIDIA GeForce RTX 5090 (SM 12.0) |
| CUDA | 13.0 |
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working