-
Notifications
You must be signed in to change notification settings - Fork 1.5k
[Metal] Add Flash Attention VJP for training #2995
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Notes: I'm working on https://github.com/mlx-node/mlx-node and trying to port some features in trl. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copilot encountered an error and was unable to review this pull request. You can try again by re-requesting a review.
mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_vjp_dkv.h
Outdated
Show resolved
Hide resolved
mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_vjp_dkv.h
Outdated
Show resolved
Hide resolved
568ff36 to
26b5857
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copilot encountered an error and was unable to review this pull request. You can try again by re-requesting a review.
dd8daf1 to
5c78507
Compare
zcbenz
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you share some benchmarking numbers?
5c78507 to
9008d2e
Compare
|
Benchmark Summary for PR Key observations:
|
|
I ran the benchmark script on a M4 air and it seems that fused VJP is slower for most cases: $ python benchmarks/python/sdpa_vjp_bench.py
SDPA VJP Benchmark - dtype=float16
=====================================================================================
[Forward + Backward (VJP)]
B H_q H_kv L D | unfused fused speedup path
-------------------------------------------------------------------------------------
2 8 8 1 64 | 0.07ms 0.07ms 1.09x vector
2 8 8 4 64 | 0.07ms 0.07ms 1.03x vector
2 8 8 8 64 | 0.07ms 0.07ms 1.05x vector
2 8 8 8 128 | 0.08ms 0.08ms 0.99x vector
2 8 8 32 64 | 0.10ms 0.11ms 0.89x STEEL
2 8 8 64 64 | 0.18ms 0.21ms 0.84x STEEL
2 8 8 128 64 | 0.45ms 0.54ms 0.82x STEEL
2 8 8 128 128 | 0.59ms 0.69ms 0.85x STEEL
2 8 8 256 128 | 1.93ms 2.32ms 0.83x STEEL
1 32 8 512 64 | 11.45ms 14.91ms 0.77x STEEL
1 32 8 512 128 | 14.26ms 17.73ms 0.80x STEEL
1 32 8 1024 64 | 50.49ms 64.39ms 0.78x STEEL
1 32 8 1024 128 | 56.06ms 69.96ms 0.80x STEEL
1 32 8 2048 128 | 244.84ms 299.90ms 0.82x STEEL
2 32 8 256 64 | 6.16ms 8.06ms 0.76x STEEL
2 32 4 256 64 | 6.23ms 8.02ms 0.78x STEEL |
|
Interesting, slight difference in the results between the M4 Air and my MacBook Pro 16 with the M3 Max and 128GB of RAM.
The primary benefit of Flash Attention VJP is memory, not speed:
For long sequences (4K+), memory savings can be critical even if speed is similar or slightly slower. |
5bba7f2 to
21caa99
Compare
|
I tested on a M3 Max and the results are still mixed: $ python benchmarks/python/sdpa_vjp_bench.py
SDPA VJP Benchmark - dtype=float16
=====================================================================================
[Forward + Backward (VJP)]
B H_q H_kv L D | unfused fused speedup path
-------------------------------------------------------------------------------------
2 8 8 1 64 | 0.30ms 0.26ms 1.17x vector
2 8 8 4 64 | 0.34ms 0.24ms 1.43x vector
2 8 8 8 64 | 0.28ms 0.25ms 1.14x vector
2 8 8 8 128 | 0.29ms 0.33ms 0.88x vector
2 8 8 32 64 | 0.33ms 0.36ms 0.92x STEEL
2 8 8 64 64 | 0.43ms 0.28ms 1.55x STEEL
2 8 8 128 64 | 0.29ms 0.33ms 0.86x STEEL
2 8 8 128 128 | 0.31ms 0.31ms 0.98x STEEL
2 8 8 256 128 | 0.39ms 0.40ms 0.97x STEEL
1 32 8 512 64 | 1.02ms 1.01ms 1.02x STEEL
1 32 8 512 128 | 1.30ms 1.34ms 0.97x STEEL
1 32 8 1024 64 | 3.75ms 3.75ms 1.00x STEEL
1 32 8 1024 128 | 4.81ms 4.81ms 1.00x STEEL
1 32 8 2048 128 | 16.65ms 16.66ms 1.00x STEEL
2 32 8 256 64 | 0.63ms 0.66ms 0.95x STEEL
2 32 4 256 64 | 0.65ms 0.67ms 0.97x STEELI'm not really familiar with how VJP works for flash attention but generally speaking reduced memory usage should make op faster when it is memory-bound, which is usually the case for mac, so I think fused op being slower likely means something is off. But anyway the result on vector path looks promising, maybe separate that part into an independent PR first? It would also make reviewing much easier. |
Implement fused backward pass for scaled_dot_product_attention on short sequences (L≤8) using the vector kernel approach. This eliminates the O(N²) memory requirement of unfused attention by recomputing the attention matrix on-the-fly during backpropagation. Key changes: - Add sdpa_vector_vjp.h with GPU kernels for computing dQ, dK, dV - Extend forward pass to output logsumexp (LSE) when needed for VJP - Add comprehensive Python tests for gradient correctness - Fix CUDA cuDNN backward to handle masks via set_bias() (removes unnecessary fallback) Performance (M3 Max, L≤8): - 1.1-1.4x faster than unfused attention for backward pass - Memory: O(N) instead of O(N²) for attention matrix The STEEL VJP for longer sequences (L>8) will be added in a follow-up PR.
21caa99 to
4875a9d
Compare
|
Done, I only preserve the |
zcbenz
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not familiar with the Metal implementation so we will need other maintainers to review the code.
| int /* n_kv_heads */) { | ||
| // Force unfused attention when masks/sinks present | ||
| if (has_mask || has_sinks) { | ||
| return true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be removed.
Summary
Implements fused backward pass (VJP) for
scaled_dot_product_attentionon Metal GPU. This enables efficient gradient computation during training without falling back to unfused (decomposed) attention operations.Changes
New Files
mlx/backend/metal/kernels/sdpa_vector_vjp.h- Vector VJP kernel for short sequences (L ≤ 8)mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_vjp_dq.h- STEEL dQ gradient kernelmlx/backend/metal/kernels/steel/attn/kernels/steel_attention_vjp_dkv.h- STEEL dK/dV gradient kernelModified Files
mlx/backend/metal/scaled_dot_product_attention.cpp- VJP dispatch logic (+840 lines)mlx/fast.cpp/mlx/fast_primitives.h- Logsumexp caching, VJP routingpython/tests/test_fast_sdpa.py- Comprehensive VJP tests (+220 lines)Implementation Notes
Uses a two-kernel approach to avoid atomic operations:
dQ kernel (
steel_attention_vjp_dq.h):dK/dV kernel (
steel_attention_vjp_dkv.h):Vector VJP (
sdpa_vector_vjp.h):Key Features
Limitations
Test Plan
test_sdpa_gradpassestest_sdpa_grad_vector_path- short sequences (L=1,4,7,8)test_sdpa_grad_steel_path- longer sequences (L=16,32,128,256)test_sdpa_grad_head_dims- head dimensions (D=32,64,96,128)test_sdpa_grad_gqa- GQA configurations (4:1, 8:1, 16:1, MHA)test_sdpa_grad_dtypes- float16, bfloat16, float32test_sdpa_grad_edge_cases- L=1, non-power-of-2, large batch, qL≠kvLAll 21 SDPA tests pass (1 skipped for unrelated disabled feature).