Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,9 @@ After installation completes, run the training script.
- In Wan2.1, the ici_fsdp_parallelism axis is used for sequence parallelism, the ici_tensor_parallelism axis is used for head parallelism.
- You can enable both, keeping in mind that Wan2.1 has 40 heads and 40 must be evenly divisible by ici_tensor_parallelism.
- For Sequence parallelism, the code pads the sequence length to evenly divide the sequence. Try out different ici_fsdp_parallelism numbers, but we find 2 and 4 to be the best right now.
- For use on GPU it is recommended to enable the cudnn_te_flash attention kernel for optimal performance.
- Best performance is achieved with the use of batch parallelism, which can be enabled by using the ici_fsdp_batch_parallelism axis. Note that this parallelism strategy does not support fractional batch sizes.
- ici_fsdp_batch_parallelism and ici_fsdp_parallelism can be combined to allow for fractional batch sizes. However, padding is not currently supported for the cudnn_te_flash attention kernel and it is therefore required that the sequence length is divisible by the number of devices in the ici_fsdp_parallelism axis.

You should eventually see a training run as:

Expand Down
13 changes: 7 additions & 6 deletions src/maxdiffusion/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
# Physical axis names for device meshes.
DATA = "data"
FSDP = "fsdp"
CONTEXT = "context"
TENSOR = "tensor"
# Logical axis names for model parameters and activations.
BATCH = "activation_batch"
Expand Down Expand Up @@ -67,18 +68,18 @@
### Common axis rules for ring attention ###
RING_ATTENTION_AXIS_RULES = [
[SELF_ATTN_HEAD, None],
[SELF_ATTN_Q_LENGTH, FSDP],
[SELF_ATTN_KV_LENGTH, FSDP],
[SELF_ATTN_Q_LENGTH, CONTEXT],
[SELF_ATTN_KV_LENGTH, CONTEXT],
[CROSS_ATTN_HEAD, None],
[CROSS_ATTN_Q_LENGTH, FSDP],
[CROSS_ATTN_KV_LENGTH, FSDP],
[CROSS_ATTN_Q_LENGTH, CONTEXT],
[CROSS_ATTN_KV_LENGTH, CONTEXT],
]

SEQUENCE_PARALLEL_AXIS_RULES = [
[SELF_ATTN_HEAD, None],
[SELF_ATTN_Q_LENGTH, FSDP],
[SELF_ATTN_Q_LENGTH, CONTEXT],
[SELF_ATTN_KV_LENGTH, None],
[CROSS_ATTN_HEAD, None],
[CROSS_ATTN_Q_LENGTH, FSDP],
[CROSS_ATTN_Q_LENGTH, CONTEXT],
[CROSS_ATTN_KV_LENGTH, None],
]
6 changes: 4 additions & 2 deletions src/maxdiffusion/configs/base14.yml
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ skip_jax_distributed_system: False
base_output_directory: ""

# Parallelism
mesh_axes: ['data', 'fsdp', 'tensor']
mesh_axes: ['data', 'fsdp', 'context', 'tensor']

# batch : batch dimension of data and activations
# hidden :
Expand All @@ -131,17 +131,19 @@ logical_axis_rules: [
['out_channels', 'tensor'],
['conv_out', 'fsdp'],
]
data_sharding: [['data', 'fsdp', 'tensor']]
data_sharding: [['data', 'fsdp', 'context', 'tensor']]

# One axis for each parallelism type may hold a placeholder (-1)
# value to auto-shard based on available slices and devices.
# By default, product of the DCN axes should equal number of slices
# and product of the ICI axes should equal number of devices per slice.
dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded
dcn_fsdp_parallelism: 1
dcn_context_parallelism: 1
dcn_tensor_parallelism: 1
ici_data_parallelism: -1 # recommended ICI axis to be auto-sharded for TPUv5e
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
ici_context_parallelism: 1
ici_tensor_parallelism: 1

allow_split_physical_axes: False
Expand Down
6 changes: 4 additions & 2 deletions src/maxdiffusion/configs/base21.yml
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ skip_jax_distributed_system: False
base_output_directory: ""

# Parallelism
mesh_axes: ['data', 'fsdp', 'tensor']
mesh_axes: ['data', 'fsdp', 'context', 'tensor']

# batch : batch dimension of data and activations
# hidden :
Expand All @@ -133,17 +133,19 @@ logical_axis_rules: [
['out_channels', 'tensor'],
['conv_out', 'fsdp'],
]
data_sharding: [['data', 'fsdp', 'tensor']]
data_sharding: [['data', 'fsdp', 'context', 'tensor']]

# One axis for each parallelism type may hold a placeholder (-1)
# value to auto-shard based on available slices and devices.
# By default, product of the DCN axes should equal number of slices
# and product of the ICI axes should equal number of devices per slice.
dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded
dcn_fsdp_parallelism: 1
dcn_context_parallelism: 1
dcn_tensor_parallelism: 1
ici_data_parallelism: -1 # recommended ICI axis to be auto-sharded for TPUv5e
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
ici_context_parallelism: 1
ici_tensor_parallelism: 1

allow_split_physical_axes: False
Expand Down
6 changes: 4 additions & 2 deletions src/maxdiffusion/configs/base_2_base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ skip_jax_distributed_system: False
base_output_directory: ""

# Parallelism
mesh_axes: ['data', 'fsdp', 'tensor']
mesh_axes: ['data', 'fsdp', 'context', 'tensor']

# batch : batch dimension of data and activations
# hidden :
Expand All @@ -146,17 +146,19 @@ logical_axis_rules: [
['out_channels', 'tensor'],
['conv_out', 'fsdp'],
]
data_sharding: [['data', 'fsdp', 'tensor']]
data_sharding: [['data', 'fsdp', 'context', 'tensor']]

# One axis for each parallelism type may hold a placeholder (-1)
# value to auto-shard based on available slices and devices.
# By default, product of the DCN axes should equal number of slices
# and product of the ICI axes should equal number of devices per slice.
dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded
dcn_fsdp_parallelism: 1
dcn_context_parallelism: 1
dcn_tensor_parallelism: 1
ici_data_parallelism: -1 # recommended ICI axis to be auto-sharded for TPUv5e
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
ici_context_parallelism: 1
ici_tensor_parallelism: 1

allow_split_physical_axes: False
Expand Down
6 changes: 4 additions & 2 deletions src/maxdiffusion/configs/base_flux_dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
skip_jax_distributed_system: False

# Parallelism
mesh_axes: ['data', 'fsdp', 'tensor']
mesh_axes: ['data', 'fsdp', 'context', 'tensor']

# batch : batch dimension of data and activations
# hidden :
Expand All @@ -158,17 +158,19 @@ logical_axis_rules: [
['out_channels', 'tensor'],
['conv_out', 'fsdp'],
]
data_sharding: [['data', 'fsdp', 'tensor']]
data_sharding: [['data', 'fsdp', 'context', 'tensor']]

# One axis for each parallelism type may hold a placeholder (-1)
# value to auto-shard based on available slices and devices.
# By default, product of the DCN axes should equal number of slices
# and product of the ICI axes should equal number of devices per slice.
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
dcn_fsdp_parallelism: -1
dcn_context_parallelism: 1
dcn_tensor_parallelism: 1
ici_data_parallelism: -1
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
ici_context_parallelism: 1
ici_tensor_parallelism: 1

allow_split_physical_axes: False
Expand Down
6 changes: 4 additions & 2 deletions src/maxdiffusion/configs/base_flux_dev_multi_res.yml
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
skip_jax_distributed_system: False

# Parallelism
mesh_axes: ['data', 'fsdp', 'tensor']
mesh_axes: ['data', 'fsdp', 'context', 'tensor']

# batch : batch dimension of data and activations
# hidden :
Expand All @@ -158,17 +158,19 @@ logical_axis_rules: [
['out_channels', 'tensor'],
['conv_out', 'fsdp'],
]
data_sharding: [['data', 'fsdp', 'tensor']]
data_sharding: [['data', 'fsdp', 'context', 'tensor']]

# One axis for each parallelism type may hold a placeholder (-1)
# value to auto-shard based on available slices and devices.
# By default, product of the DCN axes should equal number of slices
# and product of the ICI axes should equal number of devices per slice.
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
dcn_fsdp_parallelism: -1
dcn_context_parallelism: 1
dcn_tensor_parallelism: 1
ici_data_parallelism: -1
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
ici_context_parallelism: 1
ici_tensor_parallelism: 1

allow_split_physical_axes: False
Expand Down
6 changes: 4 additions & 2 deletions src/maxdiffusion/configs/base_flux_schnell.yml
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
skip_jax_distributed_system: False

# Parallelism
mesh_axes: ['data', 'fsdp', 'tensor']
mesh_axes: ['data', 'fsdp', 'context', 'tensor']

# batch : batch dimension of data and activations
# hidden :
Expand All @@ -166,17 +166,19 @@ logical_axis_rules: [
['out_channels', 'tensor'],
['conv_out', 'fsdp'],
]
data_sharding: [['data', 'fsdp', 'tensor']]
data_sharding: [['data', 'fsdp', 'context', 'tensor']]

# One axis for each parallelism type may hold a placeholder (-1)
# value to auto-shard based on available slices and devices.
# By default, product of the DCN axes should equal number of slices
# and product of the ICI axes should equal number of devices per slice.
dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded
dcn_fsdp_parallelism: 1
dcn_context_parallelism: 1
dcn_tensor_parallelism: 1
ici_data_parallelism: -1
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
ici_context_parallelism: 1
ici_tensor_parallelism: 1

allow_split_physical_axes: False
Expand Down
26 changes: 14 additions & 12 deletions src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
skip_jax_distributed_system: False

# Parallelism
mesh_axes: ['data', 'fsdp', 'tensor']
mesh_axes: ['data', 'fsdp', 'context', 'tensor']

# batch : batch dimension of data and activations
# hidden :
Expand All @@ -166,31 +166,33 @@ mesh_axes: ['data', 'fsdp', 'tensor']
# conv_in : conv.shape[2] weight
# conv_out : conv.shape[-1] weight
logical_axis_rules: [
['batch', 'data'],
['activation_batch', 'data'],
['activation_self_attn_heads', ['fsdp', 'tensor']],
['activation_cross_attn_q_length', ['fsdp', 'tensor']],
['activation_length', 'fsdp'],
['batch', ['data', 'fsdp']],
['activation_batch', ['data', 'fsdp']],
['activation_self_attn_heads', ['context', 'tensor']],
['activation_cross_attn_q_length', ['context', 'tensor']],
['activation_length', 'context'],
['activation_heads', 'tensor'],
['mlp','tensor'],
['embed','fsdp'],
['embed', ['context', 'fsdp']],
['heads', 'tensor'],
['norm', 'tensor'],
['conv_batch', ['data','fsdp']],
['conv_batch', ['data', 'context', 'fsdp']],
['out_channels', 'tensor'],
['conv_out', 'fsdp'],
['conv_out', 'context'],
]
data_sharding: [['data', 'fsdp', 'tensor']]
data_sharding: [['data', 'fsdp', 'context', 'tensor']]

# One axis for each parallelism type may hold a placeholder (-1)
# value to auto-shard based on available slices and devices.
# By default, product of the DCN axes should equal number of slices
# and product of the ICI axes should equal number of devices per slice.
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
dcn_fsdp_parallelism: -1
dcn_fsdp_parallelism: 1
dcn_context_parallelism: -1
dcn_tensor_parallelism: 1
ici_data_parallelism: 1
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
ici_fsdp_parallelism: 1
ici_context_parallelism: -1 # recommended ICI axis to be auto-sharded
ici_tensor_parallelism: 1

allow_split_physical_axes: False
Expand Down
25 changes: 14 additions & 11 deletions src/maxdiffusion/configs/base_wan_27b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
skip_jax_distributed_system: False

# Parallelism
mesh_axes: ['data', 'fsdp', 'tensor']
mesh_axes: ['data', 'fsdp', 'context', 'tensor']

# batch : batch dimension of data and activations
# hidden :
Expand All @@ -154,30 +154,33 @@ mesh_axes: ['data', 'fsdp', 'tensor']
# conv_in : conv.shape[2] weight
# conv_out : conv.shape[-1] weight
logical_axis_rules: [
['batch', 'data'],
['activation_batch', 'data'],
['activation_length', 'fsdp'],

['batch', ['data', 'fsdp']],
['activation_batch', ['data', 'fsdp']],
['activation_self_attn_heads', ['context', 'tensor']],
['activation_cross_attn_q_length', ['context', 'tensor']],
['activation_length', 'context'],
['activation_heads', 'tensor'],
['mlp','tensor'],
['embed','fsdp'],
['embed', ['context', 'fsdp']],
['heads', 'tensor'],
['norm', 'tensor'],
['conv_batch', ['data','fsdp']],
['conv_batch', ['data', 'context', 'fsdp']],
['out_channels', 'tensor'],
['conv_out', 'fsdp'],
['conv_out', 'context'],
]
data_sharding: [['data', 'fsdp', 'tensor']]
data_sharding: [['data', 'fsdp', 'context', 'tensor']]

# One axis for each parallelism type may hold a placeholder (-1)
# value to auto-shard based on available slices and devices.
# By default, product of the DCN axes should equal number of slices
# and product of the ICI axes should equal number of devices per slice.
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
dcn_fsdp_parallelism: -1
dcn_fsdp_parallelism: 1
dcn_context_parallelism: -1
dcn_tensor_parallelism: 1
ici_data_parallelism: 1
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
ici_fsdp_parallelism: 1
ici_context_parallelism: -1 # recommended ICI axis to be auto-sharded
ici_tensor_parallelism: 1

allow_split_physical_axes: False
Expand Down
22 changes: 12 additions & 10 deletions src/maxdiffusion/configs/base_wan_i2v_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
skip_jax_distributed_system: False

# Parallelism
mesh_axes: ['data', 'fsdp', 'tensor']
mesh_axes: ['data', 'fsdp', 'context', 'tensor']

# batch : batch dimension of data and activations
# hidden :
Expand All @@ -149,31 +149,33 @@ mesh_axes: ['data', 'fsdp', 'tensor']
# conv_in : conv.shape[2] weight
# conv_out : conv.shape[-1] weight
logical_axis_rules: [
['batch', 'data'],
['activation_batch', 'data'],
['activation_self_attn_heads', ['fsdp', 'tensor']],
['activation_cross_attn_q_length', ['fsdp', 'tensor']],
['activation_length', 'fsdp'],
['batch', ['data', 'fsdp']],
['activation_batch', ['data', 'fsdp']],
['activation_self_attn_heads', ['context', 'tensor']],
['activation_cross_attn_q_length', ['context', 'tensor']],
['activation_length', 'context'],
['activation_heads', 'tensor'],
['mlp','tensor'],
['embed','fsdp'],
['embed', ['context', 'fsdp']],
['heads', 'tensor'],
['norm', 'tensor'],
['conv_batch', ['data','fsdp']],
['conv_batch', ['data', 'context', 'fsdp']],
['out_channels', 'tensor'],
['conv_out', 'fsdp'],
['conv_out', 'context'],
]
data_sharding: [['data', 'fsdp', 'tensor']]
data_sharding: [['data', 'fsdp', 'context', 'tensor']]

# One axis for each parallelism type may hold a placeholder (-1)
# value to auto-shard based on available slices and devices.
# By default, product of the DCN axes should equal number of slices
# and product of the ICI axes should equal number of devices per slice.
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
dcn_fsdp_parallelism: -1
dcn_context_parallelism: 1
dcn_tensor_parallelism: 1
ici_data_parallelism: 1
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
ici_context_parallelism: 1
ici_tensor_parallelism: 1

allow_split_physical_axes: False
Expand Down
Loading
Loading