From ec0e0bbc04bb2dd31e25f286599e34d7fbc84008 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Wed, 18 Feb 2026 17:23:48 +0000 Subject: [PATCH] docs: a claude demo of using torch-tensorrt to compile qwen3-reranker --- .../dynamo/torch_export_qwen3_reranker.py | 338 ++++++++++++++++++ 1 file changed, 338 insertions(+) create mode 100644 examples/dynamo/torch_export_qwen3_reranker.py diff --git a/examples/dynamo/torch_export_qwen3_reranker.py b/examples/dynamo/torch_export_qwen3_reranker.py new file mode 100644 index 0000000000..27f54b9543 --- /dev/null +++ b/examples/dynamo/torch_export_qwen3_reranker.py @@ -0,0 +1,338 @@ +""" +Qwen3 Reranker compiled with Torch-TensorRT + +Qwen3-Reranker is a causal LM used for reranking: it takes a formatted +(instruction, query, document) prompt and returns logits for the last +token. The relevance score is softmax( [logit_no, logit_yes] )[:, 1]. + +This example: + 1. Loads Qwen/Qwen3-Reranker-0.6B (or a larger variant via --model). + 2. Runs a baseline PyTorch forward pass. + 3. Exports the model with torch.export.export() with dynamic seq_len. + 4. Compiles to TensorRT via torch_tensorrt.dynamo.compile(). + 5. Verifies that the last-token logits match between PyTorch and TRT. + 6. (Optional) Benchmarks latency of both backends. + +Usage +----- +# Basic run (quality check) +python examples/dynamo/torch_export_qwen3_reranker.py + +# With larger model and BF16 +python examples/dynamo/torch_export_qwen3_reranker.py --model Qwen/Qwen3-Reranker-4B --precision BF16 + +# Benchmark mode +python examples/dynamo/torch_export_qwen3_reranker.py --benchmark --iterations 20 +""" + +import argparse +import sys +import timeit +from contextlib import nullcontext +from pathlib import Path + +import torch +import torch_tensorrt +from transformers import AutoModelForCausalLM, AutoTokenizer + +# Make tools/llm importable so we can reuse export_llm / register_sdpa +_TOOLS_LLM = Path(__file__).resolve().parent.parent.parent / "tools" / "llm" +if str(_TOOLS_LLM) not in sys.path: + sys.path.insert(0, str(_TOOLS_LLM)) + +from torchtrt_ext import register_sdpa # noqa: E402 +from utils import export_llm # noqa: E402 + +DEVICE = torch.device("cuda:0") + +# --------------------------------------------------------------------------- +# Prompt helpers +# --------------------------------------------------------------------------- + +_PREFIX = ( + "<|im_start|>system\n" + "Judge whether the Document meets the requirements based on the Query and the Instruct provided. " + 'Note that the answer can only be "yes" or "no".<|im_end|>\n' + "<|im_start|>user\n" +) +_SUFFIX = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" + + +def format_pair(instruction: str, query: str, document: str) -> str: + return f": {instruction}\n: {query}\n: {document}" + + +def build_inputs( + tokenizer, + queries: list[str], + documents: list[str], + instruction: str, + max_length: int = 8192, +) -> dict: + """Tokenize (query, document) pairs with the reranker prefix/suffix.""" + prefix_ids = tokenizer.encode(_PREFIX, add_special_tokens=False) + suffix_ids = tokenizer.encode(_SUFFIX, add_special_tokens=False) + body_max = max_length - len(prefix_ids) - len(suffix_ids) + + pairs = [format_pair(instruction, q, d) for q, d in zip(queries, documents)] + encoded = tokenizer( + pairs, + padding=False, + truncation="longest_first", + return_attention_mask=False, + max_length=body_max, + ) + for i, ids in enumerate(encoded["input_ids"]): + encoded["input_ids"][i] = prefix_ids + ids + suffix_ids + + batch = tokenizer.pad( + encoded, padding=True, return_tensors="pt", max_length=max_length + ) + return {k: v.to(DEVICE) for k, v in batch.items()} + + +def compute_scores( + logits: torch.Tensor, token_true_id: int, token_false_id: int +) -> list[float]: + """Convert last-token logits to yes-probability scores.""" + last = logits[:, -1, :] + yes_logit = last[:, token_true_id] + no_logit = last[:, token_false_id] + stacked = torch.stack([no_logit, yes_logit], dim=1) + return torch.nn.functional.softmax(stacked, dim=1)[:, 1].tolist() + + +# --------------------------------------------------------------------------- +# Model wrapper: drop attention_mask, add position_ids +# --------------------------------------------------------------------------- + + +class RerankerForExport(torch.nn.Module): + """ + Thin wrapper around the causal-LM so that torch.export sees a clean + (input_ids, position_ids) signature – matching what export_llm / the + existing TRT pipeline expects. + + The Qwen3 model internally builds position_ids from the sequence length + when they are not supplied, but passing them explicitly lets us mark the + sequence-length dimension as dynamic during export. + + NOTE: We drop attention_mask here because the reranker pairs are + right-padded during batching. If you need left-padding (e.g. for + generation), pass attention_mask explicitly and adjust the export + dynamic_shapes accordingly. + """ + + def __init__(self, model: torch.nn.Module): + super().__init__() + self.model = model + + def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor: + out = self.model(input_ids=input_ids, position_ids=position_ids) + return out.logits + + +# --------------------------------------------------------------------------- +# Export + compile +# --------------------------------------------------------------------------- + + +def export_reranker(model_wrapper: RerankerForExport, input_ids: torch.Tensor): + """Export the wrapped reranker with a dynamic sequence-length dimension.""" + max_seq_len = input_ids.shape[1] + ep = export_llm(model_wrapper, input_ids, max_seq_len=max_seq_len) + return ep + + +def compile_torchtrt( + ep, input_ids: torch.Tensor, precision: str, debug: bool, min_block_size: int +): + position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE) + + if precision == "FP16": + enabled_precisions = {torch.float32} + use_fp32_acc = True + elif precision == "BF16": + enabled_precisions = {torch.bfloat16} + use_fp32_acc = False + else: # FP32 + enabled_precisions = {torch.float32} + use_fp32_acc = False + + with torch_tensorrt.logging.debug() if debug else nullcontext(): + trt_model = torch_tensorrt.dynamo.compile( + ep, + inputs=[input_ids, position_ids], + enabled_precisions=enabled_precisions, + use_explicit_typing=True, + use_fp32_acc=use_fp32_acc, + device=DEVICE, + disable_tf32=True, + use_python_runtime=True, + debug=debug, + offload_module_to_cpu=True, + min_block_size=min_block_size, + ) + return trt_model + + +# --------------------------------------------------------------------------- +# Benchmark helper +# --------------------------------------------------------------------------- + + +def benchmark(fn, *args, iterations: int = 10, label: str = "") -> float: + # Warmup + fn(*args) + torch.cuda.synchronize() + + total = 0.0 + for _ in range(iterations): + t0 = timeit.default_timer() + fn(*args) + torch.cuda.synchronize() + total += timeit.default_timer() - t0 + + avg_ms = total / iterations * 1000 + print(f"[{label}] avg latency over {iterations} iters: {avg_ms:.2f} ms") + return avg_ms + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def parse_args(): + p = argparse.ArgumentParser(description="Compile Qwen3-Reranker with Torch-TensorRT") + p.add_argument("--model", default="Qwen/Qwen3-Reranker-0.6B", help="HF model name") + p.add_argument("--precision", default="FP16", choices=["FP16", "BF16", "FP32"]) + p.add_argument( + "--max_length", + type=int, + default=512, + help="Max token length for test inputs (keep small for quick iteration)", + ) + p.add_argument("--benchmark", action="store_true") + p.add_argument("--iterations", type=int, default=10) + p.add_argument("--min_block_size", type=int, default=1) + p.add_argument("--debug", action="store_true") + return p.parse_args() + + +def main(): + args = parse_args() + + # ------------------------------------------------------------------ + # 1. Load tokenizer & model + # ------------------------------------------------------------------ + print(f"Loading {args.model} ...") + tokenizer = AutoTokenizer.from_pretrained(args.model, padding_side="right") + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + token_true_id = tokenizer.convert_tokens_to_ids("yes") + token_false_id = tokenizer.convert_tokens_to_ids("no") + print(f" token_true_id (yes) = {token_true_id}, token_false_id (no) = {token_false_id}") + + base_model = ( + AutoModelForCausalLM.from_pretrained( + args.model, + use_cache=False, + attn_implementation="sdpa", + ) + .eval() + .cuda() + ) + + # Register custom SDPA converter (handles Qwen3 like other models in run_llm.py) + register_sdpa.enable_sdpa_converter(args.model, base_model.config) + + dtype_map = {"FP16": torch.float16, "BF16": torch.bfloat16, "FP32": torch.float32} + base_model = base_model.to(dtype_map[args.precision]) + + # ------------------------------------------------------------------ + # 2. Build test inputs + # ------------------------------------------------------------------ + instruction = "Given a web search query, retrieve relevant passages that answer the query" + queries = [ + "What is the capital of China?", + "How does photosynthesis work?", + ] + documents = [ + "The capital of China is Beijing.", + "Photosynthesis is the process by which plants convert sunlight into glucose.", + ] + + print("Tokenizing inputs ...") + inputs = build_inputs(tokenizer, queries, documents, instruction, max_length=args.max_length) + input_ids = inputs["input_ids"] + print(f" input_ids shape: {input_ids.shape}") + + # ------------------------------------------------------------------ + # 3. PyTorch baseline + # ------------------------------------------------------------------ + position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE) + model_wrapper = RerankerForExport(base_model) + + with torch.inference_mode(): + pyt_logits = model_wrapper(input_ids, position_ids) + pyt_scores = compute_scores(pyt_logits, token_true_id, token_false_id) + print("\n--- PyTorch baseline ---") + for q, d, s in zip(queries, documents, pyt_scores): + print(f" Q: {q!r} D: {d!r} score={s:.4f}") + + # ------------------------------------------------------------------ + # 4. Export & compile + # ------------------------------------------------------------------ + print("\nExporting model ...") + with torch.inference_mode(): + ep = export_reranker(model_wrapper, input_ids) + + print("Compiling with Torch-TensorRT ...") + with torch.inference_mode(): + trt_model = compile_torchtrt( + ep, input_ids, args.precision, args.debug, args.min_block_size + ) + + # ------------------------------------------------------------------ + # 5. TRT inference & score comparison + # ------------------------------------------------------------------ + position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE) + with torch.inference_mode(): + trt_logits = trt_model(input_ids, position_ids) + trt_scores = compute_scores(trt_logits, token_true_id, token_false_id) + + print("\n--- TensorRT ---") + for q, d, s in zip(queries, documents, trt_scores): + print(f" Q: {q!r} D: {d!r} score={s:.4f}") + + print("\n--- Score comparison ---") + for i, (ps, ts) in enumerate(zip(pyt_scores, trt_scores)): + diff = abs(ps - ts) + print(f" pair {i}: PyTorch={ps:.6f} TRT={ts:.6f} |diff|={diff:.2e}") + + last_pyt = pyt_logits[:, -1, :].float() + last_trt = trt_logits[:, -1, :].float() + max_diff = (last_pyt - last_trt).abs().max().item() + print(f"\nMax absolute difference in last-token logits: {max_diff:.4e}") + + # ------------------------------------------------------------------ + # 6. (Optional) benchmark + # ------------------------------------------------------------------ + if args.benchmark: + print("\n--- Benchmarking ---") + + def pyt_fwd(): + return model_wrapper(input_ids, position_ids) + + def trt_fwd(): + return trt_model(input_ids, position_ids) + + with torch.inference_mode(): + benchmark(pyt_fwd, iterations=args.iterations, label="PyTorch") + benchmark(trt_fwd, iterations=args.iterations, label="TensorRT") + + +if __name__ == "__main__": + main()