Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -317,12 +317,14 @@ lightning_repo: ""
lightning_ckpt: ""

# LoRA parameters
enable_lora: False
# Values are lists to support multiple LoRA loading during inference in the future.
lora_config: {
lora_model_name_or_path: [],
weight_name: [],
adapter_name: [],
scale: [],
rank: [64],
lora_model_name_or_path: ["lightx2v/Wan2.1-Distill-Loras"],
weight_name: ["wan2.1_t2v_14b_lora_rank64_lightx2v_4step.safetensors"],
adapter_name: ["wan21-distill-lora"],
scale: [1.0],
from_pt: []
}
# Ex with values:
Expand Down
11 changes: 7 additions & 4 deletions src/maxdiffusion/configs/base_wan_27b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -316,12 +316,15 @@ lightning_repo: ""
lightning_ckpt: ""

# LoRA parameters
enable_lora: False
# Values are lists to support multiple LoRA loading during inference in the future.
lora_config: {
lora_model_name_or_path: [],
weight_name: [],
adapter_name: [],
scale: [],
rank: [64],
lora_model_name_or_path: ["lightx2v/Wan2.2-Distill-Loras"],
high_noise_weight_name: ["wan2.2_t2v_A14b_high_noise_lora_rank64_lightx2v_4step_1217.safetensors"],
low_noise_weight_name: ["wan2.2_t2v_A14b_low_noise_lora_rank64_lightx2v_4step_1217.safetensors"],
adapter_name: ["wan22-distill-lora"],
scale: [1.0],
from_pt: []
}
# Ex with values:
Expand Down
10 changes: 6 additions & 4 deletions src/maxdiffusion/configs/base_wan_i2v_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -300,12 +300,14 @@ lightning_repo: ""
lightning_ckpt: ""

# LoRA parameters
enable_lora: False
# Values are lists to support multiple LoRA loading during inference in the future.
lora_config: {
lora_model_name_or_path: [],
weight_name: [],
adapter_name: [],
scale: [],
rank: [64],
lora_model_name_or_path: ["lightx2v/Wan2.1-Distill-Loras"],
weight_name: ["wan2.1_i2v_lora_rank64_lightx2v_4step.safetensors"],
adapter_name: ["wan21-distill-lora-i2v"],
scale: [1.0],
from_pt: []
}
# Ex with values:
Expand Down
11 changes: 7 additions & 4 deletions src/maxdiffusion/configs/base_wan_i2v_27b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -312,12 +312,15 @@ lightning_repo: ""
lightning_ckpt: ""

# LoRA parameters
enable_lora: False
# Values are lists to support multiple LoRA loading during inference in the future.
lora_config: {
lora_model_name_or_path: [],
weight_name: [],
adapter_name: [],
scale: [],
rank: [64],
lora_model_name_or_path: ["lightx2v/Wan2.2-Distill-Loras"],
high_noise_weight_name: ["wan2.2_i2v_A14b_high_noise_lora_rank64_lightx2v_4step_1022.safetensors"],
low_noise_weight_name: ["wan2.2_i2v_A14b_low_noise_lora_rank64_lightx2v_4step_1022.safetensors"],
adapter_name: ["wan22-distill-lora"],
scale: [1.0],
from_pt: []
}
# Ex with values:
Expand Down
37 changes: 37 additions & 0 deletions src/maxdiffusion/generate_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from google.cloud import storage
import flax
from maxdiffusion.common_types import WAN2_1, WAN2_2
from maxdiffusion.loaders.wan_lora_nnx_loader import Wan2_1NnxLoraLoader, Wan2_2NnxLoraLoader


def upload_video_to_gcs(output_dir: str, video_path: str):
Expand Down Expand Up @@ -190,6 +191,42 @@ def run(config, pipeline=None, filename_prefix=""):
else:
raise ValueError(f"Unsupported model_name for checkpointer: {model_key}")
pipeline, _, _ = checkpoint_loader.load_checkpoint()

# If LoRA is specified, inject layers and load weights.
if config.enable_lora and hasattr(config, "lora_config") and config.lora_config and config.lora_config["lora_model_name_or_path"]:
if model_key == WAN2_1:
lora_loader = Wan2_1NnxLoraLoader()
lora_config = config.lora_config

if len(lora_config["lora_model_name_or_path"]) > 1:
max_logging.log("Found multiple LoRAs in config, but only loading the first one.")

pipeline = lora_loader.load_lora_weights(
pipeline,
lora_config["lora_model_name_or_path"][0],
transformer_weight_name=lora_config["weight_name"][0],
rank=lora_config["rank"][0],
scale=lora_config["scale"][0],
scan_layers=config.scan_layers,
)

if model_key == WAN2_2:
lora_loader = Wan2_2NnxLoraLoader()
lora_config = config.lora_config

if len(lora_config["lora_model_name_or_path"]) > 1:
max_logging.log("Found multiple LoRAs in config, but only loading the first one.")

pipeline = lora_loader.load_lora_weights(
pipeline,
lora_config["lora_model_name_or_path"][0],
high_noise_weight_name=lora_config["high_noise_weight_name"][0],
low_noise_weight_name=lora_config["low_noise_weight_name"][0],
rank=lora_config["rank"][0],
scale=lora_config["scale"][0],
scan_layers=config.scan_layers,
)

s0 = time.perf_counter()

# Using global_batch_size_to_train_on so not to create more config variables
Expand Down
1 change: 1 addition & 0 deletions src/maxdiffusion/loaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@

from .lora_pipeline import StableDiffusionLoraLoaderMixin
from .flux_lora_pipeline import FluxLoraLoaderMixin
from .wan_lora_nnx_loader import Wan2_1NnxLoraLoader, Wan2_2NnxLoraLoader
89 changes: 89 additions & 0 deletions src/maxdiffusion/loaders/lora_conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,3 +608,92 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.")

return new_state_dict


def translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False):
"""
Translates WAN NNX path to Diffusers/LoRA keys.
Verified against wan_utils.py mappings.
"""

# --- 1. Embeddings (Exact Matches) ---
if nnx_path_str == 'condition_embedder.text_embedder.linear_1':
return 'diffusion_model.text_embedding.0'
if nnx_path_str == 'condition_embedder.text_embedder.linear_2':
return 'diffusion_model.text_embedding.2'
if nnx_path_str == 'condition_embedder.time_embedder.linear_1':
return 'diffusion_model.time_embedding.0'
if nnx_path_str == 'condition_embedder.time_embedder.linear_2':
return 'diffusion_model.time_embedding.2'
if nnx_path_str == 'condition_embedder.image_embedder.norm1.layer_norm':
return 'diffusion_model.img_emb.proj.0'
if nnx_path_str == 'condition_embedder.image_embedder.ff.net_0':
return 'diffusion_model.img_emb.proj.1'
if nnx_path_str == 'condition_embedder.image_embedder.ff.net_2':
return 'diffusion_model.img_emb.proj.3'
if nnx_path_str == 'condition_embedder.image_embedder.norm2.layer_norm':
return 'diffusion_model.img_emb.proj.4'
if nnx_path_str == 'patch_embedding':
return 'diffusion_model.patch_embedding'
if nnx_path_str == 'proj_out':
return 'diffusion_model.head.head'
if nnx_path_str == 'condition_embedder.time_proj':
return 'diffusion_model.time_projection.1'




# --- 2. Map NNX Suffixes to LoRA Suffixes ---
suffix_map = {
# Self Attention (attn1)
"attn1.query": "self_attn.q",
"attn1.key": "self_attn.k",
"attn1.value": "self_attn.v",
"attn1.proj_attn": "self_attn.o",

# Self Attention Norms (QK Norm)
"attn1.norm_q": "self_attn.norm_q",
"attn1.norm_k": "self_attn.norm_k",

# Cross Attention (attn2)
"attn2.query": "cross_attn.q",
"attn2.key": "cross_attn.k",
"attn2.value": "cross_attn.v",
"attn2.proj_attn": "cross_attn.o",

# Cross Attention Norms (QK Norm)
"attn2.norm_q": "cross_attn.norm_q",
"attn2.norm_k": "cross_attn.norm_k",

# Cross Attention img
"attn2.add_k_proj": "cross_attn.k_img",
"attn2.add_v_proj": "cross_attn.v_img",
"attn2.norm_added_k": "cross_attn.norm_k_img",

# Feed Forward (ffn)
"ffn.act_fn.proj": "ffn.0", # Up proj
"ffn.proj_out": "ffn.2", # Down proj

# Global Norms & Modulation
"norm2.layer_norm": "norm3",
"scale_shift_table": "modulation",
"proj_out": "head.head"
}

# --- 3. Translation Logic ---
if scan_layers:
# Scanned Pattern: "blocks.attn1.query" -> "diffusion_model.blocks.{}.self_attn.q"
if nnx_path_str.startswith("blocks."):
inner_suffix = nnx_path_str[len("blocks."):]
if inner_suffix in suffix_map:
return f"diffusion_model.blocks.{{}}.{suffix_map[inner_suffix]}"
else:
# Unscanned Pattern: "blocks.0.attn1.query" -> "diffusion_model.blocks.0.self_attn.q"
m = re.match(r"^blocks\.(\d+)\.(.+)$", nnx_path_str)
if m:
idx, inner_suffix = m.group(1), m.group(2)
if inner_suffix in suffix_map:
return f"diffusion_model.blocks.{idx}.{suffix_map[inner_suffix]}"

return None

109 changes: 109 additions & 0 deletions src/maxdiffusion/loaders/wan_lora_nnx_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""NNX-based LoRA loader for WAN models."""

from flax import nnx
from .lora_base import LoRABaseMixin
from .lora_pipeline import StableDiffusionLoraLoaderMixin
from ..models import lora_nnx
from .. import max_logging
from . import lora_conversion_utils

class Wan2_1NnxLoraLoader(LoRABaseMixin):
"""
Handles loading LoRA weights into NNX-based WAN 2.1 model.
Assumes WAN pipeline contains 'transformer'
attributes that are NNX Modules.
"""

def load_lora_weights(
self,
pipeline: nnx.Module,
lora_model_path: str,
transformer_weight_name: str,
rank: int,
scale: float = 1.0,
scan_layers: bool = False,
**kwargs,
):
"""
Merges LoRA weights into the pipeline from a checkpoint.
"""
lora_loader = StableDiffusionLoraLoaderMixin()

merge_fn = lora_nnx.merge_lora_for_scanned if scan_layers else lora_nnx.merge_lora
def translate_fn(nnx_path_str):
return lora_conversion_utils.translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=scan_layers)

# Handle high noise model
if hasattr(pipeline, "transformer") and transformer_weight_name:
max_logging.log(f"Merging LoRA into transformer with rank={rank}")
h_state_dict, _ = lora_loader.lora_state_dict(
lora_model_path, weight_name=transformer_weight_name, **kwargs
)
merge_fn(pipeline.transformer, h_state_dict, rank, scale, translate_fn)
else:
max_logging.log("transformer not found or no weight name provided for LoRA.")

return pipeline

class Wan2_2NnxLoraLoader(LoRABaseMixin):
"""
Handles loading LoRA weights into NNX-based WAN 2.2 model.
Assumes WAN pipeline contains 'high_noise_transformer' and 'low_noise_transformer'
attributes that are NNX Modules.
"""

def load_lora_weights(
self,
pipeline: nnx.Module,
lora_model_path: str,
high_noise_weight_name: str,
low_noise_weight_name: str,
rank: int,
scale: float = 1.0,
scan_layers: bool = False,
**kwargs,
):
"""
Merges LoRA weights into the pipeline from a checkpoint.
"""
lora_loader = StableDiffusionLoraLoaderMixin()

merge_fn = lora_nnx.merge_lora_for_scanned if scan_layers else lora_nnx.merge_lora
def translate_fn(nnx_path_str: str):
return lora_conversion_utils.translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=scan_layers)

# Handle high noise model
if hasattr(pipeline, "high_noise_transformer") and high_noise_weight_name:
max_logging.log(f"Merging LoRA into high_noise_transformer with rank={rank}")
h_state_dict, _ = lora_loader.lora_state_dict(
lora_model_path, weight_name=high_noise_weight_name, **kwargs
)
merge_fn(pipeline.high_noise_transformer, h_state_dict, rank, scale, translate_fn)
else:
max_logging.log("high_noise_transformer not found or no weight name provided for LoRA.")

# Handle low noise model
if hasattr(pipeline, "low_noise_transformer") and low_noise_weight_name:
max_logging.log(f"Merging LoRA into low_noise_transformer with rank={rank}")
l_state_dict, _ = lora_loader.lora_state_dict(
lora_model_path, weight_name=low_noise_weight_name, **kwargs
)
merge_fn(pipeline.low_noise_transformer, l_state_dict, rank, scale, translate_fn)
else:
max_logging.log("low_noise_transformer not found or no weight name provided for LoRA.")

return pipeline
Loading
Loading