Skip to content

Conversation

@sudhakarsingh27
Copy link
Collaborator

Description

FusedAttention supports "right" side sliding window attention for some time now. This adds support for SWA (left, right) with FusedAttention backend in TE.
(changes cherry-picked from original PR: #1369)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

transformer_engine

  • common

    • fused_attn
      • fused_attn.cpp
        • add bottom_right_diagonal parameter to the API
        • Edit the filters to allow sliding window config to pick arbitrary seqlen fused attn backend
      • fused_attn_f16_arbitrary_seqlen.cu: add bottom_right_diagonal parameter to the API
      • fused_attn_fp8.cu: add bottom_right_diagonal parameter to the FADescriptor_v1 API
      • utils.h: add bottom_right_diagonal parameter to FADescriptor_v1 API
  • pytorch

    • transformer.py
      • plumb bottom_right_diagonal through the call stack: TransformerLayer --> SelfAttention/CrossAttention
    • attention
      • dot_product_attention
        • backends.py:
          • UnfusedDotProductAttention
            • add bottom_right_diagonal parameter to the forward API
              • why is it not used in the forward?
                • bottom_right_alignment is being used in the Alibi call, perhaps this should be corrected
          • FusedAttn custom module
            • add bottom_right_diagonal parameter to the forward API
          • FusedAttention module
            • plumb bottom_right_diagonal through the call stack
        • dot_product_attention.py
          • DotProductAttention
            • Plumb bottom_right_diagonal through the call stack
            • Add calculation of bottom_right_diagonal if it's None
        • utils.py
          • AttentionParams
            • [x]
          • get_attention_backend
            • update sliding window filter section
            • update attention bias filter section
      • multi_head_attention.py
        • Add bottom_right_diagonal to forward API and call
        • Add calculation of bottom_right_diagonal if it's None
    • cpp_extentions
      • fused_attn.py
        • plumb bottom_right_diagonal in fused_attn_fwd/fused_attn_bwd
    • csrc
      • extension
        • attention.cpp
          • plumb bottom_right_diagonal through the call stack: fused_attn_fwd --> nvte_fused_attn_fwd
          • same as above for bwd
      • extensions.h
        • add bottom_right_diagonal to fused_attn_fwd and fused_attn_bwd API definitions

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

…IA#1369

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch L0

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Dec 4, 2025

Greptile Summary

This PR adds support for sliding window attention (SWA) with configurable left and right window sizes to the FusedAttention backend. The implementation plumbs a new bottom_right_diagonal parameter through the entire stack from PyTorch and JAX frontends to the C++/CUDA backend, enabling control over diagonal alignment for sliding window masks.

Critical Issues Found:

  • Variable name bugs in utils.py (lines 911, 938): Sets use_flash_attention = False instead of use_flash_attention_2 = False, incorrectly disabling all FlashAttention backends instead of just FlashAttention 2 for specific cross-attention scenarios
  • FP8 implementation incomplete: The fused_attn_fp8_fwd_impl_v1 and fused_attn_fp8_bwd_impl_v1 functions hardcode bottom_right_diagonal to true instead of accepting it as a parameter, preventing FP8 users from configuring this feature

Positive aspects:

  • F16 arbitrary sequence length backend properly implements the parameter
  • JAX implementation correctly plumbs the parameter through FFI bindings
  • Backend selection filters updated to handle new sliding window configurations
  • Test coverage added for sliding window attention

Confidence Score: 2/5

  • This PR has critical bugs that will cause incorrect behavior in production
  • Two critical variable name bugs in utils.py will disable all FlashAttention backends when only FlashAttention 2 should be disabled. FP8 attention implementation hardcodes bottom_right_diagonal instead of making it configurable, limiting the feature's usefulness for FP8 workloads.
  • Pay close attention to transformer_engine/pytorch/attention/dot_product_attention/utils.py (variable name bugs) and transformer_engine/common/fused_attn/fused_attn_fp8.cu (hardcoded parameters)

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/utils.py Variable name bug causes incorrect backend selection for sliding window and ALiBi attention
transformer_engine/common/fused_attn/fused_attn_fp8.cu Hardcoded bottom_right_diagonal values bypass parameter passing mechanism for FP8 attention
transformer_engine/common/fused_attn/fused_attn.cpp Added bottom_right_diagonal parameter plumbing through API, updated backend selection filters

Sequence Diagram

sequenceDiagram
    participant User
    participant PyTorch/JAX Frontend
    participant CPP Extensions
    participant Fused Attn Backend
    participant cuDNN

    User->>PyTorch/JAX Frontend: Call attention with window_size_left/right
    PyTorch/JAX Frontend->>PyTorch/JAX Frontend: Calculate bottom_right_diagonal from mask_type
    PyTorch/JAX Frontend->>CPP Extensions: Pass bottom_right_diagonal parameter
    CPP Extensions->>Fused Attn Backend: Forward to nvte_fused_attn_fwd/bwd
    
    alt F16 Arbitrary SeqLen Backend
        Fused Attn Backend->>Fused Attn Backend: Use bottom_right_diagonal parameter
        Fused Attn Backend->>cuDNN: Pass to FADescriptor_v1
        cuDNN-->>Fused Attn Backend: Execute with correct alignment
    else FP8 Backend
        Fused Attn Backend->>Fused Attn Backend: Hardcode bottom_right_diagonal=true
        Note over Fused Attn Backend: BUG: Ignores parameter
        Fused Attn Backend->>cuDNN: Pass hardcoded true to FADescriptor_v1
        cuDNN-->>Fused Attn Backend: Always uses bottom-right alignment
    end
    
    Fused Attn Backend-->>CPP Extensions: Return attention output
    CPP Extensions-->>PyTorch/JAX Frontend: Return result
    PyTorch/JAX Frontend-->>User: Return attention output
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (2)

  1. transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py, line 1281 (link)

    logic: Trailing comma creates single-element tuple instead of boolean - should this be just bottom_right_alignment = attn_mask_type not in ["causal", "padding_causal"]?

  2. transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py, line 1482 (link)

    style: Uses hardcoded mask type check instead of the new bottom_right_diagonal parameter for ALiBi alignment. Should this use bottom_right_diagonal parameter for consistency?

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

15 files reviewed, 8 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +777 to +783
if self_attn_mask_type in {"causal", "padding_causal"}:
bottom_right_diagonal = False
if bottom_right_diagonal is None or self_attn_mask_type in {
"causal_bottom_right",
"padding_causal_bottom_right",
}:
bottom_right_diagonal = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Logic overrides the instance variable even when explicitly set in forward call - should preserve user's explicit choice. Should the mask type check override an explicitly passed bottom_right_diagonal parameter, or only apply when it's None?

Comment on lines +787 to +793
if enc_dec_attn_mask_type in {"causal", "padding_causal"}:
enc_dec_bottom_right_diagonal = False
if enc_dec_bottom_right_diagonal is None or enc_dec_attn_mask_type in {
"causal_bottom_right",
"padding_causal_bottom_right",
}:
enc_dec_bottom_right_diagonal = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Same logic issue as above - mask type check overrides explicit parameter values

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

15 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (4)

  1. transformer_engine/pytorch/transformer.py, line 777-783 (link)

    logic: Conditional logic overwrites explicitly passed parameter. When self_attn_mask_type in {"causal", "padding_causal"} at line 777, this unconditionally sets bottom_right_diagonal = False, even if the user explicitly passed a different value in the forward call. The check at line 779 (if bottom_right_diagonal is None) becomes unreachable in those cases.

    Consider restructuring to only apply defaults when bottom_right_diagonal is None:

  2. transformer_engine/pytorch/transformer.py, line 787-793 (link)

    logic: Same conditional logic issue for encoder-decoder attention. The mask type check at line 787 unconditionally overwrites enc_dec_bottom_right_diagonal, ignoring explicitly passed values.

  3. transformer_engine/pytorch/attention/dot_product_attention/utils.py, line 920 (link)

    logic: Variable name mismatch: setting use_flash_attention instead of use_flash_attention_2

  4. transformer_engine/pytorch/attention/dot_product_attention/utils.py, line 947 (link)

    logic: Variable name mismatch: setting use_flash_attention instead of use_flash_attention_2

15 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch L0

1 similar comment
@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch L0

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py, line 1283 (link)

    syntax: Tuple has trailing comma causing it to be a single-element tuple instead of boolean

15 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch L0

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

15 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch L0

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch L1

@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch L1

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 13, 2026

Greptile's behavior is changing!

From now on, if a review finishes with no comments, we will not post an additional "statistics" comment to confirm that our review found nothing to comment on. However, you can confirm that we reviewed your changes in the status check section.

This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR".

@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch L1

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch L1

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

17 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically.
bottom_right_diagonal: bool, default = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Documentation inconsistency: forward function documents default as None but backward documents True

Suggested change
bottom_right_diagonal: bool, default = True
bottom_right_diagonal: bool, default = None

0,
0,
true,
true,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Critical: Hardcoded parameter instead of proper API parameter

This hardcoded true value should be replaced with a proper bottom_right_diagonal parameter passed from the function signature. The current implementation bypasses the intended parameter passing mechanism.

Issue: The function fused_attn_fp8_fwd_impl_v1 should accept bottom_right_diagonal as a parameter and pass it here, rather than hardcoding true. This breaks the configurability that the PR is supposed to add. Should this hardcoded true be replaced with a bottom_right_diagonal parameter that gets passed through the function signature?

0,
0,
true,
false,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Critical: Hardcoded parameter instead of proper API parameter

Similar to the forward pass, this hardcoded false value should be replaced with a proper bottom_right_diagonal parameter passed from the function signature.

Issue: The function fused_attn_fp8_bwd_impl_v1 should accept bottom_right_diagonal as a parameter and pass it here. The hardcoded false suggests different behavior for backward pass, but this should still be configurable. Should this hardcoded false be replaced with a bottom_right_diagonal parameter, and why is it different from the forward pass value?

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch L1

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

17 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

"Disabling FlashAttention as it only supports sliding window with bottom right"
" diagonal alignment for cross-attention"
)
use_flash_attention = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Wrong variable name - should be use_flash_attention_2 = False not use_flash_attention = False. This causes all FlashAttention backends to be incorrectly disabled for sliding window with cross-attention when only FlashAttention 2 should be disabled.

Suggested change
use_flash_attention = False
use_flash_attention_2 = False

"Disabling FlashAttention as it only supports ALiBi with bottom right diagonal"
" alignment for cross-attention"
)
use_flash_attention = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Wrong variable name - should be use_flash_attention_2 = False not use_flash_attention = False. This causes all FlashAttention backends to be incorrectly disabled for ALiBi with cross-attention when only FlashAttention 2 should be disabled.

Suggested change
use_flash_attention = False
use_flash_attention_2 = False

NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX,
0,
0,
true,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Hardcoded bottom_right_diagonal value bypasses the parameter passing mechanism. The FP8 forward function needs to accept bottom_right_diagonal as a parameter and pass it here instead of hardcoding true. Should FP8 attention support configurable bottom_right_diagonal, or is there a technical reason it must always be true?

NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX,
0,
0,
true,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Hardcoded bottom_right_diagonal value in backward pass. Should accept this as a parameter and pass it through instead of hardcoding true.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

@cyanguwa
Copy link
Collaborator

You may need to add the extra argument bottom_right_diagonal to the nvte_xxx C API calls in Jax too, otherwise TE-Jax won't build. Have a look at PR 2584 maybe, where the deterministic arg was added, so files such as jax/csrc/extensions/attention.cpp need to be changed as well. Thanks!

@sudhakarsingh27
Copy link
Collaborator Author

/te-ci L1

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

20 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

"Disabling FlashAttention as it only supports sliding window with bottom right"
" diagonal alignment for cross-attention"
)
use_flash_attention = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax: Wrong variable - should be use_flash_attention_2 = False not use_flash_attention = False. This disables all FlashAttention backends when only FlashAttention 2 should be disabled for sliding window with cross-attention.

Suggested change
use_flash_attention = False
use_flash_attention_2 = False

"Disabling FlashAttention as it only supports ALiBi with bottom right diagonal"
" alignment for cross-attention"
)
use_flash_attention = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax: Wrong variable - should be use_flash_attention_2 = False not use_flash_attention = False. This disables all FlashAttention backends when only FlashAttention 2 should be disabled for ALiBi with cross-attention.

Suggested change
use_flash_attention = False
use_flash_attention = False

Comment on lines 1709 to +1710
true,
true,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Hardcoded bottom_right_diagonal to true bypasses the configurable parameter mechanism. The fused_attn_fp8_fwd_impl_v1 function should accept bottom_right_diagonal as a parameter and pass it here instead of hardcoding. Is there a technical reason FP8 attention must always use bottom_right_diagonal=true, or should this be configurable?

NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX,
0,
0,
true,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Hardcoded bottom_right_diagonal to true in backward pass. Should accept this as a parameter for consistency with the forward pass configuration.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

@sudhakarsingh27
Copy link
Collaborator Author

/te-ci L1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants