Skip to content

Conversation

@tdophung
Copy link
Collaborator

Description

To enable expert parallelism in MOE, we need to perform a ragged all-to-all after permutation to rearrange the permuted tokens (grouped by experts) onto all GPUs such that each GPU only store a subset of the experts. Ragged all-to-all is used because the number of tokens per expert is most often times, not the same between experts. To do this ragged all-to-all operation, we need to provide it with arguments specifying the number of tokens_per_expert, or often called group sizes in maxText.

In this PR, we compute these group sizes as part of the permutation operation and return them by default

#2585
#2536

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

return tokens_per_expert by default from summing up expert columns in routing map.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

tdophung and others added 16 commits January 7, 2026 10:00
Signed-off-by: tdophung <tdophung@nvidia.com>
…ging_probs booleans. Implement partitioning for all permutation primitives

Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
…ernel zero intiialize output permuted scales, permuted probs and output tokens

Signed-off-by: tdophung <tdophung@nvidia.com>
…tead, add extra input (aliased wiuth output) buffer to inner primitive of permutation on jax side to pass in zero intitiated buffers done with jnp zeros

Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
…s in utils

Signed-off-by: tdophung <tdophung@nvidia.com>
…/TransformerEngine into custom_partitioning_permutation
Signed-off-by: tdophung <tdophung@nvidia.com>
…/TransformerEngine into custom_partitioning_permutation
Signed-off-by: tdophung <tdophung@nvidia.com>
…/TransformerEngine into custom_partitioning_permutation
Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung tdophung changed the title Add group size permutation [Draft] Add group size permutation Jan 21, 2026
@tdophung tdophung marked this pull request as draft January 21, 2026 19:50
Signed-off-by: tdophung <tdophung@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 21, 2026

Greptile Summary

  • Modifies MoE permutation API to always return tokens_per_expert counts enabling expert parallelism for distributed Mixture of Experts models
  • Adds comprehensive distributed sharding support to JAX permutation operations including new partition methods and consolidated primitives
  • Implements JAX input_output_aliases support across PyTorch and common Triton kernels by adding buffer parameters for cross-framework compatibility

Important Files Changed

Filename Overview
transformer_engine/jax/permutation.py Critical undefined variable bug on line 171(out_tokens_per_expert used before assignment); changes return type to always provide tokens_per_expert counts
transformer_engine/jax/triton_extensions/permutation.py Major architectural changes adding distributed sharding methods and consolidating unpermute primitives
transformer_engine/jax/triton_extensions/utils.py Works around memory corruption bug in jaxlib autotuning by removing input_output_aliases

Confidence score: 2/5

  • This PR contains a critical runtime bug that will cause immediate failures when the modified code is executed
  • Score lowered due to undefined variable out_tokens_per_expert on line 171 in transformer_engine/jax/permutation.py which will raise NameError
  • Pay close attention to the variable naming issue in the main permutation file which needs to be fixed before merge

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

)
)
return output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert
return output, permuted_probs, row_id_map, pad_offsets, out_tokens_per_expert
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax: out_tokens_per_expert is undefined - should be target_tokens_per_expert

Suggested change
return output, permuted_probs, row_id_map, pad_offsets, out_tokens_per_expert
return output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert

@tdophung tdophung closed this Jan 21, 2026
@tdophung
Copy link
Collaborator Author

Opened this on the wrong branch. Closing

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant