-
Notifications
You must be signed in to change notification settings - Fork 56
[Feature] Add LoRA Inference Support for WAN Models via Flax NNX #308
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
Perseus14
wants to merge
10
commits into
main
Choose a base branch
from
wan_lora
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Collaborator
|
Does LoRA support the I2V pipelines as well? |
* WAN Img2Vid Implementation * Removed randn_tensor function import * logical_axis rules and attention_sharding_uniform added in config files * removed attn_mask from FlaxWanAttn call * fix to prevent load_image_encoder from running for wan 2.2 iv * boundary_ratio removed from generate_wan.py * testing with 720p * model restored * attn_mask correction * transformer corrected in wan 2.2 t2v and config files updated * revert * corrected * import added in wan_checkpointer_test.py * wan_checkpointer_test.py corrected * wan_checkpointer_test.py corrected * wan_checkpointer_test.py corrected * removed redundance img attn mask * Fix for multiple videos * Fix for multiple videos * Fix for multiple videos * Fix for multiple videos * removed redundant args * removed redundant args * trying dot attn fix * reverting fix to see if that was the issue * fix verified * updated comments * Added sharding * sharding added * ruff checks * README updated * sharding * ruff check
- Rename VaceWanPipeline to VaceWanPipeline2_1. - Make VaceWanPipeline2_1 inherit from WanPipeline2_1. - Remove calls to self.transformer2. This change aligns the class with the WanPipeline naming structure and resolves an initialization bug when using the from_pretrained method.
- ubuntu-20.04 has been deprecated - using ubuntu-latest for robustness
* Fix formatting with pyink * Update Python and pyink version to be aligned with code_style.sh
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
This PR introduces full Low-Rank Adaptation (LoRA) inference support for the WAN family of models in MaxDiffusion.
Unlike previous implementations in this codebase that rely on
flax.linen, this implementation leveragesflax.nnx. This allows for a more Pythonic, object-oriented approach to weight injection, enabling us to modify thetransformer modelin-place.Key Features
1. Transition to
flax.nnxWAN models in MaxDiffusion are implemented using
flax.nnx. To support LoRA, we implemented a native NNX loader rather than wrappinglinenmodules.nnx.iter_graph) to identify target layers (nnx.Linear,nnx.Conv,nnx.Embed,nnx.LayerNorm) and merge LoRA weights directly into the kernel values.2. Robust Weight Merging Strategy
This implementation solves several critical distributed training/inference challenges:
jax.jit): To avoidShardingMismatchandDeviceArrayerrors that occur when mixing sharded TPU weights with CPU-based LoRA weights, all merge computations (kernel + delta) are performed within JIT-compiled functions (_compute_and_add_*_jit). This ensures weight updates occur efficiently on-device across the TPU mesh.jax.dlpackwhere possible to efficiently move PyTorch tensors to JAX arrays without unnecessary memory overhead.3. Advanced LoRA Support
Beyond standard
Linearrank reduction, this PR supports:diffweights before device-side merging.diff,diff_b): Supports checkpoints that include full-parameter fine-tuning offsets (difference injections) and bias tuning, which are common in high-fidelity WAN fine-tunes.text_embedding,time_embedding, andLayerNorm/RMSNormscales and biases.4. Scanned vs. Unscanned Layers
MaxDiffusion supports enabling
jax.scanfor transformer layers via thescan_layers: Trueconfiguration flag. This improves training memory efficiency by stacking weights of repeated layers (e.g., Attention, FFN) along a new leading dimension. Since users may run inference with or without this flag enabled, this LoRA implementation is designed to transparently support both modes.The loader distinguishes between:
merge_lora()function is used, which iterates through each layer and merges weights individually via efficient, on-device JIT calls (_compute_and_add_single_jit).merge_lora_for_scanned()function is used. It detects which parameters are stacked (e.g.,kernel.ndim > 2) and which are not._compute_and_add_scanned_jit. This updates all layers in the stack at once on-device, which is significantly more efficient than merging layer-by-layer.embeddings,proj_out): It merges them individually using the single-layer JIT logic.This dual approach ensures correct weight injection whether or not layers are scanned, while maximizing performance in scanned mode through batching.
Files Added / Modified
src/maxdiffusion/models/lora_nnx.py: [NEW] Core logic. Contains the JIT merge functions,parse_lora_dict, and the graph traversal logic (merge_lora,merge_lora_for_scanned) to inject weights into NNX modules.src/maxdiffusion/loaders/wan_lora_nnx_loader.py: [NEW] Orchestrates the loading process. Handles the download of safetensors, conversion of keys, and delegation to the merge functions.src/maxdiffusion/generate_wan.py: Updated the generation pipeline to identify iflorais enabled and trigger the loading sequence before inference.src/maxdiffusion/lora_conversion_utils.py: Updatedtranslate_wan_nnx_path_to_diffusers_lorato accurately map NNX paths (including embeddings and time projections) to Diffusers-style keys.base_wan_lora_14b.yml&base_wan_lora_27b.yml: Added lora_config section to specify LoRA checkpoints and parameters during inference.Testing