-
Notifications
You must be signed in to change notification settings - Fork 385
Description
Bug description
HuggingFace transformers.cache_utils.StaticLayer.update() uses index_copy_() for KV cache updates. When a model using StaticCache is exported via torch.export and converted to TRT, index_copy_ decomposes to aten.index_put which triggers a broadcast shape error in Torch-TRT's index_put_converter:
ValueError: Cannot broadcast (1, 8, 1, 128) to (1, 1, 8, 128)
This completely prevents TRT conversion of any model using StaticCache with index_copy_.
To reproduce
import torch
import torch_tensorrt
from transformers import AutoModelForCausalLM, StaticCache
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B", dtype=torch.float16, device_map="cuda")
model.eval()
# Create a wrapper that uses StaticCache internally
# Export with torch.export (strict=False for nn.Module with state)
# ...
# This fails with ValueError: Cannot broadcast (1, 8, 1, 128) to (1, 1, 8, 128)
engine = torch_tensorrt.dynamo.convert_exported_program_to_serialized_trt_engine(
exported_program,
inputs=example_inputs,
use_explicit_typing=True,
min_block_size=1,
truncate_double=True,
)Root cause
StaticLayer.update() in transformers 5.2.0 uses:
self.keys.index_copy_(2, cache_position, key_states)
self.values.index_copy_(2, cache_position, value_states)After torch.export, index_copy_ decomposes to aten.index_put.default. Torch-TRT's index_put_converter (torch_tensorrt/dynamo/conversion/impl/select.py:838) fails with a broadcast shape mismatch between the cache tensor shape (1, 8, 1, 128) and the value tensor shape (1, 1, 8, 128).
Workaround
Replacing index_copy_() with functional torch.scatter() before export produces a clean TRT graph:
# Monkey-patch StaticLayer.update to use scatter instead of index_copy_
# Before: self.keys.index_copy_(2, cache_position, key_states)
# After: self.keys = self.keys.scatter(2, idx.expand_as(key_states), key_states)With index_copy_ (unpatched): 3078 graph nodes, 56 index_copy_ ops → TRT conversion fails
With scatter (patched): 3134 graph nodes, 0 index_copy_ ops, 57 scatter ops → TRT conversion succeeds, 1139 MB engine
Error trace
ERROR:torch_tensorrt.dynamo._compiler:While interpreting the module got an error:
Cannot broadcast (1, 8, 1, 128) to (1, 1, 8, 128)
While executing %index_put : [num_users=1] = call_function[target=torch.ops.aten.index_put.default]
...
File "transformers/models/qwen3/modeling_qwen3.py", line 274, in forward
key_states, value_states = past_key_values.update(...)
Expected behavior
index_copy_ / index_put should be supported in the TRT converter with correct broadcast handling, or an explicit unsupported-op error should be raised (not a cryptic broadcast error).
Note
This also affects any user following the HuggingFace static caching export guide who then attempts TRT conversion — StaticCache is the recommended caching implementation for export.
Environment
- torch: 2.10.0+cu128
- torch_tensorrt: 2.10.0+cu130
- tensorrt: 10.14.1.48
- transformers: 5.2.0
- Model: Qwen/Qwen3-0.6B (fp16, 28 layers, GQA)
- GPU: NVIDIA GeForce RTX 4090 (24GB)
- Driver: 590.48.01
- OS: Linux (Ubuntu 24.04)