From a88200fe30cdc2cd8b903b90e559bf38c25620e5 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Tue, 17 Mar 2026 00:23:59 -0700 Subject: [PATCH 1/5] support attn_bias for efficient sdpa --- .../compilation/compilation_settings.rst | 6 +- .../dynamo/conversion/impl/attention.py | 125 ++++++++++++------ .../dynamo/conversion/test_attention_aten.py | 89 +++++++++++++ 3 files changed, 180 insertions(+), 40 deletions(-) diff --git a/docsrc/user_guide/compilation/compilation_settings.rst b/docsrc/user_guide/compilation/compilation_settings.rst index 2c32bb81c0..e820cda917 100644 --- a/docsrc/user_guide/compilation/compilation_settings.rst +++ b/docsrc/user_guide/compilation/compilation_settings.rst @@ -350,6 +350,10 @@ Graph Partitioning - ``False`` - Use ``aot_autograd`` for tracing instead of the default path. Required when the model contains ``DTensor`` or other distributed tensors. + * - ``decompose_attention`` + - ``False`` + - Decompose attention layers into smaller ops. We have converters for handling attention ops, + but if you want to decompose them into smaller ops, you can set this to True. ---- @@ -372,7 +376,7 @@ Compilation Workflow - ``False`` - Defer TRT engine deserialization until all engines have been built. Works around resource contraints and builder overhad but engines - may be less well tuned to their deployment resource availablity + may be less well tuned to their deployment resource availability * - ``debug`` - ``False`` - Enable verbose TRT builder logs at ``DEBUG`` level. diff --git a/py/torch_tensorrt/dynamo/conversion/impl/attention.py b/py/torch_tensorrt/dynamo/conversion/impl/attention.py index 202ffe830b..e56a85d3c4 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/attention.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/attention.py @@ -10,6 +10,7 @@ from torch_tensorrt.dynamo.conversion.converter_utils import ( cast_trt_tensor, get_trt_tensor, + prepend_ones, ) _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -193,18 +194,17 @@ def scaled_dot_product_attention( if attn_mask is not None: if attn_mask.dtype == trt.DataType.BOOL: mask_tensor = attn_mask + elif attn_mask.dtype != query.dtype: + mask_tensor = cast_trt_tensor( + ctx, + attn_mask, + query.dtype, + name + "_cast_attn_mask", + target, + source_ir, + ) else: - if attn_mask.dtype != query.dtype: - mask_tensor = cast_trt_tensor( - ctx, - attn_mask, - query.dtype, - name + "_cast_attn_mask", - target, - source_ir, - ) - else: - mask_tensor = attn_mask + mask_tensor = attn_mask scaled_query = impl.elementwise.mul( ctx, target, source_ir, f"{name}_scaled_query", query, scale_factor @@ -329,41 +329,88 @@ def scaled_dot_product_efficient_attention( source_ir, ) + mask_tensor = None if attn_bias is not None: - attn_weight = impl.matmul.matrix_multiply( - ctx, - target, - source_ir, - name + "_mm", - scaled_query, - key, - other_matrix_op=trt.MatrixOperation.TRANSPOSE, - ) - attn_weight = impl.elementwise.add( - ctx, target, source_ir, name + "_attn_bias_add", attn_weight, attn_bias - ) - attn_weight = impl.normalization.softmax( - ctx, target, source_ir, name + "_softmax", attn_weight, -1, False - ) - out = impl.matmul.matrix_multiply( + if attn_bias.dtype == trt.DataType.BOOL: + mask_tensor = attn_bias + elif attn_bias.dtype != query.dtype: + mask_tensor = cast_trt_tensor( + ctx, + attn_bias, + query.dtype, + name + "_cast_attn_bias", + target, + source_ir, + ) + else: + mask_tensor = attn_bias + + # TensorRT IAttention does not allow setting both causal=True and mask. + # If both are requested, fold causal into mask and disable causal flag. + use_causal = is_causal + if mask_tensor is not None and is_causal: + L = impl.shape.shape(ctx, target, source_ir, name + "_L", query, -2) + S = impl.shape.shape(ctx, target, source_ir, name + "_S", key, -2) + causal_mask = tril( ctx, target, source_ir, - name + "_out", - attn_weight, - value, - ) - return out, None, None, None - else: - attention_layer = ctx.net.add_attention( - scaled_query, key, value, trt.AttentionNormalizationOp.SOFTMAX, is_causal + name + "_tril", + L, + S, ) - assert attention_layer is not None, "attention layer is None" + diff = len(query.shape) - len(causal_mask.shape) + causal_mask = prepend_ones(ctx, causal_mask, name + "_prepend_ones", diff) + + if mask_tensor.dtype == trt.DataType.BOOL: + mask_tensor = impl.elementwise.logical_and( + ctx, + target, + source_ir, + name + "_causal_attn_bias_and", + causal_mask, + mask_tensor, + ) + else: + # Convert causal bool mask to additive bias mask: + # True -> 0.0 (keep), False -> -inf (block) + zero_bias = get_trt_tensor( + ctx, 0.0, name + "_causal_additive_bias_zero", query.dtype + ) + neg_inf_bias = get_trt_tensor( + ctx, float("-inf"), name + "_causal_additive_bias_neg_inf", query.dtype + ) + causal_additive_bias = impl.condition.where( + ctx, + target, + source_ir, + name + "_causal_additive_bias", + zero_bias, + neg_inf_bias, + causal_mask, + ) + mask_tensor = impl.elementwise.add( + ctx, + target, + source_ir, + name + "_attn_bias_add_causal", + mask_tensor, + causal_additive_bias, + ) + use_causal = False - attention_layer.decomposable = True + attention_layer = ctx.net.add_attention( + scaled_query, key, value, trt.AttentionNormalizationOp.SOFTMAX, use_causal + ) + assert attention_layer is not None, "attention layer is None" - attention_output = attention_layer.get_output(0) - return attention_output, None, None, None + if mask_tensor is not None: + attention_layer.mask = mask_tensor + + attention_layer.decomposable = True + + attention_output = attention_layer.get_output(0) + return attention_output, None, None, None def scaled_dot_product_cudnn_attention( diff --git a/tests/py/dynamo/conversion/test_attention_aten.py b/tests/py/dynamo/conversion/test_attention_aten.py index f63c17b93c..835a4ac00e 100644 --- a/tests/py/dynamo/conversion/test_attention_aten.py +++ b/tests/py/dynamo/conversion/test_attention_aten.py @@ -260,5 +260,94 @@ def forward(self, query, key, value, attn_mask=None): ) +class TestScaledDotProductEfficientAttention(DispatchTestCase): + @parameterized.expand( + [ + ( + (4, 8, 32, 16), + (4, 8, 32, 16), + (4, 8, 32, 16), + (4, 8, 32, 32), + False, + None, + torch.float16, + 0.0, + ), + ( + (4, 8, 32, 16), + (4, 8, 32, 16), + (4, 8, 32, 16), + (4, 8, 32, 32), + True, + None, + torch.float16, + 0.0, + ), + ( + (4, 8, 32, 16), + (4, 8, 32, 16), + (4, 8, 32, 16), + (4, 8, 32, 32), + True, + 2.0, + torch.float32, + 0.0, + ), + ( + (4, 8, 32, 16), + (4, 8, 32, 16), + (4, 8, 32, 16), + (4, 8, 32, 32), + False, + 2.0, + torch.float32, + 0.0, + ), + ] + ) + def test_efficient_sdpa( + self, + q_shape, + k_shape, + v_shape, + attn_bias_shape, + is_causal, + scale, + dtype, + dropout_p=0.0, + ): + class EfficientSDPA(nn.Module): + def forward(self, query, key, value, attn_bias=None): + attn = torch.ops.aten._scaled_dot_product_efficient_attention.default( + query, + key, + value, + attn_bias, + False, + dropout_p, + is_causal, + scale=scale, + ) + return attn[0] + + inputs = [] + query = torch.randn(q_shape, dtype=dtype) + key = torch.rand(k_shape, dtype=dtype) + value = torch.rand(v_shape, dtype=dtype) + inputs.extend([query, key, value]) + if attn_bias_shape is not None: + attn_bias = torch.randn(attn_bias_shape, dtype=dtype) + inputs.append(attn_bias) + self.run_test( + EfficientSDPA(), + inputs, + rtol=1e-2, + atol=1e-2, + precision=dtype, + enable_passes=True, + use_explicit_typing=True, + ) + + if __name__ == "__main__": run_tests() From e7c0cdda58994acd2558d9df72a71697de28a01f Mon Sep 17 00:00:00 2001 From: Evan Li Date: Fri, 20 Mar 2026 10:04:11 -0700 Subject: [PATCH 2/5] add compile arg attn_bias_is_causal for attn_bias from HF models --- .../compilation/compilation_settings.rst | 7 +- py/torch_tensorrt/dynamo/_compiler.py | 9 ++ py/torch_tensorrt/dynamo/_defaults.py | 1 + py/torch_tensorrt/dynamo/_settings.py | 4 + py/torch_tensorrt/dynamo/backend/backends.py | 5 + .../lowering/passes/_aten_lowering_pass.py | 2 + .../dynamo/conversion/test_attention_aten.py | 92 ++++++++++++++++ .../lowering/test_aten_lowering_passes.py | 104 ++++++++++++++++++ 8 files changed, 223 insertions(+), 1 deletion(-) diff --git a/docsrc/user_guide/compilation/compilation_settings.rst b/docsrc/user_guide/compilation/compilation_settings.rst index e820cda917..63211daa4b 100644 --- a/docsrc/user_guide/compilation/compilation_settings.rst +++ b/docsrc/user_guide/compilation/compilation_settings.rst @@ -123,6 +123,11 @@ Optimization Tuning - Disable TensorFloat-32 (TF32) accumulation. TF32 is enabled by default on Ampere and newer GPUs and provides FP32-range with FP16-speed for matmul/conv. Disable only when you need strict IEEE FP32 semantics. + * - ``attn_bias_is_causal`` + - ``True`` + - Whether the attn_bias in efficient SDPA is causal. Default is True. This can + accelerate models from HF because attn_bias is always a causal mask in HF. + If you want to use non-causal attn_bias, you can set this to False. ---- @@ -448,7 +453,7 @@ rebuilt from scratch: ``l2_limit_for_tiling``, ``enable_autocast``, ``autocast_low_precision_type``, ``autocast_excluded_nodes``, ``autocast_excluded_ops``, ``autocast_max_output_threshold``, ``autocast_max_depth_of_reduction``, -``autocast_calibration_dataloader``. +``autocast_calibration_dataloader``, ``decompose_attention``, ``attn_bias_is_causal``. Settings not in this list (e.g., ``debug``, ``dryrun``, ``pass_through_build_failures``) can be changed without invalidating the cache. diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index bc3cdc5721..4bca7fcf8e 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -112,6 +112,7 @@ def cross_compile_for_windows( cpu_memory_budget: Optional[int] = _defaults.CPU_MEMORY_BUDGET, dynamically_allocate_resources: bool = _defaults.DYNAMICALLY_ALLOCATE_RESOURCES, decompose_attention: bool = _defaults.DECOMPOSE_ATTENTION, + attn_bias_is_causal: bool = _defaults.ATTN_BIAS_IS_CAUSAL, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module using TensorRT in Linux for Inference in Windows @@ -190,6 +191,7 @@ def cross_compile_for_windows( cpu_memory_budget (Optional[int]): The maximum amount of CPU memory to use for the compilation. If the compilation requires more memory than this budget, the compilation will fail. dynamically_allocate_resources (bool): Dynamically allocate resources during engine execution. decompose_attention (bool): Whether to decompose attention layers. We have converters for handling attention ops, but if you want to decompose them into smaller ops, you can set this to True. + attn_bias_is_causal (bool): Whether the attn_bias in efficient SDPA is causal. Default is True. This can accelerate models from HF because attn_bias is always a causal mask in HF. If you want to use non-causal attn_bias, you can set this to False. **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -349,6 +351,7 @@ def cross_compile_for_windows( "cpu_memory_budget": cpu_memory_budget, "dynamically_allocate_resources": dynamically_allocate_resources, "decompose_attention": decompose_attention, + "attn_bias_is_causal": attn_bias_is_causal, } # disable the following settings is not supported for cross compilation for windows feature @@ -471,6 +474,7 @@ def compile( enable_resource_partitioning: bool = _defaults.ENABLE_RESOURCE_PARTITIONING, dynamically_allocate_resources: bool = _defaults.DYNAMICALLY_ALLOCATE_RESOURCES, decompose_attention: bool = _defaults.DECOMPOSE_ATTENTION, + attn_bias_is_causal: bool = _defaults.ATTN_BIAS_IS_CAUSAL, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module for NVIDIA GPUs using TensorRT @@ -559,6 +563,7 @@ def compile( cpu_memory_budget (Optional[int]): The maximum amount of CPU memory to use for the compilation. If the compilation requires more memory than this budget, the compilation will fail. dynamically_allocate_resources (bool): Dynamically allocate resources during engine execution. decompose_attention (bool): Whether to decompose attention layers. We have converters for handling attention ops, but if you want to decompose them into smaller ops, you can set this to True. + attn_bias_is_causal (bool): Whether the attn_bias in efficient SDPA is causal. Default is True. This can accelerate models from HF because attn_bias is always a causal mask in HF. If you want to use non-causal attn_bias, you can set this to False. **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -763,6 +768,7 @@ def compile( "cpu_memory_budget": cpu_memory_budget, "dynamically_allocate_resources": dynamically_allocate_resources, "decompose_attention": decompose_attention, + "attn_bias_is_causal": attn_bias_is_causal, } logger.debug(f"CPU memory usage before lowering: {get_cpu_memory_usage()} MB") settings = CompilationSettings(**compilation_options) @@ -1167,6 +1173,7 @@ def convert_exported_program_to_serialized_trt_engine( offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU, use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, decompose_attention: bool = _defaults.DECOMPOSE_ATTENTION, + attn_bias_is_causal: bool = _defaults.ATTN_BIAS_IS_CAUSAL, **kwargs: Any, ) -> bytes: """Convert an ExportedProgram to a serialized TensorRT engine @@ -1242,6 +1249,7 @@ def convert_exported_program_to_serialized_trt_engine( offload_module_to_cpu (bool): Offload the module to CPU. This is useful when we need to minimize GPU memory usage. use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model. decompose_attention (bool): Whether to decompose attention layers. We have converters for handling attention ops, but if you want to decompose them into smaller ops, you can set this to True. + attn_bias_is_causal (bool): Whether the attn_bias in efficient SDPA is causal. Default is True. This can accelerate models from HF because attn_bias is always a causal mask in HF. If you want to use non-causal attn_bias, you can set this to False. **kwargs: Any, Returns: bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs @@ -1412,6 +1420,7 @@ def convert_exported_program_to_serialized_trt_engine( "offload_module_to_cpu": offload_module_to_cpu, "use_distributed_mode_trace": use_distributed_mode_trace, "decompose_attention": decompose_attention, + "attn_bias_is_causal": attn_bias_is_causal, } settings = CompilationSettings(**compilation_options) diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 2e838cd28c..0165f91086 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -68,6 +68,7 @@ CPU_MEMORY_BUDGET = None DYNAMICALLY_ALLOCATE_RESOURCES = False DECOMPOSE_ATTENTION = False +ATTN_BIAS_IS_CAUSAL = True if platform.system() == "Linux": import pwd diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index e3f2f1bc37..4b9acdf19e 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -8,6 +8,7 @@ from torch_tensorrt._enums import EngineCapability, dtype from torch_tensorrt.dynamo._defaults import ( ASSUME_DYNAMIC_SHAPE_SUPPORT, + ATTN_BIAS_IS_CAUSAL, AUTOCAST_CALIBRATION_DATALOADER, AUTOCAST_EXCLUDED_NODES, AUTOCAST_EXCLUDED_OPS, @@ -120,6 +121,7 @@ class CompilationSettings: offload_module_to_cpu (bool): Offload the model to CPU to reduce memory footprint during compilation dynamically_allocate_resources (bool): Dynamically allocate resources for TensorRT engines decompose_attention (bool): Whether to decompose attention layers. We have converters for handling attention ops, but if you want to decompose them into smaller ops, you can set this to True. + attn_bias_is_causal (bool): Whether the attn_bias in efficient SDPA is causal. Default is True. This can accelerate models from HF because attn_bias is always a causal mask in HF. If you want to use non-causal attn_bias, you can set this to False. """ enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS) @@ -180,6 +182,7 @@ class CompilationSettings: cpu_memory_budget: Optional[int] = CPU_MEMORY_BUDGET dynamically_allocate_resources: bool = DYNAMICALLY_ALLOCATE_RESOURCES decompose_attention: bool = DECOMPOSE_ATTENTION + attn_bias_is_causal: bool = ATTN_BIAS_IS_CAUSAL def __getstate__(self) -> dict[str, Any]: from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( @@ -220,6 +223,7 @@ def __setstate__(self, state: dict[str, Any]) -> None: "autocast_max_depth_of_reduction", "autocast_calibration_dataloader", "decompose_attention", + "attn_bias_is_causal", } diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index 00fb6977e8..5ce62133ff 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -21,6 +21,9 @@ remove_sym_nodes, repair_input_aliasing, ) +from torch_tensorrt.dynamo.lowering.passes.force_causal_efficient_attention import ( + force_causal_efficient_attention, +) from torch_tensorrt.dynamo.utils import ( parse_dynamo_kwargs, prepare_inputs, @@ -148,6 +151,8 @@ def _pretraced_backend( logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph)) + # Keep the behavior of efficient attention consistent with pre_export_lowering path. + gm = force_causal_efficient_attention(gm, settings) gm = post_lowering(gm, settings) logger.debug("Lowered Input graph:\n " + str(gm.graph)) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index 8ad5f2fcae..3b7a8ded83 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -11,6 +11,7 @@ from .complex_graph_rewrite import complex_graph_detection from .constant_folding import constant_fold +from .force_causal_efficient_attention import force_causal_efficient_attention from .fuse_prims_broadcast import fuse_prims_broadcast from .pass_manager import DynamoPassManager from .remove_assert_nodes import remove_assert_nodes @@ -25,6 +26,7 @@ remove_detach, remove_assert_nodes, rule_based_autocast, + force_causal_efficient_attention, ] post_lowering_pass_list = [ diff --git a/tests/py/dynamo/conversion/test_attention_aten.py b/tests/py/dynamo/conversion/test_attention_aten.py index 835a4ac00e..babeae0215 100644 --- a/tests/py/dynamo/conversion/test_attention_aten.py +++ b/tests/py/dynamo/conversion/test_attention_aten.py @@ -330,6 +330,98 @@ def forward(self, query, key, value, attn_bias=None): ) return attn[0] + inputs = [] + query = torch.randn(q_shape, dtype=dtype) + key = torch.rand(k_shape, dtype=dtype) + value = torch.rand(v_shape, dtype=dtype) + inputs.extend([query, key, value]) + if attn_bias_shape is not None: + # create a lower triangular mask that is 0 for lower and -inf for upper + attn_bias = torch.zeros(attn_bias_shape, dtype=dtype) + upper = torch.triu( + torch.ones(attn_bias_shape, dtype=torch.bool), diagonal=1 + ) + attn_bias = attn_bias.masked_fill(upper, float("-inf")) + inputs.append(attn_bias) + self.run_test( + EfficientSDPA(), + inputs, + rtol=1e-2, + atol=1e-2, + precision=dtype, + enable_passes=True, + use_explicit_typing=True, + ) + + @parameterized.expand( + [ + ( + (4, 8, 32, 16), + (4, 8, 32, 16), + (4, 8, 32, 16), + (4, 8, 32, 32), + False, + None, + torch.float16, + 0.0, + ), + ( + (4, 8, 32, 16), + (4, 8, 32, 16), + (4, 8, 32, 16), + (4, 8, 32, 32), + True, + None, + torch.float16, + 0.0, + ), + ( + (4, 8, 32, 16), + (4, 8, 32, 16), + (4, 8, 32, 16), + (4, 8, 32, 32), + True, + 2.0, + torch.float32, + 0.0, + ), + ( + (4, 8, 32, 16), + (4, 8, 32, 16), + (4, 8, 32, 16), + (4, 8, 32, 32), + False, + 2.0, + torch.float32, + 0.0, + ), + ] + ) + def test_efficient_sdpa_random_attn_bias( + self, + q_shape, + k_shape, + v_shape, + attn_bias_shape, + is_causal, + scale, + dtype, + dropout_p=0.0, + ): + class EfficientSDPA(nn.Module): + def forward(self, query, key, value, attn_bias=None): + attn = torch.ops.aten._scaled_dot_product_efficient_attention.default( + query, + key, + value, + attn_bias, + False, + dropout_p, + is_causal, + scale=scale, + ) + return attn[0] + inputs = [] query = torch.randn(q_shape, dtype=dtype) key = torch.rand(k_shape, dtype=dtype) diff --git a/tests/py/dynamo/lowering/test_aten_lowering_passes.py b/tests/py/dynamo/lowering/test_aten_lowering_passes.py index ccfbf06268..424cf145fc 100644 --- a/tests/py/dynamo/lowering/test_aten_lowering_passes.py +++ b/tests/py/dynamo/lowering/test_aten_lowering_passes.py @@ -278,5 +278,109 @@ def forward(self, x: torch.Tensor): self.assertTrue(True) +class TestRewriteEfficientAttention(TestCase): + def test_force_causal_efficient_attention(self): + class RewriteEfficientAttention(torch.nn.Module): + def forward( + self, + query, + key, + value, + attn_bias=None, + compute_log_sumexp=False, + dropout_p=0.0, + is_causal=False, + scale=None, + ): + out = torch.ops.aten._scaled_dot_product_efficient_attention.default( + query, + key, + value, + attn_bias, + compute_log_sumexp, + dropout_p, + is_causal, + scale=scale, + ) + return out[0] + + attn_bias = torch.zeros(4, 8, 32, 32, device="cuda") + upper = torch.triu( + torch.ones((32, 32), dtype=torch.bool, device="cuda"), diagonal=1 + ) + attn_bias = attn_bias.masked_fill(upper, float("-inf")) + + inputs = [ + torch.randn(4, 8, 32, 16).cuda(), + torch.randn(4, 8, 32, 16).cuda(), + torch.randn(4, 8, 32, 16).cuda(), + attn_bias, + True, + 0.0, + True, + ] + model = RewriteEfficientAttention().cuda() + pytorch_out = model(*inputs) + ep = torch.export.export(model, tuple(inputs)) + trt_module = torch_tensorrt.dynamo.compile( + ep, + inputs, + min_block_size=1, + decompose_attention=False, + attn_bias_is_causal=True, + ) + trt_out = trt_module(*inputs) + torch.testing.assert_close(pytorch_out, trt_out, rtol=1e-2, atol=1e-2) + + def test_force_causal_efficient_attention_with_non_causal_attn_bias(self): + class RewriteEfficientAttention(torch.nn.Module): + def forward( + self, + query, + key, + value, + attn_bias=None, + compute_log_sumexp=False, + dropout_p=0.0, + is_causal=False, + scale=None, + ): + out = torch.ops.aten._scaled_dot_product_efficient_attention.default( + query, + key, + value, + attn_bias, + compute_log_sumexp, + dropout_p, + is_causal, + scale=scale, + ) + return out[0] + + attn_bias = torch.randn(4, 8, 32, 32).cuda() + + inputs = [ + torch.randn(4, 8, 32, 16).cuda(), + torch.randn(4, 8, 32, 16).cuda(), + torch.randn(4, 8, 32, 16).cuda(), + attn_bias, + True, + 0.0, + False, + ] + model = RewriteEfficientAttention().cuda() + pytorch_out = model(*inputs) + ep = torch.export.export(model, tuple(inputs)) + trt_module = torch_tensorrt.dynamo.compile( + ep, + inputs, + min_block_size=1, + decompose_attention=False, + attn_bias_is_causal=False, + ) + trt_out = trt_module(*inputs) + torch.testing.assert_close(pytorch_out, trt_out, rtol=1e-2, atol=1e-2) + + if __name__ == "__main__": run_tests() From 8783ecfa02f6da75a8f93710b72ec1b2a2bff310 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Sat, 21 Mar 2026 01:09:27 -0700 Subject: [PATCH 3/5] fix --- py/torch_tensorrt/dynamo/backend/backends.py | 5 --- .../lowering/passes/_aten_lowering_pass.py | 2 +- .../force_causal_efficient_attention.py | 43 +++++++++++++++++++ 3 files changed, 44 insertions(+), 6 deletions(-) create mode 100644 py/torch_tensorrt/dynamo/lowering/passes/force_causal_efficient_attention.py diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index 5ce62133ff..00fb6977e8 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -21,9 +21,6 @@ remove_sym_nodes, repair_input_aliasing, ) -from torch_tensorrt.dynamo.lowering.passes.force_causal_efficient_attention import ( - force_causal_efficient_attention, -) from torch_tensorrt.dynamo.utils import ( parse_dynamo_kwargs, prepare_inputs, @@ -151,8 +148,6 @@ def _pretraced_backend( logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph)) - # Keep the behavior of efficient attention consistent with pre_export_lowering path. - gm = force_causal_efficient_attention(gm, settings) gm = post_lowering(gm, settings) logger.debug("Lowered Input graph:\n " + str(gm.graph)) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index 3b7a8ded83..02be2e98b0 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -26,7 +26,6 @@ remove_detach, remove_assert_nodes, rule_based_autocast, - force_causal_efficient_attention, ] post_lowering_pass_list = [ @@ -38,6 +37,7 @@ remove_assert_nodes, remove_num_users_is_0_nodes, complex_graph_detection, + force_causal_efficient_attention, ] if not is_tegra_platform(): diff --git a/py/torch_tensorrt/dynamo/lowering/passes/force_causal_efficient_attention.py b/py/torch_tensorrt/dynamo/lowering/passes/force_causal_efficient_attention.py new file mode 100644 index 0000000000..71f09576b2 --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/force_causal_efficient_attention.py @@ -0,0 +1,43 @@ +import logging + +import torch +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) + +logger = logging.getLogger(__name__) + + +def force_causal_efficient_attention( + gm: torch.fx.GraphModule, settings: CompilationSettings +) -> torch.fx.GraphModule: + """Force efficient-attention calls to causal mode when enabled in settings.""" + if not settings.attn_bias_is_causal: + return gm + + changed = False + for node in gm.graph.nodes: + if ( + node.target + == torch.ops.aten._scaled_dot_product_efficient_attention.default + ): + node.args = ( + node.args[0], + node.args[1], + node.args[2], + None, + False, + 0.0, + True, + ) + changed = True + logger.debug( + f"The args of node {node} was changed to causal mode. Now the node's arguments are: {node.args}" + ) + + if changed: + gm = clean_up_graph_after_modifications(gm) + + logger.debug(f"After forcing causal efficient attention pass:\n{gm.graph}") + return gm From 9d8d7b599d888302dc6d75636454987302d3f467 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Sat, 21 Mar 2026 20:23:49 -0700 Subject: [PATCH 4/5] set attn_bias_is_causal=False to BERT test --- tests/py/dynamo/models/test_models.py | 2 ++ tests/py/dynamo/models/test_models_export.py | 1 + 2 files changed, 3 insertions(+) diff --git a/tests/py/dynamo/models/test_models.py b/tests/py/dynamo/models/test_models.py index b95968809d..17fc0ec488 100644 --- a/tests/py/dynamo/models/test_models.py +++ b/tests/py/dynamo/models/test_models.py @@ -315,6 +315,7 @@ def test_bert_base_uncased(ir, dtype): "cache_built_engines": False, "reuse_cached_engines": False, "use_explicit_typing": True, + "attn_bias_is_causal": False, # BERT uses bidirectional self-attention instead of causal } trt_mod = torchtrt.compile(model, **compile_spec) @@ -365,6 +366,7 @@ def test_bert_base_uncased_cpu_offload(ir): "cache_built_engines": False, "reuse_cached_engines": False, "offload_module_to_cpu": True, + "attn_bias_is_causal": False, # BERT uses bidirectional self-attention instead of causal } trt_mod = torchtrt.compile(model, **compile_spec) if ir == "dynamo": diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index 8e378ada4b..06b427702f 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -157,6 +157,7 @@ def test_bert_base_uncased(ir): "min_block_size": 10, "cache_built_engines": False, "reuse_cached_engines": False, + "attn_bias_is_causal": False, # BERT uses bidirectional self-attention instead of causal } trt_mod = torchtrt.compile(model, **compile_spec) model_outputs = model(input, input2) From baf0a01d419e005666e0b7cb0cce3c0f7920a10a Mon Sep 17 00:00:00 2001 From: Evan Li Date: Sat, 21 Mar 2026 21:24:47 -0700 Subject: [PATCH 5/5] add unit tests of dynamic bs --- .../dynamo/conversion/test_attention_aten.py | 105 ++++++++++++++++++ 1 file changed, 105 insertions(+) diff --git a/tests/py/dynamo/conversion/test_attention_aten.py b/tests/py/dynamo/conversion/test_attention_aten.py index babeae0215..d5a7e4cbdb 100644 --- a/tests/py/dynamo/conversion/test_attention_aten.py +++ b/tests/py/dynamo/conversion/test_attention_aten.py @@ -4,6 +4,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input from .harness import DispatchTestCase @@ -259,6 +260,110 @@ def forward(self, query, key, value, attn_mask=None): use_explicit_typing=True, ) + @parameterized.expand( + [ + ( + [(2, 8, 32, 16), (4, 8, 32, 16), (16, 8, 32, 16)], + [(2, 8, 32, 16), (4, 8, 32, 16), (16, 8, 32, 16)], + [(2, 8, 32, 16), (4, 8, 32, 16), (16, 8, 32, 16)], + None, + False, + None, + torch.float16, + 0.0, + False, + ), + ( + [(2, 8, 32, 16), (4, 8, 32, 16), (16, 8, 32, 16)], + [(2, 8, 32, 16), (4, 8, 32, 16), (16, 8, 32, 16)], + [(2, 8, 32, 16), (4, 8, 32, 16), (16, 8, 32, 16)], + None, + True, + None, + torch.float32, + 0.0, + False, + ), + ( + [(2, 8, 32, 16), (4, 8, 32, 16), (16, 8, 32, 16)], + [(2, 8, 32, 16), (4, 8, 32, 16), (16, 8, 32, 16)], + [(2, 8, 32, 16), (4, 8, 32, 16), (16, 8, 32, 16)], + [(2, 8, 32, 32), (4, 8, 32, 32), (16, 8, 32, 32)], + False, + None, + torch.float16, + 0.0, + False, + ), + ( + [(2, 4, 128, 64), (4, 4, 128, 64), (8, 4, 128, 64)], + [(2, 4, 128, 64), (4, 4, 128, 64), (8, 4, 128, 64)], + [(2, 4, 128, 64), (4, 4, 128, 64), (8, 4, 128, 64)], + [(2, 4, 128, 128), (4, 4, 128, 128), (8, 4, 128, 128)], + True, + 2.0, + torch.float32, + 0.0, + False, + ), + ] + ) + def test_dynamic_sdpa_fp_mask( + self, + q_shape, + k_shape, + v_shape, + attn_mask_shape, + is_causal, + scale, + dtype, + dropout_p=0.0, + enable_gqa=False, + ): + class SDPA(nn.Module): + def forward(self, query, key, value, attn_mask=None): + return torch.ops.aten.scaled_dot_product_attention.default( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + + input_specs = [ + Input( + min_shape=q_shape[0], + opt_shape=q_shape[1], + max_shape=q_shape[2], + dtype=dtype, + ), + Input( + min_shape=k_shape[0], + opt_shape=k_shape[1], + max_shape=k_shape[2], + dtype=dtype, + ), + Input( + min_shape=v_shape[0], + opt_shape=v_shape[1], + max_shape=v_shape[2], + dtype=dtype, + ), + ] + if attn_mask_shape is not None: + input_specs.append( + Input( + min_shape=attn_mask_shape[0], + opt_shape=attn_mask_shape[1], + max_shape=attn_mask_shape[2], + dtype=dtype, + ), + ) + self.run_test_with_dynamic_shape(SDPA(), input_specs, output_dtypes=[dtype]) + class TestScaledDotProductEfficientAttention(DispatchTestCase): @parameterized.expand(