Skip to content

[Bug] index_copy_ in transformers StaticCache causes TRT conversion failure (index_put broadcast error) #4142

@Mgluhovskoi

Description

@Mgluhovskoi

Bug description

HuggingFace transformers.cache_utils.StaticLayer.update() uses index_copy_() for KV cache updates. When a model using StaticCache is exported via torch.export and converted to TRT, index_copy_ decomposes to aten.index_put which triggers a broadcast shape error in Torch-TRT's index_put_converter:

ValueError: Cannot broadcast (1, 8, 1, 128) to (1, 1, 8, 128)

This completely prevents TRT conversion of any model using StaticCache with index_copy_.

To reproduce

import torch
import torch_tensorrt
from transformers import AutoModelForCausalLM, StaticCache

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B", dtype=torch.float16, device_map="cuda")
model.eval()

# Create a wrapper that uses StaticCache internally
# Export with torch.export (strict=False for nn.Module with state)
# ...

# This fails with ValueError: Cannot broadcast (1, 8, 1, 128) to (1, 1, 8, 128)
engine = torch_tensorrt.dynamo.convert_exported_program_to_serialized_trt_engine(
    exported_program,
    inputs=example_inputs,
    use_explicit_typing=True,
    min_block_size=1,
    truncate_double=True,
)

Root cause

StaticLayer.update() in transformers 5.2.0 uses:

self.keys.index_copy_(2, cache_position, key_states)
self.values.index_copy_(2, cache_position, value_states)

After torch.export, index_copy_ decomposes to aten.index_put.default. Torch-TRT's index_put_converter (torch_tensorrt/dynamo/conversion/impl/select.py:838) fails with a broadcast shape mismatch between the cache tensor shape (1, 8, 1, 128) and the value tensor shape (1, 1, 8, 128).

Workaround

Replacing index_copy_() with functional torch.scatter() before export produces a clean TRT graph:

# Monkey-patch StaticLayer.update to use scatter instead of index_copy_
# Before: self.keys.index_copy_(2, cache_position, key_states)
# After:  self.keys = self.keys.scatter(2, idx.expand_as(key_states), key_states)

With index_copy_ (unpatched): 3078 graph nodes, 56 index_copy_ ops → TRT conversion fails
With scatter (patched): 3134 graph nodes, 0 index_copy_ ops, 57 scatter ops → TRT conversion succeeds, 1139 MB engine

Error trace

ERROR:torch_tensorrt.dynamo._compiler:While interpreting the module got an error:
Cannot broadcast (1, 8, 1, 128) to (1, 1, 8, 128)
While executing %index_put : [num_users=1] = call_function[target=torch.ops.aten.index_put.default]
  ...
  File "transformers/models/qwen3/modeling_qwen3.py", line 274, in forward
    key_states, value_states = past_key_values.update(...)

Expected behavior

index_copy_ / index_put should be supported in the TRT converter with correct broadcast handling, or an explicit unsupported-op error should be raised (not a cryptic broadcast error).

Note

This also affects any user following the HuggingFace static caching export guide who then attempts TRT conversion — StaticCache is the recommended caching implementation for export.

Environment

  • torch: 2.10.0+cu128
  • torch_tensorrt: 2.10.0+cu130
  • tensorrt: 10.14.1.48
  • transformers: 5.2.0
  • Model: Qwen/Qwen3-0.6B (fp16, 28 layers, GQA)
  • GPU: NVIDIA GeForce RTX 4090 (24GB)
  • Driver: 590.48.01
  • OS: Linux (Ubuntu 24.04)

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions