-
Notifications
You must be signed in to change notification settings - Fork 612
Add support for SWA (left, right) with FusedAttention #2477
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add support for SWA (left, right) with FusedAttention #2477
Conversation
…IA#1369 Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci pytorch L0 |
Greptile SummaryThis PR adds support for sliding window attention (SWA) with configurable left and right window sizes to the FusedAttention backend. The implementation plumbs a new Critical Issues Found:
Positive aspects:
Confidence Score: 2/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant PyTorch/JAX Frontend
participant CPP Extensions
participant Fused Attn Backend
participant cuDNN
User->>PyTorch/JAX Frontend: Call attention with window_size_left/right
PyTorch/JAX Frontend->>PyTorch/JAX Frontend: Calculate bottom_right_diagonal from mask_type
PyTorch/JAX Frontend->>CPP Extensions: Pass bottom_right_diagonal parameter
CPP Extensions->>Fused Attn Backend: Forward to nvte_fused_attn_fwd/bwd
alt F16 Arbitrary SeqLen Backend
Fused Attn Backend->>Fused Attn Backend: Use bottom_right_diagonal parameter
Fused Attn Backend->>cuDNN: Pass to FADescriptor_v1
cuDNN-->>Fused Attn Backend: Execute with correct alignment
else FP8 Backend
Fused Attn Backend->>Fused Attn Backend: Hardcode bottom_right_diagonal=true
Note over Fused Attn Backend: BUG: Ignores parameter
Fused Attn Backend->>cuDNN: Pass hardcoded true to FADescriptor_v1
cuDNN-->>Fused Attn Backend: Always uses bottom-right alignment
end
Fused Attn Backend-->>CPP Extensions: Return attention output
CPP Extensions-->>PyTorch/JAX Frontend: Return result
PyTorch/JAX Frontend-->>User: Return attention output
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (2)
-
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py, line 1281 (link)logic: Trailing comma creates single-element tuple instead of boolean - should this be just
bottom_right_alignment = attn_mask_type not in ["causal", "padding_causal"]? -
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py, line 1482 (link)style: Uses hardcoded mask type check instead of the new
bottom_right_diagonalparameter for ALiBi alignment. Should this usebottom_right_diagonalparameter for consistency?Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
15 files reviewed, 8 comments
transformer_engine/pytorch/attention/dot_product_attention/backends.py
Outdated
Show resolved
Hide resolved
transformer_engine/pytorch/attention/dot_product_attention/utils.py
Outdated
Show resolved
Hide resolved
| if self_attn_mask_type in {"causal", "padding_causal"}: | ||
| bottom_right_diagonal = False | ||
| if bottom_right_diagonal is None or self_attn_mask_type in { | ||
| "causal_bottom_right", | ||
| "padding_causal_bottom_right", | ||
| }: | ||
| bottom_right_diagonal = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Logic overrides the instance variable even when explicitly set in forward call - should preserve user's explicit choice. Should the mask type check override an explicitly passed bottom_right_diagonal parameter, or only apply when it's None?
| if enc_dec_attn_mask_type in {"causal", "padding_causal"}: | ||
| enc_dec_bottom_right_diagonal = False | ||
| if enc_dec_bottom_right_diagonal is None or enc_dec_attn_mask_type in { | ||
| "causal_bottom_right", | ||
| "padding_causal_bottom_right", | ||
| }: | ||
| enc_dec_bottom_right_diagonal = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Same logic issue as above - mask type check overrides explicit parameter values
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu
Outdated
Show resolved
Hide resolved
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu
Outdated
Show resolved
Hide resolved
transformer_engine/pytorch/attention/dot_product_attention/backends.py
Outdated
Show resolved
Hide resolved
transformer_engine/pytorch/attention/dot_product_attention/utils.py
Outdated
Show resolved
Hide resolved
transformer_engine/pytorch/attention/dot_product_attention/utils.py
Outdated
Show resolved
Hide resolved
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
15 files reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (4)
-
transformer_engine/pytorch/transformer.py, line 777-783 (link)logic: Conditional logic overwrites explicitly passed parameter. When
self_attn_mask_type in {"causal", "padding_causal"}at line 777, this unconditionally setsbottom_right_diagonal = False, even if the user explicitly passed a different value in the forward call. The check at line 779 (if bottom_right_diagonal is None) becomes unreachable in those cases.Consider restructuring to only apply defaults when
bottom_right_diagonal is None: -
transformer_engine/pytorch/transformer.py, line 787-793 (link)logic: Same conditional logic issue for encoder-decoder attention. The mask type check at line 787 unconditionally overwrites
enc_dec_bottom_right_diagonal, ignoring explicitly passed values. -
transformer_engine/pytorch/attention/dot_product_attention/utils.py, line 920 (link)logic: Variable name mismatch: setting
use_flash_attentioninstead ofuse_flash_attention_2 -
transformer_engine/pytorch/attention/dot_product_attention/utils.py, line 947 (link)logic: Variable name mismatch: setting
use_flash_attentioninstead ofuse_flash_attention_2
15 files reviewed, 4 comments
|
/te-ci pytorch L0 |
1 similar comment
|
/te-ci pytorch L0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py, line 1283 (link)syntax: Tuple has trailing comma causing it to be a single-element tuple instead of boolean
15 files reviewed, 1 comment
|
/te-ci pytorch L0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
15 files reviewed, no comments
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
|
/te-ci pytorch L0 |
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
a56c4d2 to
a545ebf
Compare
|
/te-ci pytorch L1 |
|
/te-ci pytorch L1 |
Greptile's behavior is changing!From now on, if a review finishes with no comments, we will not post an additional "statistics" comment to confirm that our review found nothing to comment on. However, you can confirm that we reviewed your changes in the status check section. This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR". |
|
/te-ci pytorch L1 |
transformer_engine/common/include/transformer_engine/fused_attn.h
Outdated
Show resolved
Hide resolved
transformer_engine/pytorch/attention/dot_product_attention/backends.py
Outdated
Show resolved
Hide resolved
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
|
/te-ci pytorch L1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
17 files reviewed, 3 comments
| in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q | ||
| + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding | ||
| window and causal mask specifically. | ||
| bottom_right_diagonal: bool, default = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Documentation inconsistency: forward function documents default as None but backward documents True
| bottom_right_diagonal: bool, default = True | |
| bottom_right_diagonal: bool, default = None |
| 0, | ||
| 0, | ||
| true, | ||
| true, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Critical: Hardcoded parameter instead of proper API parameter
This hardcoded true value should be replaced with a proper bottom_right_diagonal parameter passed from the function signature. The current implementation bypasses the intended parameter passing mechanism.
Issue: The function fused_attn_fp8_fwd_impl_v1 should accept bottom_right_diagonal as a parameter and pass it here, rather than hardcoding true. This breaks the configurability that the PR is supposed to add. Should this hardcoded true be replaced with a bottom_right_diagonal parameter that gets passed through the function signature?
| 0, | ||
| 0, | ||
| true, | ||
| false, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Critical: Hardcoded parameter instead of proper API parameter
Similar to the forward pass, this hardcoded false value should be replaced with a proper bottom_right_diagonal parameter passed from the function signature.
Issue: The function fused_attn_fp8_bwd_impl_v1 should accept bottom_right_diagonal as a parameter and pass it here. The hardcoded false suggests different behavior for backward pass, but this should still be configurable. Should this hardcoded false be replaced with a bottom_right_diagonal parameter, and why is it different from the forward pass value?
|
/te-ci pytorch L1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
17 files reviewed, 4 comments
| "Disabling FlashAttention as it only supports sliding window with bottom right" | ||
| " diagonal alignment for cross-attention" | ||
| ) | ||
| use_flash_attention = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Wrong variable name - should be use_flash_attention_2 = False not use_flash_attention = False. This causes all FlashAttention backends to be incorrectly disabled for sliding window with cross-attention when only FlashAttention 2 should be disabled.
| use_flash_attention = False | |
| use_flash_attention_2 = False |
| "Disabling FlashAttention as it only supports ALiBi with bottom right diagonal" | ||
| " alignment for cross-attention" | ||
| ) | ||
| use_flash_attention = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Wrong variable name - should be use_flash_attention_2 = False not use_flash_attention = False. This causes all FlashAttention backends to be incorrectly disabled for ALiBi with cross-attention when only FlashAttention 2 should be disabled.
| use_flash_attention = False | |
| use_flash_attention_2 = False |
| NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX, | ||
| 0, | ||
| 0, | ||
| true, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Hardcoded bottom_right_diagonal value bypasses the parameter passing mechanism. The FP8 forward function needs to accept bottom_right_diagonal as a parameter and pass it here instead of hardcoding true. Should FP8 attention support configurable bottom_right_diagonal, or is there a technical reason it must always be true?
| NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX, | ||
| 0, | ||
| 0, | ||
| true, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Hardcoded bottom_right_diagonal value in backward pass. Should accept this as a parameter and pass it through instead of hardcoding true.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
|
You may need to add the extra argument |
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci L1 |
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…ansformerEngine into swa_padding_brcm_try2
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
20 files reviewed, 4 comments
| "Disabling FlashAttention as it only supports sliding window with bottom right" | ||
| " diagonal alignment for cross-attention" | ||
| ) | ||
| use_flash_attention = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: Wrong variable - should be use_flash_attention_2 = False not use_flash_attention = False. This disables all FlashAttention backends when only FlashAttention 2 should be disabled for sliding window with cross-attention.
| use_flash_attention = False | |
| use_flash_attention_2 = False |
| "Disabling FlashAttention as it only supports ALiBi with bottom right diagonal" | ||
| " alignment for cross-attention" | ||
| ) | ||
| use_flash_attention = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: Wrong variable - should be use_flash_attention_2 = False not use_flash_attention = False. This disables all FlashAttention backends when only FlashAttention 2 should be disabled for ALiBi with cross-attention.
| use_flash_attention = False | |
| use_flash_attention = False |
| true, | ||
| true, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Hardcoded bottom_right_diagonal to true bypasses the configurable parameter mechanism. The fused_attn_fp8_fwd_impl_v1 function should accept bottom_right_diagonal as a parameter and pass it here instead of hardcoding. Is there a technical reason FP8 attention must always use bottom_right_diagonal=true, or should this be configurable?
| NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX, | ||
| 0, | ||
| 0, | ||
| true, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Hardcoded bottom_right_diagonal to true in backward pass. Should accept this as a parameter for consistency with the forward pass configuration.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
|
/te-ci L1 |
Description
FusedAttention supports "right" side sliding window attention for some time now. This adds support for SWA (left, right) with FusedAttention backend in TE.
(changes cherry-picked from original PR: #1369)
Type of change
Changes
Please list the changes introduced in this PR:
transformer_enginecommonfused_attnfused_attn.cppbottom_right_diagonalparameter to the APIfused_attn_f16_arbitrary_seqlen.cu: addbottom_right_diagonalparameter to the APIfused_attn_fp8.cu: addbottom_right_diagonalparameter to theFADescriptor_v1APIutils.h: addbottom_right_diagonalparameter toFADescriptor_v1APIpytorchtransformer.pybottom_right_diagonalthrough the call stack:TransformerLayer-->SelfAttention/CrossAttentionattentiondot_product_attentionbackends.py:UnfusedDotProductAttentionbottom_right_diagonalparameter to theforwardAPIforward?bottom_right_alignmentis being used in the Alibi call, perhaps this should be correctedFusedAttncustom modulebottom_right_diagonalparameter to theforwardAPIFusedAttentionmodulebottom_right_diagonalthrough the call stackdot_product_attention.pyDotProductAttentionbottom_right_diagonalthrough the call stackbottom_right_diagonalif it'sNoneutils.pyAttentionParamsget_attention_backendmulti_head_attention.pybottom_right_diagonalto forward API and callbottom_right_diagonalif it'sNonecpp_extentionsfused_attn.pybottom_right_diagonalinfused_attn_fwd/fused_attn_bwdcsrcextensionattention.cppbottom_right_diagonalthrough the call stack:fused_attn_fwd-->nvte_fused_attn_fwdextensions.hbottom_right_diagonaltofused_attn_fwdandfused_attn_bwdAPI definitionsChecklist: