Skip to content

Conversation

@khatwanimohit
Copy link
Collaborator

@khatwanimohit khatwanimohit commented Jan 15, 2026

Description

This work was done in collaboration with @richjames0 and @NicoGrande

This PR introduces attention data parallelism (attn_dp) to optimize memory efficiency when the number of KV heads is less than tensor parallelism. The attention DP degree is auto-calculated based on the ratio of tensor parallelism to KV heads, ensuring optimal sharding without manual configuration.
New logical axes (attn_activation_length, attn_activation_embed) and corresponding sharding rules have been added to support attention-specific tensor partitioning separate from the rest of the model

Notice 1: Once all tests pass, the "pull ready" label will automatically be assigned.
This label is used for administrative purposes. Please do not add it manually.

Notice 2: For external contributions, our settings currently require an approval from a MaxText maintainer to trigger CI tests.

Tests

Ran Llama3.1-8B by reducing num_kv_heads to 2 and ici_tensor_parallelism=8. This will auto-calculate attn_dp and set it to 4.
vllm mesh: mesh=Mesh('data': 1, 'attn_dp': 4, 'expert': 1, 'model': 2, axis_types=(Auto, Auto, Auto, Auto))

Attention shardings:

llama2.py:158] bfloat16[96,1,4096]............................................................. (None, None, ('model', 'attn_dp')).
attentions.py:1014] bfloat16[96,1,4096]............................................................. (None, 'attn_dp', 'model').
attentions.py:1078] bfloat16[96,1,32,128]........................................................... (None, 'attn_dp', 'model', None).
attentions.py:1079] bfloat16[96,1,8,128]............................................................ (None, 'attn_dp', 'model', None).
linears.py:527] bfloat16[96,1,14336]............................................................ (None, None, ('model', 'attn_dp')).

Weight shardings:

maxtext_utils.py:1197] decoder/decoder_norm/scale/value................................................ bfloat16[4096] (('model', 'attn_dp'),)
maxtext_utils.py:1197] decoder/layers_0/mlp/wi_0/kernel/value.......................................... bfloat16[4096,14336] (None, ('model', 'attn_dp'))
maxtext_utils.py:1197] decoder/layers_0/mlp/wi_1/kernel/value.......................................... bfloat16[4096,14336] (None, ('model', 'attn_dp'))
maxtext_utils.py:1197] decoder/layers_0/mlp/wo/kernel/value............................................ bfloat16[14336,4096] (('model', 'attn_dp'), None)
maxtext_utils.py:1197] decoder/layers_0/post_self_attention_layer_norm/scale/value..................... bfloat16[4096] (('model', 'attn_dp'),)
maxtext_utils.py:1197] decoder/layers_0/pre_self_attention_layer_norm/scale/value...................... bfloat16[4096] (('model', 'attn_dp'),)
maxtext_utils.py:1197] decoder/layers_0/self_attention/key/kernel/value................................ bfloat16[4096,8,128] (None, 'model', None)
maxtext_utils.py:1197] decoder/layers_0/self_attention/out/kernel/value................................ bfloat16[32,128,4096] ('model', None, None)
maxtext_utils.py:1197] decoder/layers_0/self_attention/query/kernel/value.............................. bfloat16[4096,32,128] (None, 'model', None)
maxtext_utils.py:1197] decoder/layers_0/self_attention/value/kernel/value.............................. bfloat16[4096,8,128] (None, 'model', None)
decoder/logits_dense/kernel/value............................................... bfloat16[4096,128256] (None, ('model', 'attn_dp'))
token_embedder/embedding/value.................................................. bfloat16[128256,4096] (('model', 'attn_dp'), None)

Running Qwen3-30b-moe with tp=8, it has 4 kv heads so model=4 and attn_dp=2
Command:

NEW_MODEL_DESIGN=1 python3 -m MaxText.vllm_decode     --model_name qwen3-30b-a3b  --hf_model_name Qwen/Qwen3-30B-A3B     --hf_config_path src/MaxText/integration/vllm/maxtext_vllm_adapter  --ici_tensor_parallelism 8  --gpu_memory_utilization 0.5  --prompt "Suggest some famous landmarks in London" --debug_sharding true --enable_dp_attention true --load_parameters_path gs://parambole-qwen3-moe-verification/unscanned/qwen3-30b-a3b-thinking-2507/14_08_2025/0/items 2>&1 | tee attn_output

Output: https://paste.googleplex.com/5980793354715136

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link

codecov bot commented Jan 15, 2026

Codecov Report

❌ Patch coverage is 30.43478% with 16 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/MaxText/layers/moe.py 0.00% 5 Missing and 1 partial ⚠️
src/MaxText/vllm_decode.py 0.00% 5 Missing ⚠️
src/MaxText/model_creation_utils.py 25.00% 3 Missing ⚠️
src/MaxText/train_compile.py 0.00% 1 Missing ⚠️
src/MaxText/train_utils.py 0.00% 1 Missing ⚠️

📢 Thoughts on this report? Let us know!

@khatwanimohit khatwanimohit force-pushed the mohit/attn_dp branch 3 times, most recently from f54a2f8 to 6417a4b Compare January 20, 2026 15:18
Copy link
Collaborator

@NicoGrande NicoGrande left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Awesome work!

Copy link
Collaborator

@NuojCheng NuojCheng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is surprising to see

attentions.py:1014] bfloat16[96,1,4096]............................................................. (None, 'attn_dp', 'model').

with 'attn_dp' = 4. It should produce error instead. Is it a typo?

@khatwanimohit
Copy link
Collaborator Author

It is surprising to see

attentions.py:1014] bfloat16[96,1,4096]............................................................. (None, 'attn_dp', 'model').

with 'attn_dp' = 4. It should produce error instead. Is it a typo?

I just re-ran Llama3.1-8B with 2 kv heads and here are the updated shardings. You are right, this should have errored out, not sure what happened earlier. I remember in that run earlier we were missing a env variable (NEW_MODE_DESIGN=1)in vLLM which was responsible for using new shardings which includes attn_dp

Copy link
Collaborator

@NuojCheng NuojCheng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you Mohit! LGTM

@copybara-service copybara-service bot merged commit 3c8e0fa into main Jan 21, 2026
29 of 31 checks passed
@copybara-service copybara-service bot deleted the mohit/attn_dp branch January 21, 2026 19:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants