Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions docsrc/user_guide/compilation/compilation_settings.rst
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,11 @@ Optimization Tuning
- Disable TensorFloat-32 (TF32) accumulation. TF32 is enabled by default on
Ampere and newer GPUs and provides FP32-range with FP16-speed for matmul/conv.
Disable only when you need strict IEEE FP32 semantics.
* - ``attn_bias_is_causal``
- ``True``
- Whether the attn_bias in efficient SDPA is causal. Default is True. This can
accelerate models from HF because attn_bias is always a causal mask in HF.
If you want to use non-causal attn_bias, you can set this to False.

----

Expand Down Expand Up @@ -350,6 +355,10 @@ Graph Partitioning
- ``False``
- Use ``aot_autograd`` for tracing instead of the default path. Required when the
model contains ``DTensor`` or other distributed tensors.
* - ``decompose_attention``
- ``False``
- Decompose attention layers into smaller ops. We have converters for handling attention ops,
but if you want to decompose them into smaller ops, you can set this to True.

----

Expand All @@ -372,7 +381,7 @@ Compilation Workflow
- ``False``
- Defer TRT engine deserialization until all engines have been built.
Works around resource contraints and builder overhad but engines
may be less well tuned to their deployment resource availablity
may be less well tuned to their deployment resource availability
* - ``debug``
- ``False``
- Enable verbose TRT builder logs at ``DEBUG`` level.
Expand Down Expand Up @@ -444,7 +453,7 @@ rebuilt from scratch:
``l2_limit_for_tiling``, ``enable_autocast``, ``autocast_low_precision_type``,
``autocast_excluded_nodes``, ``autocast_excluded_ops``,
``autocast_max_output_threshold``, ``autocast_max_depth_of_reduction``,
``autocast_calibration_dataloader``.
``autocast_calibration_dataloader``, ``decompose_attention``, ``attn_bias_is_causal``.

Settings not in this list (e.g., ``debug``, ``dryrun``, ``pass_through_build_failures``)
can be changed without invalidating the cache.
9 changes: 9 additions & 0 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def cross_compile_for_windows(
cpu_memory_budget: Optional[int] = _defaults.CPU_MEMORY_BUDGET,
dynamically_allocate_resources: bool = _defaults.DYNAMICALLY_ALLOCATE_RESOURCES,
decompose_attention: bool = _defaults.DECOMPOSE_ATTENTION,
attn_bias_is_causal: bool = _defaults.ATTN_BIAS_IS_CAUSAL,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile an ExportedProgram module using TensorRT in Linux for Inference in Windows
Expand Down Expand Up @@ -190,6 +191,7 @@ def cross_compile_for_windows(
cpu_memory_budget (Optional[int]): The maximum amount of CPU memory to use for the compilation. If the compilation requires more memory than this budget, the compilation will fail.
dynamically_allocate_resources (bool): Dynamically allocate resources during engine execution.
decompose_attention (bool): Whether to decompose attention layers. We have converters for handling attention ops, but if you want to decompose them into smaller ops, you can set this to True.
attn_bias_is_causal (bool): Whether the attn_bias in efficient SDPA is causal. Default is True. This can accelerate models from HF because attn_bias is always a causal mask in HF. If you want to use non-causal attn_bias, you can set this to False.
**kwargs: Any,
Returns:
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
Expand Down Expand Up @@ -349,6 +351,7 @@ def cross_compile_for_windows(
"cpu_memory_budget": cpu_memory_budget,
"dynamically_allocate_resources": dynamically_allocate_resources,
"decompose_attention": decompose_attention,
"attn_bias_is_causal": attn_bias_is_causal,
}

# disable the following settings is not supported for cross compilation for windows feature
Expand Down Expand Up @@ -471,6 +474,7 @@ def compile(
enable_resource_partitioning: bool = _defaults.ENABLE_RESOURCE_PARTITIONING,
dynamically_allocate_resources: bool = _defaults.DYNAMICALLY_ALLOCATE_RESOURCES,
decompose_attention: bool = _defaults.DECOMPOSE_ATTENTION,
attn_bias_is_causal: bool = _defaults.ATTN_BIAS_IS_CAUSAL,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
Expand Down Expand Up @@ -559,6 +563,7 @@ def compile(
cpu_memory_budget (Optional[int]): The maximum amount of CPU memory to use for the compilation. If the compilation requires more memory than this budget, the compilation will fail.
dynamically_allocate_resources (bool): Dynamically allocate resources during engine execution.
decompose_attention (bool): Whether to decompose attention layers. We have converters for handling attention ops, but if you want to decompose them into smaller ops, you can set this to True.
attn_bias_is_causal (bool): Whether the attn_bias in efficient SDPA is causal. Default is True. This can accelerate models from HF because attn_bias is always a causal mask in HF. If you want to use non-causal attn_bias, you can set this to False.
**kwargs: Any,
Returns:
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
Expand Down Expand Up @@ -763,6 +768,7 @@ def compile(
"cpu_memory_budget": cpu_memory_budget,
"dynamically_allocate_resources": dynamically_allocate_resources,
"decompose_attention": decompose_attention,
"attn_bias_is_causal": attn_bias_is_causal,
}
logger.debug(f"CPU memory usage before lowering: {get_cpu_memory_usage()} MB")
settings = CompilationSettings(**compilation_options)
Expand Down Expand Up @@ -1167,6 +1173,7 @@ def convert_exported_program_to_serialized_trt_engine(
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE,
decompose_attention: bool = _defaults.DECOMPOSE_ATTENTION,
attn_bias_is_causal: bool = _defaults.ATTN_BIAS_IS_CAUSAL,
**kwargs: Any,
) -> bytes:
"""Convert an ExportedProgram to a serialized TensorRT engine
Expand Down Expand Up @@ -1242,6 +1249,7 @@ def convert_exported_program_to_serialized_trt_engine(
offload_module_to_cpu (bool): Offload the module to CPU. This is useful when we need to minimize GPU memory usage.
use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model.
decompose_attention (bool): Whether to decompose attention layers. We have converters for handling attention ops, but if you want to decompose them into smaller ops, you can set this to True.
attn_bias_is_causal (bool): Whether the attn_bias in efficient SDPA is causal. Default is True. This can accelerate models from HF because attn_bias is always a causal mask in HF. If you want to use non-causal attn_bias, you can set this to False.
**kwargs: Any,
Returns:
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
Expand Down Expand Up @@ -1412,6 +1420,7 @@ def convert_exported_program_to_serialized_trt_engine(
"offload_module_to_cpu": offload_module_to_cpu,
"use_distributed_mode_trace": use_distributed_mode_trace,
"decompose_attention": decompose_attention,
"attn_bias_is_causal": attn_bias_is_causal,
}

settings = CompilationSettings(**compilation_options)
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
CPU_MEMORY_BUDGET = None
DYNAMICALLY_ALLOCATE_RESOURCES = False
DECOMPOSE_ATTENTION = False
ATTN_BIAS_IS_CAUSAL = True

if platform.system() == "Linux":
import pwd
Expand Down
4 changes: 4 additions & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch_tensorrt._enums import EngineCapability, dtype
from torch_tensorrt.dynamo._defaults import (
ASSUME_DYNAMIC_SHAPE_SUPPORT,
ATTN_BIAS_IS_CAUSAL,
AUTOCAST_CALIBRATION_DATALOADER,
AUTOCAST_EXCLUDED_NODES,
AUTOCAST_EXCLUDED_OPS,
Expand Down Expand Up @@ -120,6 +121,7 @@ class CompilationSettings:
offload_module_to_cpu (bool): Offload the model to CPU to reduce memory footprint during compilation
dynamically_allocate_resources (bool): Dynamically allocate resources for TensorRT engines
decompose_attention (bool): Whether to decompose attention layers. We have converters for handling attention ops, but if you want to decompose them into smaller ops, you can set this to True.
attn_bias_is_causal (bool): Whether the attn_bias in efficient SDPA is causal. Default is True. This can accelerate models from HF because attn_bias is always a causal mask in HF. If you want to use non-causal attn_bias, you can set this to False.
"""

enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
Expand Down Expand Up @@ -180,6 +182,7 @@ class CompilationSettings:
cpu_memory_budget: Optional[int] = CPU_MEMORY_BUDGET
dynamically_allocate_resources: bool = DYNAMICALLY_ALLOCATE_RESOURCES
decompose_attention: bool = DECOMPOSE_ATTENTION
attn_bias_is_causal: bool = ATTN_BIAS_IS_CAUSAL

def __getstate__(self) -> dict[str, Any]:
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
Expand Down Expand Up @@ -220,6 +223,7 @@ def __setstate__(self, state: dict[str, Any]) -> None:
"autocast_max_depth_of_reduction",
"autocast_calibration_dataloader",
"decompose_attention",
"attn_bias_is_causal",
}


Expand Down
125 changes: 86 additions & 39 deletions py/torch_tensorrt/dynamo/conversion/impl/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torch_tensorrt.dynamo.conversion.converter_utils import (
cast_trt_tensor,
get_trt_tensor,
prepend_ones,
)

_LOGGER: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -193,18 +194,17 @@ def scaled_dot_product_attention(
if attn_mask is not None:
if attn_mask.dtype == trt.DataType.BOOL:
mask_tensor = attn_mask
elif attn_mask.dtype != query.dtype:
mask_tensor = cast_trt_tensor(
ctx,
attn_mask,
query.dtype,
name + "_cast_attn_mask",
target,
source_ir,
)
else:
if attn_mask.dtype != query.dtype:
mask_tensor = cast_trt_tensor(
ctx,
attn_mask,
query.dtype,
name + "_cast_attn_mask",
target,
source_ir,
)
else:
mask_tensor = attn_mask
mask_tensor = attn_mask

scaled_query = impl.elementwise.mul(
ctx, target, source_ir, f"{name}_scaled_query", query, scale_factor
Expand Down Expand Up @@ -329,41 +329,88 @@ def scaled_dot_product_efficient_attention(
source_ir,
)

mask_tensor = None
if attn_bias is not None:
attn_weight = impl.matmul.matrix_multiply(
ctx,
target,
source_ir,
name + "_mm",
scaled_query,
key,
other_matrix_op=trt.MatrixOperation.TRANSPOSE,
)
attn_weight = impl.elementwise.add(
ctx, target, source_ir, name + "_attn_bias_add", attn_weight, attn_bias
)
attn_weight = impl.normalization.softmax(
ctx, target, source_ir, name + "_softmax", attn_weight, -1, False
)
out = impl.matmul.matrix_multiply(
if attn_bias.dtype == trt.DataType.BOOL:
mask_tensor = attn_bias
elif attn_bias.dtype != query.dtype:
mask_tensor = cast_trt_tensor(
ctx,
attn_bias,
query.dtype,
name + "_cast_attn_bias",
target,
source_ir,
)
else:
mask_tensor = attn_bias

# TensorRT IAttention does not allow setting both causal=True and mask.
# If both are requested, fold causal into mask and disable causal flag.
use_causal = is_causal
if mask_tensor is not None and is_causal:
L = impl.shape.shape(ctx, target, source_ir, name + "_L", query, -2)
S = impl.shape.shape(ctx, target, source_ir, name + "_S", key, -2)
causal_mask = tril(
ctx,
target,
source_ir,
name + "_out",
attn_weight,
value,
)
return out, None, None, None
else:
attention_layer = ctx.net.add_attention(
scaled_query, key, value, trt.AttentionNormalizationOp.SOFTMAX, is_causal
name + "_tril",
L,
S,
)
assert attention_layer is not None, "attention layer is None"
diff = len(query.shape) - len(causal_mask.shape)
causal_mask = prepend_ones(ctx, causal_mask, name + "_prepend_ones", diff)

if mask_tensor.dtype == trt.DataType.BOOL:
mask_tensor = impl.elementwise.logical_and(
ctx,
target,
source_ir,
name + "_causal_attn_bias_and",
causal_mask,
mask_tensor,
)
else:
# Convert causal bool mask to additive bias mask:
# True -> 0.0 (keep), False -> -inf (block)
zero_bias = get_trt_tensor(
ctx, 0.0, name + "_causal_additive_bias_zero", query.dtype
)
neg_inf_bias = get_trt_tensor(
ctx, float("-inf"), name + "_causal_additive_bias_neg_inf", query.dtype
)
causal_additive_bias = impl.condition.where(
ctx,
target,
source_ir,
name + "_causal_additive_bias",
zero_bias,
neg_inf_bias,
causal_mask,
)
mask_tensor = impl.elementwise.add(
ctx,
target,
source_ir,
name + "_attn_bias_add_causal",
mask_tensor,
causal_additive_bias,
)
use_causal = False

attention_layer.decomposable = True
attention_layer = ctx.net.add_attention(
scaled_query, key, value, trt.AttentionNormalizationOp.SOFTMAX, use_causal
)
assert attention_layer is not None, "attention layer is None"

attention_output = attention_layer.get_output(0)
return attention_output, None, None, None
if mask_tensor is not None:
attention_layer.mask = mask_tensor

attention_layer.decomposable = True

attention_output = attention_layer.get_output(0)
return attention_output, None, None, None


def scaled_dot_product_cudnn_attention(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from .complex_graph_rewrite import complex_graph_detection
from .constant_folding import constant_fold
from .force_causal_efficient_attention import force_causal_efficient_attention
from .fuse_prims_broadcast import fuse_prims_broadcast
from .pass_manager import DynamoPassManager
from .remove_assert_nodes import remove_assert_nodes
Expand All @@ -36,6 +37,7 @@
remove_assert_nodes,
remove_num_users_is_0_nodes,
complex_graph_detection,
force_causal_efficient_attention,
]

if not is_tegra_platform():
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import logging

import torch
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
)

logger = logging.getLogger(__name__)


def force_causal_efficient_attention(
gm: torch.fx.GraphModule, settings: CompilationSettings
) -> torch.fx.GraphModule:
"""Force efficient-attention calls to causal mode when enabled in settings."""
if not settings.attn_bias_is_causal:
return gm

changed = False
for node in gm.graph.nodes:
if (
node.target
== torch.ops.aten._scaled_dot_product_efficient_attention.default
):
node.args = (
node.args[0],
node.args[1],
node.args[2],
None,
False,
0.0,
True,
)
changed = True
logger.debug(
f"The args of node {node} was changed to causal mode. Now the node's arguments are: {node.args}"
)

if changed:
gm = clean_up_graph_after_modifications(gm)

logger.debug(f"After forcing causal efficient attention pass:\n{gm.graph}")
return gm
Loading
Loading