From 9b1e9dfcb362d2269a4118c1bebba78cffb6e3ef Mon Sep 17 00:00:00 2001 From: jianhan-amd Date: Tue, 20 Jan 2026 14:59:17 +0000 Subject: [PATCH 1/2] feat: add general synthetic data iterator and examples for WAN and FLUX. --- src/maxdiffusion/configs/base_flux_dev.yml | 15 +- src/maxdiffusion/configs/base_wan_14b.yml | 23 +- .../input_pipeline_interface.py | 14 +- .../input_pipeline/synthetic_data_iterator.py | 491 ++++++++++++++++++ src/maxdiffusion/trainers/flux_trainer.py | 12 + src/maxdiffusion/trainers/wan_trainer.py | 30 +- 6 files changed, 579 insertions(+), 6 deletions(-) create mode 100755 src/maxdiffusion/input_pipeline/synthetic_data_iterator.py diff --git a/src/maxdiffusion/configs/base_flux_dev.yml b/src/maxdiffusion/configs/base_flux_dev.yml index 0036b363..15688880 100644 --- a/src/maxdiffusion/configs/base_flux_dev.yml +++ b/src/maxdiffusion/configs/base_flux_dev.yml @@ -177,7 +177,20 @@ allow_split_physical_axes: False # Replace with dataset path or train_data_dir. One has to be set. dataset_name: 'diffusers/pokemon-gpt4-captions' train_split: 'train' -dataset_type: 'tf' +dataset_type: 'tfrecord' # Options: 'tfrecord', 'hf', 'tf', 'grain', 'synthetic' +# ============================================================================== +# Synthetic Data Configuration (only used when dataset_type='synthetic') +# ============================================================================== +# To use synthetic data for testing/debugging without real datasets: +# 1. Set dataset_type: 'synthetic' above +# 2. Optionally set synthetic_num_samples (null=infinite, or a number like 10000) +# 3. Optionally override dimensions +# +# synthetic_num_samples: null # null for infinite, or set a number +# +# Optional dimension overrides: +# resolution: 512 +# ============================================================================== cache_latents_text_encoder_outputs: True # cache_latents_text_encoder_outputs only apply to dataset_type="tf", # only apply to small dataset that fits in memory diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index b2a11dba..67d66401 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -199,7 +199,28 @@ allow_split_physical_axes: False # Replace with dataset path or train_data_dir. One has to be set. dataset_name: 'diffusers/pokemon-gpt4-captions' train_split: 'train' -dataset_type: 'tfrecord' +dataset_type: 'tfrecord' # Options: 'tfrecord', 'hf', 'tf', 'grain', 'synthetic' +# ============================================================================== +# Synthetic Data Configuration (only used when dataset_type='synthetic') +# ============================================================================== +# To use synthetic data for testing/debugging without real datasets: +# 1. Set dataset_type: 'synthetic' above +# 2. Optionally set synthetic_num_samples (null=infinite, or a number like 10000) +# 3. Optionally override dimensions with synthetic_override_* flags below +# +# synthetic_num_samples: null # null for infinite, or set a number +# +# Optional dimension overrides (comment out to use pipeline/config values): +# synthetic_override_height: 720 +# synthetic_override_width: 1280 +# synthetic_override_num_frames: 121 +# synthetic_override_max_sequence_length: 512 +# synthetic_override_text_embed_dim: 4096 +# synthetic_override_num_channels_latents: 16 +# synthetic_override_vae_scale_factor_spatial: 8 +# synthetic_override_vae_scale_factor_temporal: 4 +# ============================================================================== + cache_latents_text_encoder_outputs: True # cache_latents_text_encoder_outputs only apply to dataset_type="tf", # only apply to small dataset that fits in memory diff --git a/src/maxdiffusion/input_pipeline/input_pipeline_interface.py b/src/maxdiffusion/input_pipeline/input_pipeline_interface.py index 27f2ad25..1c030f14 100644 --- a/src/maxdiffusion/input_pipeline/input_pipeline_interface.py +++ b/src/maxdiffusion/input_pipeline/input_pipeline_interface.py @@ -23,6 +23,7 @@ from maxdiffusion.input_pipeline import _hf_data_processing from maxdiffusion.input_pipeline import _grain_data_processing from maxdiffusion.input_pipeline import _tfds_data_processing +from maxdiffusion.input_pipeline import synthetic_data_iterator from maxdiffusion import multihost_dataloading from maxdiffusion.maxdiffusion_utils import tokenize_captions, transform_images, vae_apply from maxdiffusion.dreambooth.dreambooth_constants import ( @@ -54,8 +55,9 @@ def make_data_iterator( feature_description=None, prepare_sample_fn=None, is_training=True, + pipeline=None, ): - """Make data iterator for SD1, 2, XL, dataset_types in (hf, tf, tfrecord)""" + """Make data iterator for SD1, 2, XL, dataset_types in (hf, tf, tfrecord, grain, synthetic)""" if config.dataset_type == "hf" or config.dataset_type == "tf": if tokenize_fn is None or image_transforms_fn is None: @@ -110,8 +112,16 @@ def make_data_iterator( prepare_sample_fn, is_training, ) + elif config.dataset_type == "synthetic": + return synthetic_data_iterator.make_synthetic_iterator( + config=config, + mesh=mesh, + global_batch_size=global_batch_size, + pipeline=pipeline, + is_training=is_training, + ) else: - assert False, f"Unknown dataset_type {config.dataset_type}, dataset_type must be in (tf, tfrecord, hf, grain)" + assert False, f"Unknown dataset_type {config.dataset_type}, dataset_type must be in (tf, tfrecord, hf, grain, synthetic)" def make_dreambooth_train_iterator(config, mesh, global_batch_size, tokenizer, vae, vae_params): diff --git a/src/maxdiffusion/input_pipeline/synthetic_data_iterator.py b/src/maxdiffusion/input_pipeline/synthetic_data_iterator.py new file mode 100755 index 00000000..1d9f0a5f --- /dev/null +++ b/src/maxdiffusion/input_pipeline/synthetic_data_iterator.py @@ -0,0 +1,491 @@ +""" +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import math +from typing import Dict, Any, Optional +import numpy as np +import jax +import jax.numpy as jnp + +from maxdiffusion import multihost_dataloading, max_logging + + +# ============================================================================ +# Helper Functions +# ============================================================================ + + +def get_wan_dimension( + config, + pipeline, + config_key: str, + pipeline_path: str = None, + default_value: Any = None +) -> Any: + """ + Get dimension for WAN model with override priority: + 1. Config override (synthetic_override_{config_key}) - for height, width, num_frames + 2. Pipeline path (exact path specified by caller) + 3. Config default + 4. Hardcoded default + + Args: + config: Configuration object + pipeline: WAN Pipeline object + config_key: Key to look up in config + pipeline_path: Exact dotted path in pipeline (e.g., 'transformer.config.in_channels') + default_value: Fallback value if not found elsewhere + """ + # Check overrides for height, width, num_frames (WAN-specific) + if config_key in ['height', 'width', 'num_frames']: + override_key = f'synthetic_override_{config_key}' + try: + value = getattr(config, override_key) + if value is not None: + if jax.process_index() == 0: + max_logging.log(f"[WAN] Using override {config_key}: {value}") + return value + except (AttributeError, ValueError): + pass # Override not set, continue to pipeline/config + + # Check pipeline using exact path if provided + if pipeline is not None and pipeline_path: + try: + # Navigate the dotted path (e.g., 'transformer.config.in_channels') + value = pipeline + for attr in pipeline_path.split('.'): + value = getattr(value, attr) + + if value is not None: + if jax.process_index() == 0: + max_logging.log(f"[WAN] Using {config_key} from pipeline.{pipeline_path}: {value}") + return value + except AttributeError: + pass # Path not available in pipeline + + # Check config - use try/except because config raises ValueError instead of AttributeError + try: + value = getattr(config, config_key) + if jax.process_index() == 0: + max_logging.log(f"[WAN] Using {config_key} from config: {value}") + return value + except (AttributeError, ValueError): + pass # Key not in config, use default + + # Use default + if jax.process_index() == 0: + max_logging.log(f"[WAN] Using default {config_key}: {default_value}") + return default_value + + +def get_flux_dimension( + config, + pipeline, + config_key: str, + pipeline_path: str = None, + default_value: Any = None +) -> Any: + """ + Get dimension for FLUX model with override priority: + 1. Pipeline path (exact path specified by caller) + 2. Config default + 3. Hardcoded default + + Note: FLUX does not support override flags + + Args: + config: Configuration object + pipeline: FLUX Pipeline object + config_key: Key to look up in config + pipeline_path: Exact dotted path in pipeline (e.g., 'vae_scale_factor') + default_value: Fallback value if not found elsewhere + """ + # FLUX does not check overrides - load directly from pipeline/config + + # Check pipeline using exact path if provided + if pipeline is not None and pipeline_path: + try: + # Navigate the dotted path (e.g., 'vae_scale_factor') + value = pipeline + for attr in pipeline_path.split('.'): + value = getattr(value, attr) + + if value is not None: + if jax.process_index() == 0: + max_logging.log(f"[FLUX] Using {config_key} from pipeline.{pipeline_path}: {value}") + return value + except AttributeError: + pass # Path not available in pipeline + + # Check config - use try/except because config raises ValueError instead of AttributeError + try: + value = getattr(config, config_key) + if jax.process_index() == 0: + max_logging.log(f"[FLUX] Using {config_key} from config: {value}") + return value + except (AttributeError, ValueError): + pass # Key not in config, use default + + # Use default + if jax.process_index() == 0: + max_logging.log(f"[FLUX] Using default {config_key}: {default_value}") + return default_value + + +def log_synthetic_config(model_name: str, dimensions: Dict[str, Any], per_host_batch_size: int, is_training: bool, num_samples: Optional[int]): + """Log synthetic data configuration.""" + if jax.process_index() == 0: + info = [ + "=" * 60, + f"{model_name.upper()} Synthetic Data Iterator Configuration:", + f" Per-host batch size: {per_host_batch_size}", + f" Mode: {'Training' if is_training else 'Evaluation'}", + f" Samples per iteration: {num_samples if num_samples else 'Infinite'}", + ] + for key, value in dimensions.items(): + info.append(f" {key}: {value}") + info.append("=" * 60) + max_logging.log("\n".join(info)) + + +# ============================================================================ +# Synthetic Data Source and Iterator +# ============================================================================ + + +class SyntheticDataSource: + """Wrapper for synthetic data that provides iterator interface.""" + + def __init__(self, generate_fn, num_samples, seed): + self.generate_fn = generate_fn + self.num_samples = num_samples + self.seed = seed + self.current_step = 0 + self.rng_key = jax.random.key(seed) + + def __iter__(self): + self.current_step = 0 + self.rng_key = jax.random.key(self.seed) + return self + + def __next__(self): + if self.num_samples is not None and self.current_step >= self.num_samples: + raise StopIteration + + self.rng_key, step_key = jax.random.split(self.rng_key) + data = self.generate_fn(step_key) + self.current_step += 1 + return data + + def as_numpy_iterator(self): + return iter(self) + + +# ============================================================================ +# WAN Model Synthetic Data Generator +# ============================================================================ + + +def _generate_wan_sample(rng_key: jax.Array, dimensions: Dict[str, Any], is_training: bool) -> Dict[str, np.ndarray]: + """Generate a single batch of synthetic data for WAN model.""" + keys = jax.random.split(rng_key, 3) + + per_host_batch_size = dimensions['per_host_batch_size'] + + # Generate latents: (batch, channels, frames, height, width) + latents_shape = ( + per_host_batch_size, + dimensions['num_channels_latents'], + dimensions['num_latent_frames'], + dimensions['latent_height'], + dimensions['latent_width'] + ) + latents = jax.random.normal(keys[0], shape=latents_shape, dtype=jnp.float32) + + # Generate encoder hidden states: (batch, seq_len, embed_dim) + encoder_hidden_states_shape = ( + per_host_batch_size, + dimensions['max_sequence_length'], + dimensions['text_embed_dim'] + ) + encoder_hidden_states = jax.random.normal(keys[1], shape=encoder_hidden_states_shape, dtype=jnp.float32) + + data = { + 'latents': np.array(latents), + 'encoder_hidden_states': np.array(encoder_hidden_states), + } + + # For evaluation, also generate timesteps + if not is_training: + timesteps = jax.random.randint( + keys[2], + shape=(per_host_batch_size,), + minval=0, + maxval=dimensions['num_train_timesteps'], + dtype=jnp.int64 + ) + data['timesteps'] = np.array(timesteps) + + return data + + +def _make_wan_synthetic_iterator(config, mesh, global_batch_size, pipeline, is_training, num_samples): + """Create synthetic data iterator for WAN model.""" + per_host_batch_size = global_batch_size // jax.process_count() + + # Initialize dimensions - explicitly specify pipeline paths for WAN model + height = get_wan_dimension( + config, pipeline, 'height', + pipeline_path=None, # Not in pipeline, use config/override + default_value=480 + ) + width = get_wan_dimension( + config, pipeline, 'width', + pipeline_path=None, # Not in pipeline, use config/override + default_value=832 + ) + num_frames = get_wan_dimension( + config, pipeline, 'num_frames', + pipeline_path=None, # Not in pipeline, use config/override + default_value=81 + ) + + # WAN-specific dimensions from transformer config + max_sequence_length = get_wan_dimension( + config, pipeline, 'max_sequence_length', + pipeline_path='transformer.config.rope_max_seq_len', + default_value=512 + ) + text_embed_dim = get_wan_dimension( + config, pipeline, 'text_embed_dim', + pipeline_path='transformer.config.text_dim', + default_value=4096 + ) + num_channels_latents = get_wan_dimension( + config, pipeline, 'num_channels_latents', + pipeline_path='transformer.config.in_channels', + default_value=16 + ) + + # VAE scale factors from pipeline attributes + vae_scale_factor_spatial = get_wan_dimension( + config, pipeline, 'vae_scale_factor_spatial', + pipeline_path='vae_scale_factor_spatial', + default_value=8 + ) + vae_scale_factor_temporal = get_wan_dimension( + config, pipeline, 'vae_scale_factor_temporal', + pipeline_path='vae_scale_factor_temporal', + default_value=4 + ) + + # Calculate latent dimensions + num_latent_frames = (num_frames - 1) // vae_scale_factor_temporal + 1 + latent_height = height // vae_scale_factor_spatial + latent_width = width // vae_scale_factor_spatial + + # Get num_train_timesteps from scheduler + num_train_timesteps = get_wan_dimension( + config, pipeline, 'num_train_timesteps', + pipeline_path='scheduler.config.num_train_timesteps', + default_value=1000 + ) + # Fallback to scheduler.num_train_timesteps if config doesn't exist + if pipeline is not None and hasattr(pipeline, 'scheduler') and num_train_timesteps == 1000: + try: + num_train_timesteps = pipeline.scheduler.num_train_timesteps + if jax.process_index() == 0: + max_logging.log(f"Using num_train_timesteps from pipeline.scheduler: {num_train_timesteps}") + except AttributeError: + pass + + dimensions = { + 'per_host_batch_size': per_host_batch_size, + 'height': height, + 'width': width, + 'num_frames': num_frames, + 'num_latent_frames': num_latent_frames, + 'latent_height': latent_height, + 'latent_width': latent_width, + 'max_sequence_length': max_sequence_length, + 'text_embed_dim': text_embed_dim, + 'num_channels_latents': num_channels_latents, + 'vae_scale_factor_spatial': vae_scale_factor_spatial, + 'vae_scale_factor_temporal': vae_scale_factor_temporal, + 'num_train_timesteps': num_train_timesteps, + } + + log_synthetic_config('WAN', dimensions, per_host_batch_size, is_training, num_samples) + + # Create generate function with dimensions bound + def generate_fn(rng_key): + return _generate_wan_sample(rng_key, dimensions, is_training) + + data_source = SyntheticDataSource(generate_fn, num_samples, config.seed) + return multihost_dataloading.MultiHostDataLoadIterator(data_source, mesh) + + +# ============================================================================ +# FLUX Model Synthetic Data Generator +# ============================================================================ + + +def _generate_flux_sample(rng_key: jax.Array, dimensions: Dict[str, Any]) -> Dict[str, np.ndarray]: + """Generate a single batch of synthetic data for FLUX model.""" + keys = jax.random.split(rng_key, 4) + + per_host_batch_size = dimensions['per_host_batch_size'] + latent_height = dimensions['latent_height'] + latent_width = dimensions['latent_width'] + latent_seq_len = dimensions['latent_seq_len'] + + # Generate pixel values (packed latents) - should be float16 to match trainer + pixel_values_shape = (per_host_batch_size, latent_seq_len, dimensions['packed_latent_dim']) + pixel_values = jax.random.normal(keys[0], shape=pixel_values_shape, dtype=jnp.float16) + + # Generate text embedding IDs (position encodings) + input_ids_shape = (per_host_batch_size, dimensions['max_sequence_length'], 3) + input_ids = jax.random.normal(keys[1], shape=input_ids_shape, dtype=jnp.float32) + + # Generate text embeddings (T5) + text_embeds_shape = (per_host_batch_size, dimensions['max_sequence_length'], dimensions['t5_embed_dim']) + text_embeds = jax.random.normal(keys[2], shape=text_embeds_shape, dtype=jnp.float32) + + # Generate pooled prompt embeddings (CLIP) + prompt_embeds_shape = (per_host_batch_size, dimensions['pooled_embed_dim']) + prompt_embeds = jax.random.normal(keys[3], shape=prompt_embeds_shape, dtype=jnp.float32) + + # Generate image position IDs - matching pipeline.prepare_latent_image_ids + # Create base img_ids for single sample (without batch dimension) + img_ids_base = jnp.zeros((latent_height, latent_width, 3), dtype=jnp.float16) + # Channel 0 stays 0 + # Channel 1 = height indices + img_ids_base = img_ids_base.at[..., 1].set(jnp.arange(latent_height)[:, None]) + # Channel 2 = width indices + img_ids_base = img_ids_base.at[..., 2].set(jnp.arange(latent_width)[None, :]) + + # Reshape to (latent_seq_len, 3) + img_ids_base = img_ids_base.reshape(latent_seq_len, 3) + + # Tile for batch dimension + img_ids = jnp.tile(img_ids_base[None, ...], (per_host_batch_size, 1, 1)) + + return { + 'pixel_values': np.array(pixel_values), + 'input_ids': np.array(input_ids), + 'text_embeds': np.array(text_embeds), + 'prompt_embeds': np.array(prompt_embeds), + 'img_ids': np.array(img_ids), + } + + +def _make_flux_synthetic_iterator(config, mesh, global_batch_size, pipeline, is_training, num_samples): + """Create synthetic data iterator for FLUX model.""" + per_host_batch_size = global_batch_size // jax.process_count() + + # Initialize dimensions - explicitly specify pipeline paths for FLUX model + resolution = get_flux_dimension( + config, pipeline, 'resolution', + pipeline_path=None, # Not in pipeline, use config + default_value=512 + ) + max_sequence_length = get_flux_dimension( + config, pipeline, 'max_sequence_length', + pipeline_path=None, # Not in pipeline, use config + default_value=512 + ) + t5_embed_dim = get_flux_dimension( + config, pipeline, 't5_embed_dim', + pipeline_path='text_encoder_2.config.d_model', # T5 model dimension + default_value=4096 + ) + pooled_embed_dim = get_flux_dimension( + config, pipeline, 'pooled_embed_dim', + pipeline_path='text_encoder.config.projection_dim', # CLIP projection dimension + default_value=768 + ) + vae_scale_factor = get_flux_dimension( + config, pipeline, 'vae_scale_factor', + pipeline_path='vae_scale_factor', # Direct pipeline attribute + default_value=8 + ) + + # Calculate packed latent dimensions + latent_height = math.ceil(resolution // (vae_scale_factor * 2)) + latent_width = math.ceil(resolution // (vae_scale_factor * 2)) + latent_seq_len = latent_height * latent_width + packed_latent_dim = 64 # 16 channels * 2 * 2 packing + + dimensions = { + 'per_host_batch_size': per_host_batch_size, + 'max_sequence_length': max_sequence_length, + 't5_embed_dim': t5_embed_dim, + 'pooled_embed_dim': pooled_embed_dim, + 'resolution': resolution, + 'latent_height': latent_height, + 'latent_width': latent_width, + 'latent_seq_len': latent_seq_len, + 'packed_latent_dim': packed_latent_dim, + } + + log_synthetic_config('FLUX', dimensions, per_host_batch_size, is_training, num_samples) + + # Create generate function with dimensions bound + def generate_fn(rng_key): + return _generate_flux_sample(rng_key, dimensions) + + data_source = SyntheticDataSource(generate_fn, num_samples, config.seed) + return multihost_dataloading.MultiHostDataLoadIterator(data_source, mesh) + + +# ============================================================================ +# Public API +# ============================================================================ + + +def make_synthetic_iterator(config, mesh, global_batch_size, pipeline=None, is_training=True): + """ + Create a synthetic data iterator for the specified model. + + Args: + config: Configuration object with model_name + mesh: JAX mesh for sharding + global_batch_size: Total batch size across all devices + pipeline: Optional pipeline object to extract dimensions from + is_training: Whether this is for training or evaluation + + Returns: + MultiHostDataLoadIterator wrapping the synthetic data source + """ + num_samples = getattr(config, 'synthetic_num_samples', None) + + try: + model_name = getattr(config, 'model_name', None) + if model_name in ('wan2.1', 'wan2.2'): + return _make_wan_synthetic_iterator(config, mesh, global_batch_size, pipeline, is_training, num_samples) + except (AttributeError, ValueError): + pass + try: + model_name = getattr(config, 'flux_name', None) + if model_name in ('flux', 'flux-dev', 'flux-schnell'): + return _make_flux_synthetic_iterator(config, mesh, global_batch_size, pipeline, is_training, num_samples) + except (AttributeError, ValueError): + pass + + raise ValueError( + f"No synthetic iterator implemented for model." + f"Supported models: wan2.1, wan2.2, flux, flux-dev, flux-schnell" + ) diff --git a/src/maxdiffusion/trainers/flux_trainer.py b/src/maxdiffusion/trainers/flux_trainer.py index 74b4f259..9f27935e 100644 --- a/src/maxdiffusion/trainers/flux_trainer.py +++ b/src/maxdiffusion/trainers/flux_trainer.py @@ -247,6 +247,18 @@ def load_dataset(self, pipeline, params, train_states): total_train_batch_size = self.total_train_batch_size mesh = self.mesh + # If using synthetic data + if config.dataset_type == "synthetic": + return make_data_iterator( + config, + jax.process_index(), + jax.process_count(), + mesh, + total_train_batch_size, + pipeline=pipeline, # Pass pipeline to extract dimensions + is_training=True, + ) + encode_fn = partial( pipeline.encode_prompt, clip_tokenizer=pipeline.clip_tokenizer, diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index f23836a5..ae0d095b 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -164,7 +164,18 @@ def get_eval_data_shardings(self, mesh): data_sharding = {"latents": data_sharding, "encoder_hidden_states": data_sharding, "timesteps": data_sharding} return data_sharding - def load_dataset(self, mesh, is_training=True): + def load_dataset(self, mesh, pipeline=None, is_training=True): + """ + Load dataset - supports both real tfrecord and synthetic data. + + Args: + mesh: JAX mesh for sharding + pipeline: Optional WAN pipeline to extract dimensions from (for synthetic data) + is_training: Whether this is for training or evaluation + + Returns: + Data iterator + """ # Stages of training as described in the Wan 2.1 paper - https://arxiv.org/pdf/2503.20314 # Image pre-training - txt2img 256px # Image-video joint training - stage 1. 256 px images and 192px 5 sec videos at fps=16 @@ -173,6 +184,21 @@ def load_dataset(self, mesh, is_training=True): # prompt embeds shape: (1, 512, 4096) # For now, we will pass the same latents over and over # TODO - create a dataset + + config = self.config + + # If using synthetic data + if config.dataset_type == "synthetic": + return make_data_iterator( + config, + jax.process_index(), + jax.process_count(), + mesh, + config.global_batch_size_to_load, + pipeline=pipeline, # Pass pipeline to extract dimensions + is_training=is_training, + ) + config = self.config if config.dataset_type != "tfrecord" and not config.cache_latents_text_encoder_outputs: raise ValueError( @@ -226,7 +252,7 @@ def start_training(self): del pipeline.vae_cache mesh = pipeline.mesh - train_data_iterator = self.load_dataset(mesh, is_training=True) + train_data_iterator = self.load_dataset(mesh, pipeline=pipeline, is_training=True) # Load FlowMatch scheduler scheduler, scheduler_state = self.create_scheduler() From b3fcc99790b5090db9db97266797b24f25a799dc Mon Sep 17 00:00:00 2001 From: jianhan-amd Date: Tue, 20 Jan 2026 15:27:10 +0000 Subject: [PATCH 2/2] doc: update README.md with synthetic data iterator usage. --- README.md | 11 +++++++++++ 1 file changed, 11 insertions(+) mode change 100644 => 100755 README.md diff --git a/README.md b/README.md old mode 100644 new mode 100755 index e6584d23..0e0e7a66 --- a/README.md +++ b/README.md @@ -255,6 +255,17 @@ After installation completes, run the training script. - In Wan2.1, the ici_fsdp_parallelism axis is used for sequence parallelism, the ici_tensor_parallelism axis is used for head parallelism. - You can enable both, keeping in mind that Wan2.1 has 40 heads and 40 must be evenly divisible by ici_tensor_parallelism. - For Sequence parallelism, the code pads the sequence length to evenly divide the sequence. Try out different ici_fsdp_parallelism numbers, but we find 2 and 4 to be the best right now. + - For benchmarking training performance on multiple data dimension input without downloading/re-processing the dataset, the synthetic data iterator is supported. + - Set dataset_type='synthetic' and synthetic_num_samples=null to enable the synthetic data iterator. + - The following overrides on data dimensions are supported: + - synthetic_override_height: 720 + - synthetic_override_width: 1280 + - synthetic_override_num_frames: 85 + - synthetic_override_max_sequence_length: 512 + - synthetic_override_text_embed_dim: 4096 + - synthetic_override_num_channels_latents: 16 + - synthetic_override_vae_scale_factor_spatial: 8 + - synthetic_override_vae_scale_factor_temporal: 4 You should eventually see a training run as: