Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions README.md
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,17 @@ After installation completes, run the training script.
- In Wan2.1, the ici_fsdp_parallelism axis is used for sequence parallelism, the ici_tensor_parallelism axis is used for head parallelism.
- You can enable both, keeping in mind that Wan2.1 has 40 heads and 40 must be evenly divisible by ici_tensor_parallelism.
- For Sequence parallelism, the code pads the sequence length to evenly divide the sequence. Try out different ici_fsdp_parallelism numbers, but we find 2 and 4 to be the best right now.
- For benchmarking training performance on multiple data dimension input without downloading/re-processing the dataset, the synthetic data iterator is supported.
- Set dataset_type='synthetic' and synthetic_num_samples=null to enable the synthetic data iterator.
- The following overrides on data dimensions are supported:
- synthetic_override_height: 720
- synthetic_override_width: 1280
- synthetic_override_num_frames: 85
- synthetic_override_max_sequence_length: 512
- synthetic_override_text_embed_dim: 4096
- synthetic_override_num_channels_latents: 16
- synthetic_override_vae_scale_factor_spatial: 8
- synthetic_override_vae_scale_factor_temporal: 4

You should eventually see a training run as:

Expand Down
15 changes: 14 additions & 1 deletion src/maxdiffusion/configs/base_flux_dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,20 @@ allow_split_physical_axes: False
# Replace with dataset path or train_data_dir. One has to be set.
dataset_name: 'diffusers/pokemon-gpt4-captions'
train_split: 'train'
dataset_type: 'tf'
dataset_type: 'tfrecord' # Options: 'tfrecord', 'hf', 'tf', 'grain', 'synthetic'
# ==============================================================================
# Synthetic Data Configuration (only used when dataset_type='synthetic')
# ==============================================================================
# To use synthetic data for testing/debugging without real datasets:
# 1. Set dataset_type: 'synthetic' above
# 2. Optionally set synthetic_num_samples (null=infinite, or a number like 10000)
# 3. Optionally override dimensions
#
# synthetic_num_samples: null # null for infinite, or set a number
#
# Optional dimension overrides:
# resolution: 512
# ==============================================================================
cache_latents_text_encoder_outputs: True
# cache_latents_text_encoder_outputs only apply to dataset_type="tf",
# only apply to small dataset that fits in memory
Expand Down
23 changes: 22 additions & 1 deletion src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,28 @@ allow_split_physical_axes: False
# Replace with dataset path or train_data_dir. One has to be set.
dataset_name: 'diffusers/pokemon-gpt4-captions'
train_split: 'train'
dataset_type: 'tfrecord'
dataset_type: 'tfrecord' # Options: 'tfrecord', 'hf', 'tf', 'grain', 'synthetic'
# ==============================================================================
# Synthetic Data Configuration (only used when dataset_type='synthetic')
# ==============================================================================
# To use synthetic data for testing/debugging without real datasets:
# 1. Set dataset_type: 'synthetic' above
# 2. Optionally set synthetic_num_samples (null=infinite, or a number like 10000)
# 3. Optionally override dimensions with synthetic_override_* flags below
#
# synthetic_num_samples: null # null for infinite, or set a number
#
# Optional dimension overrides (comment out to use pipeline/config values):
# synthetic_override_height: 720
# synthetic_override_width: 1280
# synthetic_override_num_frames: 121
# synthetic_override_max_sequence_length: 512
# synthetic_override_text_embed_dim: 4096
# synthetic_override_num_channels_latents: 16
# synthetic_override_vae_scale_factor_spatial: 8
# synthetic_override_vae_scale_factor_temporal: 4
# ==============================================================================

cache_latents_text_encoder_outputs: True
# cache_latents_text_encoder_outputs only apply to dataset_type="tf",
# only apply to small dataset that fits in memory
Expand Down
14 changes: 12 additions & 2 deletions src/maxdiffusion/input_pipeline/input_pipeline_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from maxdiffusion.input_pipeline import _hf_data_processing
from maxdiffusion.input_pipeline import _grain_data_processing
from maxdiffusion.input_pipeline import _tfds_data_processing
from maxdiffusion.input_pipeline import synthetic_data_iterator
from maxdiffusion import multihost_dataloading
from maxdiffusion.maxdiffusion_utils import tokenize_captions, transform_images, vae_apply
from maxdiffusion.dreambooth.dreambooth_constants import (
Expand Down Expand Up @@ -54,8 +55,9 @@ def make_data_iterator(
feature_description=None,
prepare_sample_fn=None,
is_training=True,
pipeline=None,
):
"""Make data iterator for SD1, 2, XL, dataset_types in (hf, tf, tfrecord)"""
"""Make data iterator for SD1, 2, XL, dataset_types in (hf, tf, tfrecord, grain, synthetic)"""

if config.dataset_type == "hf" or config.dataset_type == "tf":
if tokenize_fn is None or image_transforms_fn is None:
Expand Down Expand Up @@ -110,8 +112,16 @@ def make_data_iterator(
prepare_sample_fn,
is_training,
)
elif config.dataset_type == "synthetic":
return synthetic_data_iterator.make_synthetic_iterator(
config=config,
mesh=mesh,
global_batch_size=global_batch_size,
pipeline=pipeline,
is_training=is_training,
)
else:
assert False, f"Unknown dataset_type {config.dataset_type}, dataset_type must be in (tf, tfrecord, hf, grain)"
assert False, f"Unknown dataset_type {config.dataset_type}, dataset_type must be in (tf, tfrecord, hf, grain, synthetic)"


def make_dreambooth_train_iterator(config, mesh, global_batch_size, tokenizer, vae, vae_params):
Expand Down
Loading
Loading