-
Notifications
You must be signed in to change notification settings - Fork 385
Open
Description
Bug description
convert_exported_program_to_serialized_trt_engine produces a TRT engine with all-zero logits when use_explicit_typing=False and enabled_precisions={torch.float16}. The same exported program converts correctly with use_explicit_typing=True.
No error or warning is emitted — the engine silently produces incorrect results.
To reproduce
import torch
import torch_tensorrt
from transformers import AutoModelForCausalLM, StaticCache
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B", dtype=torch.float16, device_map="cuda")
model.eval()
# Export with stateless KV cache wrapper
# (Standard torch.export flow — model has fp16 weights, fp32 I/O)
# ... export model to ExportedProgram ...
# This produces all-zero logits:
engine_bad = torch_tensorrt.dynamo.convert_exported_program_to_serialized_trt_engine(
exported_program,
inputs=example_inputs,
use_explicit_typing=False,
enabled_precisions={torch.float16},
min_block_size=1,
truncate_double=True,
)
# This produces correct logits:
engine_good = torch_tensorrt.dynamo.convert_exported_program_to_serialized_trt_engine(
exported_program,
inputs=example_inputs,
use_explicit_typing=True,
min_block_size=1,
truncate_double=True,
)Observed behavior
use_explicit_typing=False: Engine builds successfully (~1139 MB), but inference produces logits where every value is exactly 0.0use_explicit_typing=True: Engine builds successfully (~1139 MB) with correct non-zero logits
Both engines are the same size, suggesting TRT builds a valid-looking engine in both cases, but the use_explicit_typing=False path silently corrupts computation — likely due to incorrect type casting or layer fusion.
Expected behavior
Both paths should produce correct (non-zero) logits, or use_explicit_typing=False should raise an error/warning if it cannot handle the exported program correctly.
Environment
- torch: 2.10.0+cu128
- torch_tensorrt: 2.10.0+cu130
- tensorrt: 10.14.1.48
- transformers: 5.2.0
- Model: Qwen/Qwen3-0.6B (fp16, 28 layers, GQA with 16 heads / 2 KV heads)
- GPU: NVIDIA GeForce RTX 4090 (24GB)
- Driver: 590.48.01
- OS: Linux (Ubuntu 24.04)
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels