-
Notifications
You must be signed in to change notification settings - Fork 454
Support attention data parallelism #2955
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
9f95406 to
6f98db7
Compare
1cacf8f to
863f779
Compare
f54a2f8 to
6417a4b
Compare
NicoGrande
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Awesome work!
NuojCheng
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is surprising to see
attentions.py:1014] bfloat16[96,1,4096]............................................................. (None, 'attn_dp', 'model').
with 'attn_dp' = 4. It should produce error instead. Is it a typo?
6417a4b to
96a2d49
Compare
I just re-ran Llama3.1-8B with 2 kv heads and here are the updated shardings. You are right, this should have errored out, not sure what happened earlier. I remember in that run earlier we were missing a env variable ( |
96a2d49 to
351eebc
Compare
NuojCheng
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you Mohit! LGTM
Description
This work was done in collaboration with @richjames0 and @NicoGrande
This PR introduces attention data parallelism (attn_dp) to optimize memory efficiency when the number of KV heads is less than tensor parallelism. The attention DP degree is auto-calculated based on the ratio of tensor parallelism to KV heads, ensuring optimal sharding without manual configuration.
New logical axes (attn_activation_length, attn_activation_embed) and corresponding sharding rules have been added to support attention-specific tensor partitioning separate from the rest of the model
Notice 1: Once all tests pass, the "pull ready" label will automatically be assigned.
This label is used for administrative purposes. Please do not add it manually.
Notice 2: For external contributions, our settings currently require an approval from a MaxText maintainer to trigger CI tests.
Tests
Ran Llama3.1-8B by reducing num_kv_heads to 2 and ici_tensor_parallelism=8. This will auto-calculate attn_dp and set it to 4.
vllm mesh:
mesh=Mesh('data': 1, 'attn_dp': 4, 'expert': 1, 'model': 2, axis_types=(Auto, Auto, Auto, Auto))Attention shardings:
Weight shardings:
Running Qwen3-30b-moe with tp=8, it has 4 kv heads so model=4 and attn_dp=2
Command:
Output: https://paste.googleplex.com/5980793354715136
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.