From e1b7221f752a67b8c3b071a75a92f815168c27ee Mon Sep 17 00:00:00 2001 From: Rishabh Manoj Date: Thu, 15 Jan 2026 18:55:28 +0000 Subject: [PATCH 1/3] Add LoRA support for WAN models --- src/maxdiffusion/configs/base_wan_14b.yml | 10 +- src/maxdiffusion/configs/base_wan_27b.yml | 11 +- src/maxdiffusion/configs/base_wan_i2v_14b.yml | 10 +- src/maxdiffusion/configs/base_wan_i2v_27b.yml | 11 +- src/maxdiffusion/generate_wan.py | 37 ++ src/maxdiffusion/loaders/__init__.py | 1 + .../loaders/lora_conversion_utils.py | 89 +++ .../loaders/wan_lora_nnx_loader.py | 109 ++++ src/maxdiffusion/models/lora_nnx.py | 517 ++++++++++++++++++ 9 files changed, 779 insertions(+), 16 deletions(-) create mode 100644 src/maxdiffusion/loaders/wan_lora_nnx_loader.py create mode 100644 src/maxdiffusion/models/lora_nnx.py diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 390ea3c6..bb9a8a3b 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -317,12 +317,14 @@ lightning_repo: "" lightning_ckpt: "" # LoRA parameters +enable_lora: False # Values are lists to support multiple LoRA loading during inference in the future. lora_config: { - lora_model_name_or_path: [], - weight_name: [], - adapter_name: [], - scale: [], + rank: [64], + lora_model_name_or_path: ["lightx2v/Wan2.1-Distill-Loras"], + weight_name: ["wan2.1_t2v_14b_lora_rank64_lightx2v_4step.safetensors"], + adapter_name: ["wan21-distill-lora"], + scale: [1.0], from_pt: [] } # Ex with values: diff --git a/src/maxdiffusion/configs/base_wan_27b.yml b/src/maxdiffusion/configs/base_wan_27b.yml index f2839cba..e3f8e5a9 100644 --- a/src/maxdiffusion/configs/base_wan_27b.yml +++ b/src/maxdiffusion/configs/base_wan_27b.yml @@ -316,12 +316,15 @@ lightning_repo: "" lightning_ckpt: "" # LoRA parameters +enable_lora: False # Values are lists to support multiple LoRA loading during inference in the future. lora_config: { - lora_model_name_or_path: [], - weight_name: [], - adapter_name: [], - scale: [], + rank: [64], + lora_model_name_or_path: ["lightx2v/Wan2.2-Distill-Loras"], + high_noise_weight_name: ["wan2.2_t2v_A14b_high_noise_lora_rank64_lightx2v_4step_1217.safetensors"], + low_noise_weight_name: ["wan2.2_t2v_A14b_low_noise_lora_rank64_lightx2v_4step_1217.safetensors"], + adapter_name: ["wan22-distill-lora"], + scale: [1.0], from_pt: [] } # Ex with values: diff --git a/src/maxdiffusion/configs/base_wan_i2v_14b.yml b/src/maxdiffusion/configs/base_wan_i2v_14b.yml index 07a84419..09c05cfd 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_14b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_14b.yml @@ -300,12 +300,14 @@ lightning_repo: "" lightning_ckpt: "" # LoRA parameters +enable_lora: False # Values are lists to support multiple LoRA loading during inference in the future. lora_config: { - lora_model_name_or_path: [], - weight_name: [], - adapter_name: [], - scale: [], + rank: [64], + lora_model_name_or_path: ["lightx2v/Wan2.1-Distill-Loras"], + weight_name: ["wan2.1_i2v_lora_rank64_lightx2v_4step.safetensors"], + adapter_name: ["wan21-distill-lora-i2v"], + scale: [1.0], from_pt: [] } # Ex with values: diff --git a/src/maxdiffusion/configs/base_wan_i2v_27b.yml b/src/maxdiffusion/configs/base_wan_i2v_27b.yml index bcc69e66..afe8bc54 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_27b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_27b.yml @@ -312,12 +312,15 @@ lightning_repo: "" lightning_ckpt: "" # LoRA parameters +enable_lora: False # Values are lists to support multiple LoRA loading during inference in the future. lora_config: { - lora_model_name_or_path: [], - weight_name: [], - adapter_name: [], - scale: [], + rank: [64], + lora_model_name_or_path: ["lightx2v/Wan2.2-Distill-Loras"], + high_noise_weight_name: ["wan2.2_i2v_A14b_high_noise_lora_rank64_lightx2v_4step_1022.safetensors"], + low_noise_weight_name: ["wan2.2_i2v_A14b_low_noise_lora_rank64_lightx2v_4step_1022.safetensors"], + adapter_name: ["wan22-distill-lora"], + scale: [1.0], from_pt: [] } # Ex with values: diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index b739a6f3..fef2d4fa 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -28,6 +28,7 @@ from google.cloud import storage import flax from maxdiffusion.common_types import WAN2_1, WAN2_2 +from maxdiffusion.loaders.wan_lora_nnx_loader import Wan2_1NnxLoraLoader, Wan2_2NnxLoraLoader def upload_video_to_gcs(output_dir: str, video_path: str): @@ -190,6 +191,42 @@ def run(config, pipeline=None, filename_prefix=""): else: raise ValueError(f"Unsupported model_name for checkpointer: {model_key}") pipeline, _, _ = checkpoint_loader.load_checkpoint() + + # If LoRA is specified, inject layers and load weights. + if config.enable_lora and hasattr(config, "lora_config") and config.lora_config and config.lora_config["lora_model_name_or_path"]: + if model_key == WAN2_1: + lora_loader = Wan2_1NnxLoraLoader() + lora_config = config.lora_config + + if len(lora_config["lora_model_name_or_path"]) > 1: + max_logging.log("Found multiple LoRAs in config, but only loading the first one.") + + pipeline = lora_loader.load_lora_weights( + pipeline, + lora_config["lora_model_name_or_path"][0], + transformer_weight_name=lora_config["weight_name"][0], + rank=lora_config["rank"][0], + scale=lora_config["scale"][0], + scan_layers=config.scan_layers, + ) + + if model_key == WAN2_2: + lora_loader = Wan2_2NnxLoraLoader() + lora_config = config.lora_config + + if len(lora_config["lora_model_name_or_path"]) > 1: + max_logging.log("Found multiple LoRAs in config, but only loading the first one.") + + pipeline = lora_loader.load_lora_weights( + pipeline, + lora_config["lora_model_name_or_path"][0], + high_noise_weight_name=lora_config["high_noise_weight_name"][0], + low_noise_weight_name=lora_config["low_noise_weight_name"][0], + rank=lora_config["rank"][0], + scale=lora_config["scale"][0], + scan_layers=config.scan_layers, + ) + s0 = time.perf_counter() # Using global_batch_size_to_train_on so not to create more config variables diff --git a/src/maxdiffusion/loaders/__init__.py b/src/maxdiffusion/loaders/__init__.py index 2c9e973d..e7abb88a 100644 --- a/src/maxdiffusion/loaders/__init__.py +++ b/src/maxdiffusion/loaders/__init__.py @@ -14,3 +14,4 @@ from .lora_pipeline import StableDiffusionLoraLoaderMixin from .flux_lora_pipeline import FluxLoraLoaderMixin +from .wan_lora_nnx_loader import Wan2_1NnxLoraLoader, Wan2_2NnxLoraLoader diff --git a/src/maxdiffusion/loaders/lora_conversion_utils.py b/src/maxdiffusion/loaders/lora_conversion_utils.py index 534a440d..73227438 100644 --- a/src/maxdiffusion/loaders/lora_conversion_utils.py +++ b/src/maxdiffusion/loaders/lora_conversion_utils.py @@ -608,3 +608,92 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None): raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.") return new_state_dict + + +def translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False): + """ + Translates WAN NNX path to Diffusers/LoRA keys. + Verified against wan_utils.py mappings. + """ + + # --- 1. Embeddings (Exact Matches) --- + if nnx_path_str == 'condition_embedder.text_embedder.linear_1': + return 'diffusion_model.text_embedding.0' + if nnx_path_str == 'condition_embedder.text_embedder.linear_2': + return 'diffusion_model.text_embedding.2' + if nnx_path_str == 'condition_embedder.time_embedder.linear_1': + return 'diffusion_model.time_embedding.0' + if nnx_path_str == 'condition_embedder.time_embedder.linear_2': + return 'diffusion_model.time_embedding.2' + if nnx_path_str == 'condition_embedder.image_embedder.norm1.layer_norm': + return 'diffusion_model.img_emb.proj.0' + if nnx_path_str == 'condition_embedder.image_embedder.ff.net_0': + return 'diffusion_model.img_emb.proj.1' + if nnx_path_str == 'condition_embedder.image_embedder.ff.net_2': + return 'diffusion_model.img_emb.proj.3' + if nnx_path_str == 'condition_embedder.image_embedder.norm2.layer_norm': + return 'diffusion_model.img_emb.proj.4' + if nnx_path_str == 'patch_embedding': + return 'diffusion_model.patch_embedding' + if nnx_path_str == 'proj_out': + return 'diffusion_model.head.head' + if nnx_path_str == 'condition_embedder.time_proj': + return 'diffusion_model.time_projection.1' + + + + + # --- 2. Map NNX Suffixes to LoRA Suffixes --- + suffix_map = { + # Self Attention (attn1) + "attn1.query": "self_attn.q", + "attn1.key": "self_attn.k", + "attn1.value": "self_attn.v", + "attn1.proj_attn": "self_attn.o", + + # Self Attention Norms (QK Norm) + "attn1.norm_q": "self_attn.norm_q", + "attn1.norm_k": "self_attn.norm_k", + + # Cross Attention (attn2) + "attn2.query": "cross_attn.q", + "attn2.key": "cross_attn.k", + "attn2.value": "cross_attn.v", + "attn2.proj_attn": "cross_attn.o", + + # Cross Attention Norms (QK Norm) + "attn2.norm_q": "cross_attn.norm_q", + "attn2.norm_k": "cross_attn.norm_k", + + # Cross Attention img + "attn2.add_k_proj": "cross_attn.k_img", + "attn2.add_v_proj": "cross_attn.v_img", + "attn2.norm_added_k": "cross_attn.norm_k_img", + + # Feed Forward (ffn) + "ffn.act_fn.proj": "ffn.0", # Up proj + "ffn.proj_out": "ffn.2", # Down proj + + # Global Norms & Modulation + "norm2.layer_norm": "norm3", + "scale_shift_table": "modulation", + "proj_out": "head.head" + } + + # --- 3. Translation Logic --- + if scan_layers: + # Scanned Pattern: "blocks.attn1.query" -> "diffusion_model.blocks.{}.self_attn.q" + if nnx_path_str.startswith("blocks."): + inner_suffix = nnx_path_str[len("blocks."):] + if inner_suffix in suffix_map: + return f"diffusion_model.blocks.{{}}.{suffix_map[inner_suffix]}" + else: + # Unscanned Pattern: "blocks.0.attn1.query" -> "diffusion_model.blocks.0.self_attn.q" + m = re.match(r"^blocks\.(\d+)\.(.+)$", nnx_path_str) + if m: + idx, inner_suffix = m.group(1), m.group(2) + if inner_suffix in suffix_map: + return f"diffusion_model.blocks.{idx}.{suffix_map[inner_suffix]}" + + return None + diff --git a/src/maxdiffusion/loaders/wan_lora_nnx_loader.py b/src/maxdiffusion/loaders/wan_lora_nnx_loader.py new file mode 100644 index 00000000..95b33237 --- /dev/null +++ b/src/maxdiffusion/loaders/wan_lora_nnx_loader.py @@ -0,0 +1,109 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""NNX-based LoRA loader for WAN models.""" + +from flax import nnx +from .lora_base import LoRABaseMixin +from .lora_pipeline import StableDiffusionLoraLoaderMixin +from ..models import lora_nnx +from .. import max_logging +from . import lora_conversion_utils + +class Wan2_1NnxLoraLoader(LoRABaseMixin): + """ + Handles loading LoRA weights into NNX-based WAN 2.1 model. + Assumes WAN pipeline contains 'transformer' + attributes that are NNX Modules. + """ + + def load_lora_weights( + self, + pipeline: nnx.Module, + lora_model_path: str, + transformer_weight_name: str, + rank: int, + scale: float = 1.0, + scan_layers: bool = False, + **kwargs, + ): + """ + Merges LoRA weights into the pipeline from a checkpoint. + """ + lora_loader = StableDiffusionLoraLoaderMixin() + + merge_fn = lora_nnx.merge_lora_for_scanned if scan_layers else lora_nnx.merge_lora + def translate_fn(nnx_path_str): + return lora_conversion_utils.translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=scan_layers) + + # Handle high noise model + if hasattr(pipeline, "transformer") and transformer_weight_name: + max_logging.log(f"Merging LoRA into transformer with rank={rank}") + h_state_dict, _ = lora_loader.lora_state_dict( + lora_model_path, weight_name=transformer_weight_name, **kwargs + ) + merge_fn(pipeline.transformer, h_state_dict, rank, scale, translate_fn) + else: + max_logging.log("transformer not found or no weight name provided for LoRA.") + + return pipeline + +class Wan2_2NnxLoraLoader(LoRABaseMixin): + """ + Handles loading LoRA weights into NNX-based WAN 2.2 model. + Assumes WAN pipeline contains 'high_noise_transformer' and 'low_noise_transformer' + attributes that are NNX Modules. + """ + + def load_lora_weights( + self, + pipeline: nnx.Module, + lora_model_path: str, + high_noise_weight_name: str, + low_noise_weight_name: str, + rank: int, + scale: float = 1.0, + scan_layers: bool = False, + **kwargs, + ): + """ + Merges LoRA weights into the pipeline from a checkpoint. + """ + lora_loader = StableDiffusionLoraLoaderMixin() + + merge_fn = lora_nnx.merge_lora_for_scanned if scan_layers else lora_nnx.merge_lora + def translate_fn(nnx_path_str: str): + return lora_conversion_utils.translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=scan_layers) + + # Handle high noise model + if hasattr(pipeline, "high_noise_transformer") and high_noise_weight_name: + max_logging.log(f"Merging LoRA into high_noise_transformer with rank={rank}") + h_state_dict, _ = lora_loader.lora_state_dict( + lora_model_path, weight_name=high_noise_weight_name, **kwargs + ) + merge_fn(pipeline.high_noise_transformer, h_state_dict, rank, scale, translate_fn) + else: + max_logging.log("high_noise_transformer not found or no weight name provided for LoRA.") + + # Handle low noise model + if hasattr(pipeline, "low_noise_transformer") and low_noise_weight_name: + max_logging.log(f"Merging LoRA into low_noise_transformer with rank={rank}") + l_state_dict, _ = lora_loader.lora_state_dict( + lora_model_path, weight_name=low_noise_weight_name, **kwargs + ) + merge_fn(pipeline.low_noise_transformer, l_state_dict, rank, scale, translate_fn) + else: + max_logging.log("low_noise_transformer not found or no weight name provided for LoRA.") + + return pipeline diff --git a/src/maxdiffusion/models/lora_nnx.py b/src/maxdiffusion/models/lora_nnx.py new file mode 100644 index 00000000..d2dbb5f6 --- /dev/null +++ b/src/maxdiffusion/models/lora_nnx.py @@ -0,0 +1,517 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import re +import torch +import jax +from jax import dlpack +import jax.numpy as jnp +from flax import nnx +from .. import max_logging +import numpy as np + +# ----------------------------------------------------------------------------- +# JIT Helpers (The Fix for Sharding & Device-Side Computation) +# ----------------------------------------------------------------------------- + +@jax.jit +def _compute_and_add_single_jit(kernel, bias, down, up, scale, w_diff, b_diff): + """ + Applies LoRA + Weight Diff + Bias Diff on device. + """ + # 1. Apply LoRA (if valid) + if down is not None and up is not None: + # down: (Rank, In), up: (Out, Rank) -> Result: (In, Out) + # Note: We reshape to kernel shape to handle 1x1 convs + delta = (down.T @ up.T).reshape(kernel.shape) + kernel = kernel + (delta * scale).astype(kernel.dtype) + + # 2. Apply Full Weight Diff (if valid) + if w_diff is not None: + kernel = kernel + w_diff.astype(kernel.dtype) + + # 3. Apply Bias Diff (if valid and bias exists) + if bias is not None and b_diff is not None: + bias = bias + b_diff.astype(bias.dtype) + + return kernel, bias + +@jax.jit +def _compute_and_add_scanned_jit(kernel, downs, ups, alphas, global_scale, w_diffs=None, b_diffs=None, bias=None): + """ + Applies scanned LoRA + Diffs. + """ + # 1. Apply LoRA + if downs is not None and ups is not None: + rank = downs.shape[1] + scales = (global_scale * alphas / rank) + # Batch Matmul: (L, In, Out) + delta = jnp.matmul(jnp.swapaxes(downs, 1, 2), jnp.swapaxes(ups, 1, 2)) + delta = (delta * scales).astype(kernel.dtype) + kernel = kernel + delta.reshape(kernel.shape) + + # 2. Apply Scanned Weight Diffs (L, ...) + if w_diffs is not None: + kernel = kernel + w_diffs.astype(kernel.dtype) + + # 3. Apply Scanned Bias Diffs (L, ...) + # Note: Scanned bias is usually shape (L, Out) + if bias is not None and b_diffs is not None: + bias = bias + b_diffs.astype(bias.dtype) + + return kernel, bias + +# ----------------------------------------------------------------------------- + +def _to_jax_array(v): + if isinstance(v, torch.Tensor): + return dlpack.from_dlpack(v) + return jnp.array(v) + +def parse_lora_dict(state_dict): + """Helper to parse state_dict into structured params including diffs.""" + lora_params = {} + for k, v in state_dict.items(): + # Alpha + if k.endswith(".alpha"): + key_base = k[:-len(".alpha")] + if key_base not in lora_params: + lora_params[key_base] = {} + lora_params[key_base]["alpha"] = _to_jax_array(v) + continue + + # Bias Diff (e.g., "layer.diff_b") + if k.endswith(".diff_b"): + key_base = k[:-len(".diff_b")] + if key_base not in lora_params: + lora_params[key_base] = {} + lora_params[key_base]["diff_b"] = _to_jax_array(v) + continue + + # Weight Diff (e.g., "layer.diff") + if k.endswith(".diff"): + key_base = k[:-len(".diff")] + if key_base not in lora_params: + lora_params[key_base] = {} + lora_params[key_base]["diff"] = _to_jax_array(v) + continue + + # Standard LoRA + m = re.match(r"^(.*?)_lora\.(down|up)\.weight$", k) + if not m: + m = re.match(r"^(.*?)\.lora\.(down|up)\.weight$", k) + if not m: + m = re.match(r"^(.*?)\.(lora_down|lora_up)\.weight$", k) + + if m: + key_base, weight_type = m.group(1), m.group(2).replace("lora_", "") + if key_base not in lora_params: + lora_params[key_base] = {} + lora_params[key_base][weight_type] = _to_jax_array(v) + else: + # Fallback for exact matches of diffs if regex failed above + pass + + return lora_params + +def merge_lora(model: nnx.Module, state_dict: dict, rank: int, scale: float, translate_fn=None): + """ + Merges weights for non-scanned layers (Embeddings, singular Dense, etc). + Now supports diff and diff_b. + """ + lora_params = parse_lora_dict(state_dict) + max_logging.log(f"Parsed {len(lora_params)} unique module keys.") + matched_keys = set() + + assigned_count = 0 + for path, module in nnx.iter_graph(model): + if not isinstance(module, (nnx.Linear, nnx.Conv, nnx.LayerNorm, nnx.RMSNorm, nnx.Embed)): + continue + + nnx_path_str = ".".join(map(str, path)) + lora_key = translate_fn(nnx_path_str) if translate_fn else None + + if lora_key and lora_key in lora_params: + matched_keys.add(lora_key) + weights = lora_params[lora_key] + + is_conv_kxk_locon = False + if isinstance(module, nnx.Conv) and module.kernel_size != (1,1) and "down" in weights and "up" in weights: + is_conv_kxk_locon = True + + # Handle Embeddings + if isinstance(module, nnx.Embed): + if "diff" in weights and hasattr(module, 'embedding'): + module.embedding.value += np.array(weights["diff"]).reshape(module.embedding.shape).astype(module.embedding.dtype) + assigned_count += 1 + continue + # Handle Norms + elif isinstance(module, (nnx.LayerNorm, nnx.RMSNorm)): + scale_diff = weights.get("diff", None) + bias_diff = weights.get("diff_b", None) + updated = False + if scale_diff is not None and hasattr(module, 'scale') and module.scale is not None: + module.scale.value += np.array(scale_diff).reshape(module.scale.shape).astype(module.scale.dtype) + updated = True + if bias_diff is not None and isinstance(module, nnx.LayerNorm) and hasattr(module, 'bias') and module.bias is not None: + module.bias.value += np.array(bias_diff).reshape(module.bias.shape).astype(module.bias.dtype) + updated = True + if updated: + assigned_count += 1 + continue + + # Prepare LoRA terms + down_w, up_w, current_scale = None, None, None + if "down" in weights and "up" in weights and not is_conv_kxk_locon: + down_w, up_w = weights["down"], weights["up"] + down_w, up_w = np.array(down_w), np.array(up_w) # CPU convert + + # Squeeze dimensions if needed (Conv 1x1 or Linear) + if isinstance(module, nnx.Conv) and module.kernel_size == (1, 1): + down_w, up_w = np.squeeze(down_w), np.squeeze(up_w) + + rank = down_w.shape[0] if down_w.ndim > 0 else 0 + alpha = float(weights.get("alpha", rank)) + current_scale = scale * alpha / rank + + # Prepare Diff terms + w_diff = weights.get("diff", None) + b_diff = weights.get("diff_b", None) + + if w_diff is not None: + w_diff = np.array(w_diff) + # Transpose weights from PyTorch OIHW/OIDHW to Flax HWIO/DHWIO if needed. + if isinstance(module, nnx.Conv): + if w_diff.ndim == 5: + w_diff = w_diff.transpose((2,3,4,1,0)) + elif w_diff.ndim == 4: + w_diff = w_diff.transpose((2,3,1,0)) + elif isinstance(module, nnx.Linear) and w_diff.ndim == 2: + w_diff = w_diff.transpose((1,0)) + if b_diff is not None: + b_diff = np.array(b_diff) + + # If LoCON, compute delta and add to w_diff + if is_conv_kxk_locon: + dw, uw = np.array(weights['down']), np.array(weights['up']) + rank, in_c, *k_dims = dw.shape + out_c = uw.shape[0] + alpha = float(weights.get("alpha", rank)) + + delta_pt = (uw.reshape(out_c, rank) @ dw.reshape(rank, -1)).reshape(out_c, in_c, *k_dims) + + # Transpose to flax + if delta_pt.ndim == 5: + delta_fx = delta_pt.transpose((2,3,4,1,0)) + else: + delta_fx = delta_pt.transpose((2,3,1,0)) + + lora_delta = delta_fx * (scale * alpha / rank) + if w_diff is None: + w_diff = lora_delta.astype(np.float32) + else: + w_diff += lora_delta.astype(w_diff.dtype) + + # Check for Bias existence + bias_val = module.bias.value if module.bias is not None else None + + # --- EXECUTE JIT UPDATE --- + if down_w is not None or w_diff is not None or b_diff is not None: + new_kernel, new_bias = _compute_and_add_single_jit( + module.kernel.value, + bias_val, + down_w, up_w, current_scale, + w_diff, b_diff + ) + + module.kernel.value = new_kernel + if new_bias is not None: + module.bias.value = new_bias + + assigned_count +=1 + else: + max_logging.log(f"Matched key {lora_key} but found no actionable weights.") + + max_logging.log(f"Merged weights into {assigned_count} layers.") + unmatched_keys = set(lora_params.keys()) - matched_keys + if unmatched_keys: + max_logging.log(f"{len(unmatched_keys)} key(s) in LoRA dictionary were not applied to any layer in the model: {unmatched_keys}") + + +def merge_lora_for_scanned(model: nnx.Module, state_dict: dict, rank: int, scale: float, translate_fn=None): + """ + Device-Side Optimized Merge for Scanned Layers. + Now supports diff and diff_b. + """ + lora_params = parse_lora_dict(state_dict) + max_logging.log(f"Parsed {len(lora_params)} keys for scanned merge.") + matched_keys = set() + + assigned_count = 0 + for path, module in nnx.iter_graph(model): + if not isinstance(module, (nnx.Linear, nnx.Conv, nnx.LayerNorm, nnx.RMSNorm, nnx.Embed)): + continue + + nnx_path_str = ".".join(map(str, path)) + lora_key_template = translate_fn(nnx_path_str) if translate_fn else None + + if not lora_key_template: + continue + + # Determine if layer is scanned based on parameter dimensions + is_scanned = False + if isinstance(module, nnx.Embed) and hasattr(module, 'embedding'): + is_scanned = module.embedding.ndim > 2 + elif isinstance(module, (nnx.LayerNorm, nnx.RMSNorm)) and hasattr(module, 'scale') and module.scale is not None: + is_scanned = module.scale.ndim > 1 + elif isinstance(module, nnx.Linear): + is_scanned = module.kernel.ndim == 3 + elif isinstance(module, nnx.Conv): + is_scanned = module.kernel.ndim == 5 + + # If layer is not scanned, merge it using single-layer logic + if not is_scanned: + lora_key = lora_key_template + if lora_key in lora_params: + matched_keys.add(lora_key) + weights = lora_params[lora_key] + is_conv_kxk_locon = isinstance(module, nnx.Conv) and module.kernel_size != (1,1) and "down" in weights and "up" in weights + + if isinstance(module, nnx.Embed): + if "diff" in weights and hasattr(module, 'embedding'): + module.embedding.value += np.array(weights["diff"]).reshape(module.embedding.shape).astype(module.embedding.dtype) + assigned_count += 1 + elif isinstance(module, (nnx.LayerNorm, nnx.RMSNorm)): + scale_diff = weights.get("diff", None) + bias_diff = weights.get("diff_b", None) + updated = False + if scale_diff is not None and hasattr(module, 'scale') and module.scale is not None: + module.scale.value += np.array(scale_diff).reshape(module.scale.shape).astype(module.scale.dtype) + updated = True + if bias_diff is not None and isinstance(module, nnx.LayerNorm) and hasattr(module, 'bias') and module.bias is not None: + module.bias.value += np.array(bias_diff).reshape(module.bias.shape).astype(module.bias.dtype) + updated = True + if updated: + assigned_count += 1 + elif isinstance(module, (nnx.Linear, nnx.Conv)): + down_w, up_w, current_scale_ = None, None, None + if "down" in weights and "up" in weights and not is_conv_kxk_locon: + down_w, up_w = np.array(weights["down"]), np.array(weights["up"]) + if isinstance(module, nnx.Conv): + down_w, up_w = np.squeeze(down_w), np.squeeze(up_w) + rank, alpha = down_w.shape[0], float(weights.get("alpha", down_w.shape[0])) + current_scale_ = scale * alpha / rank + + w_diff, b_diff = weights.get("diff", None), weights.get("diff_b", None) + if w_diff is not None: + w_diff = np.array(w_diff) + if isinstance(module, nnx.Conv): + if w_diff.ndim == 5: + w_diff = w_diff.transpose((2,3,4,1,0)) + elif w_diff.ndim == 4: + w_diff = w_diff.transpose((2,3,1,0)) + elif isinstance(module, nnx.Linear) and w_diff.ndim == 2: + w_diff = w_diff.transpose((1,0)) + if b_diff is not None: + b_diff = np.array(b_diff) + if is_conv_kxk_locon: + dw, uw = np.array(weights['down']), np.array(weights['up']) + rank, in_c, *k_dims = dw.shape + out_c = uw.shape[0] + alpha = float(weights.get("alpha", rank)) + delta_pt = (uw.reshape(out_c, rank) @ dw.reshape(rank, -1)).reshape(out_c, in_c, *k_dims) + if delta_pt.ndim == 5: + delta_fx = delta_pt.transpose((2,3,4,1,0)) + else: + delta_fx = delta_pt.transpose((2,3,1,0)) + lora_delta = delta_fx * (scale * alpha / rank) + if w_diff is None: + w_diff = lora_delta.astype(np.float32) + else: + w_diff += lora_delta.astype(w_diff.dtype) + + bias_val = module.bias.value if module.bias is not None else None + if down_w is not None or w_diff is not None or b_diff is not None: + k, b = _compute_and_add_single_jit(module.kernel.value, bias_val, down_w, up_w, current_scale_, w_diff, b_diff) + module.kernel.value = k + if b is not None: + module.bias.value = b + assigned_count +=1 + continue + + # If we reach here, layer is SCANNED + if isinstance(module, nnx.Embed): + num_layers = module.embedding.shape[0] + embed_diffs_to_add = np.zeros_like(module.embedding.value) + updated = False + for i in range(num_layers): + lora_key = lora_key_template.format(i) + if lora_key in lora_params: + matched_keys.add(lora_key) + if "diff" in lora_params[lora_key]: + embed_diffs_to_add[i] = np.array(lora_params[lora_key]["diff"]).reshape(module.embedding.shape[1:]) + updated = True + if updated: + module.embedding.value += embed_diffs_to_add.astype(module.embedding.dtype) + assigned_count += 1 + continue + elif isinstance(module, (nnx.LayerNorm, nnx.RMSNorm)): + num_layers = module.scale.shape[0] + scale_diffs_to_add = np.zeros_like(module.scale.value) + bias_diffs_to_add = np.zeros_like(module.bias.value) if isinstance(module, nnx.LayerNorm) and hasattr(module, 'bias') and module.bias is not None else None + updated_scale, updated_bias = False, False + for i in range(num_layers): + lora_key = lora_key_template.format(i) + if lora_key in lora_params: + matched_keys.add(lora_key) + weights = lora_params[lora_key] + if "diff" in weights: + scale_diffs_to_add[i] = np.array(weights["diff"]).reshape(module.scale.shape[1:]) + updated_scale = True + if "diff_b" in weights and bias_diffs_to_add is not None: + bias_diffs_to_add[i] = np.array(weights["diff_b"]).reshape(module.bias.shape[1:]) + updated_bias = True + if updated_scale: + module.scale.value += scale_diffs_to_add.astype(module.scale.dtype) + if updated_bias and bias_diffs_to_add is not None: + module.bias.value += bias_diffs_to_add.astype(module.bias.dtype) + if updated_scale or updated_bias: + assigned_count += 1 + continue + elif isinstance(module, (nnx.Linear, nnx.Conv)): + is_linear = isinstance(module, nnx.Linear) + is_conv = isinstance(module, nnx.Conv) + is_conv_kxk = isinstance(module, nnx.Conv) and module.kernel_size != (1,1) + if is_linear: + num_layers, in_feat, out_feat = module.kernel.shape + else: # Conv + num_layers = module.kernel.shape[0] + in_feat, out_feat = module.kernel.shape[3], module.kernel.shape[4] + else: + # Should not happen based on is_scanned logic + continue + + # 1. Scan for Rank (Fallback use rank in config file) + found_rank = rank + for i in range(num_layers): + k = lora_key_template.format(i) + if k in lora_params and "down" in lora_params[k]: + found_rank = lora_params[k]["down"].shape[0] + break + + # 2. Pre-allocate Buffers (CPU) + # LoRA Buffers + stack_down = np.zeros((num_layers, found_rank, in_feat), dtype=np.float32) + stack_up = np.zeros((num_layers, out_feat, found_rank), dtype=np.float32) + stack_alpha = np.zeros((num_layers, 1, 1), dtype=np.float32) + + # Diff Buffers + # Initialize as None, allocate only if found to save memory + stack_w_diff = None + stack_b_diff = None + + has_lora = False + has_diff = False + + for i in range(num_layers): + lora_key = lora_key_template.format(i) + if lora_key in lora_params: + matched_keys.add(lora_key) + w = lora_params[lora_key] + + # --- Fill LoRA --- + if "down" in w: + d, u = np.array(w["down"]), np.array(w["up"]) + alpha = float(w.get("alpha", d.shape[0])) + rank = d.shape[0] + + if is_conv_kxk: + # For LoCON kxk, compute delta and merge into stack_w_diff + rank, in_c, *k_dims = d.shape + out_c = u.shape[0] + delta_pt = (u.reshape(out_c, rank) @ d.reshape(rank, -1)).reshape(out_c, in_c, *k_dims) + if delta_pt.ndim == 5: + delta_fx = delta_pt.transpose((2,3,4,1,0)) + else: + delta_fx = delta_pt.transpose((2,3,1,0)) + + lora_delta = delta_fx * (scale * alpha / rank) + if stack_w_diff is None: + stack_w_diff = np.zeros(module.kernel.shape, dtype=np.float32) + stack_w_diff[i] += lora_delta.reshape(stack_w_diff[i].shape).astype(stack_w_diff.dtype) + has_diff = True # Mark as having diff because we merged LoRA into w_diff + else: + # For Linear or 1x1 Conv, prepare for JIT + if d.ndim > 2: + d = np.squeeze(d) + if u.ndim > 2: + u = np.squeeze(u) + stack_down[i] = d + stack_up[i] = u + stack_alpha[i] = alpha + has_lora = True + + # --- Fill Weight Diff --- + if "diff" in w: + if stack_w_diff is None: + stack_w_diff = np.zeros(module.kernel.shape, dtype=np.float32) + wd = np.array(w["diff"]) + # Transpose weights from PyTorch OIHW/OIDHW to Flax HWIO/DHWIO if needed. + if is_conv: + if wd.ndim == 5: + wd = wd.transpose((2,3,4,1,0)) + elif wd.ndim == 4: + wd = wd.transpose((2,3,1,0)) + elif is_linear and wd.ndim == 2: + wd = wd.transpose((1,0)) + + stack_w_diff[i] += wd.reshape(stack_w_diff[i].shape) + has_diff = True + + # --- Fill Bias Diff --- + if "diff_b" in w: + if stack_b_diff is None: + # Bias shape: Linear (L, Out), Conv (L, Out) usually + stack_b_diff = np.zeros((num_layers, out_feat), dtype=np.float32) + bd = np.array(w["diff_b"]) + stack_b_diff[i] = bd.flatten() + has_diff = True + + if has_lora or has_diff: + bias_val = module.bias.value if module.bias is not None else None + + # Call JIT + new_k, new_b = _compute_and_add_scanned_jit( + module.kernel.value, + stack_down if has_lora else None, + stack_up if has_lora else None, + stack_alpha if has_lora else None, + scale, + stack_w_diff, + stack_b_diff, + bias_val + ) + + module.kernel.value = new_k + if new_b is not None: + module.bias.value = new_b + + assigned_count += 1 + + max_logging.log(f"Merged weights into {assigned_count} scanned layers.") + unmatched_keys = set(lora_params.keys()) - matched_keys + if unmatched_keys: + max_logging.log(f"{len(unmatched_keys)} key(s) in LoRA dictionary were not applied to any layer in the model: {unmatched_keys}") From 5e1d2b32a99952f0190726402080cf01ce63fb47 Mon Sep 17 00:00:00 2001 From: Rishabh Manoj Date: Fri, 23 Jan 2026 08:39:23 +0000 Subject: [PATCH 2/3] Formatting through pyink --- .../loaders/lora_conversion_utils.py | 158 ++- .../loaders/wan_lora_nnx_loader.py | 34 +- src/maxdiffusion/models/lora_nnx.py | 977 +++++++++--------- 3 files changed, 591 insertions(+), 578 deletions(-) diff --git a/src/maxdiffusion/loaders/lora_conversion_utils.py b/src/maxdiffusion/loaders/lora_conversion_utils.py index 73227438..55e5f8f8 100644 --- a/src/maxdiffusion/loaders/lora_conversion_utils.py +++ b/src/maxdiffusion/loaders/lora_conversion_utils.py @@ -611,89 +611,79 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None): def translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False): - """ - Translates WAN NNX path to Diffusers/LoRA keys. - Verified against wan_utils.py mappings. - """ - - # --- 1. Embeddings (Exact Matches) --- - if nnx_path_str == 'condition_embedder.text_embedder.linear_1': - return 'diffusion_model.text_embedding.0' - if nnx_path_str == 'condition_embedder.text_embedder.linear_2': - return 'diffusion_model.text_embedding.2' - if nnx_path_str == 'condition_embedder.time_embedder.linear_1': - return 'diffusion_model.time_embedding.0' - if nnx_path_str == 'condition_embedder.time_embedder.linear_2': - return 'diffusion_model.time_embedding.2' - if nnx_path_str == 'condition_embedder.image_embedder.norm1.layer_norm': - return 'diffusion_model.img_emb.proj.0' - if nnx_path_str == 'condition_embedder.image_embedder.ff.net_0': - return 'diffusion_model.img_emb.proj.1' - if nnx_path_str == 'condition_embedder.image_embedder.ff.net_2': - return 'diffusion_model.img_emb.proj.3' - if nnx_path_str == 'condition_embedder.image_embedder.norm2.layer_norm': - return 'diffusion_model.img_emb.proj.4' - if nnx_path_str == 'patch_embedding': - return 'diffusion_model.patch_embedding' - if nnx_path_str == 'proj_out': - return 'diffusion_model.head.head' - if nnx_path_str == 'condition_embedder.time_proj': - return 'diffusion_model.time_projection.1' - - - - - # --- 2. Map NNX Suffixes to LoRA Suffixes --- - suffix_map = { - # Self Attention (attn1) - "attn1.query": "self_attn.q", - "attn1.key": "self_attn.k", - "attn1.value": "self_attn.v", - "attn1.proj_attn": "self_attn.o", - - # Self Attention Norms (QK Norm) - "attn1.norm_q": "self_attn.norm_q", - "attn1.norm_k": "self_attn.norm_k", - - # Cross Attention (attn2) - "attn2.query": "cross_attn.q", - "attn2.key": "cross_attn.k", - "attn2.value": "cross_attn.v", - "attn2.proj_attn": "cross_attn.o", - - # Cross Attention Norms (QK Norm) - "attn2.norm_q": "cross_attn.norm_q", - "attn2.norm_k": "cross_attn.norm_k", - - # Cross Attention img - "attn2.add_k_proj": "cross_attn.k_img", - "attn2.add_v_proj": "cross_attn.v_img", - "attn2.norm_added_k": "cross_attn.norm_k_img", - - # Feed Forward (ffn) - "ffn.act_fn.proj": "ffn.0", # Up proj - "ffn.proj_out": "ffn.2", # Down proj - - # Global Norms & Modulation - "norm2.layer_norm": "norm3", - "scale_shift_table": "modulation", - "proj_out": "head.head" - } - - # --- 3. Translation Logic --- - if scan_layers: - # Scanned Pattern: "blocks.attn1.query" -> "diffusion_model.blocks.{}.self_attn.q" - if nnx_path_str.startswith("blocks."): - inner_suffix = nnx_path_str[len("blocks."):] - if inner_suffix in suffix_map: - return f"diffusion_model.blocks.{{}}.{suffix_map[inner_suffix]}" - else: - # Unscanned Pattern: "blocks.0.attn1.query" -> "diffusion_model.blocks.0.self_attn.q" - m = re.match(r"^blocks\.(\d+)\.(.+)$", nnx_path_str) - if m: - idx, inner_suffix = m.group(1), m.group(2) - if inner_suffix in suffix_map: - return f"diffusion_model.blocks.{idx}.{suffix_map[inner_suffix]}" + """ + Translates WAN NNX path to Diffusers/LoRA keys. + Verified against wan_utils.py mappings. + """ - return None + # --- 1. Embeddings (Exact Matches) --- + if nnx_path_str == "condition_embedder.text_embedder.linear_1": + return "diffusion_model.text_embedding.0" + if nnx_path_str == "condition_embedder.text_embedder.linear_2": + return "diffusion_model.text_embedding.2" + if nnx_path_str == "condition_embedder.time_embedder.linear_1": + return "diffusion_model.time_embedding.0" + if nnx_path_str == "condition_embedder.time_embedder.linear_2": + return "diffusion_model.time_embedding.2" + if nnx_path_str == "condition_embedder.image_embedder.norm1.layer_norm": + return "diffusion_model.img_emb.proj.0" + if nnx_path_str == "condition_embedder.image_embedder.ff.net_0": + return "diffusion_model.img_emb.proj.1" + if nnx_path_str == "condition_embedder.image_embedder.ff.net_2": + return "diffusion_model.img_emb.proj.3" + if nnx_path_str == "condition_embedder.image_embedder.norm2.layer_norm": + return "diffusion_model.img_emb.proj.4" + if nnx_path_str == "patch_embedding": + return "diffusion_model.patch_embedding" + if nnx_path_str == "proj_out": + return "diffusion_model.head.head" + if nnx_path_str == "condition_embedder.time_proj": + return "diffusion_model.time_projection.1" + + # --- 2. Map NNX Suffixes to LoRA Suffixes --- + suffix_map = { + # Self Attention (attn1) + "attn1.query": "self_attn.q", + "attn1.key": "self_attn.k", + "attn1.value": "self_attn.v", + "attn1.proj_attn": "self_attn.o", + # Self Attention Norms (QK Norm) + "attn1.norm_q": "self_attn.norm_q", + "attn1.norm_k": "self_attn.norm_k", + # Cross Attention (attn2) + "attn2.query": "cross_attn.q", + "attn2.key": "cross_attn.k", + "attn2.value": "cross_attn.v", + "attn2.proj_attn": "cross_attn.o", + # Cross Attention Norms (QK Norm) + "attn2.norm_q": "cross_attn.norm_q", + "attn2.norm_k": "cross_attn.norm_k", + # Cross Attention img + "attn2.add_k_proj": "cross_attn.k_img", + "attn2.add_v_proj": "cross_attn.v_img", + "attn2.norm_added_k": "cross_attn.norm_k_img", + # Feed Forward (ffn) + "ffn.act_fn.proj": "ffn.0", # Up proj + "ffn.proj_out": "ffn.2", # Down proj + # Global Norms & Modulation + "norm2.layer_norm": "norm3", + "scale_shift_table": "modulation", + "proj_out": "head.head", + } + # --- 3. Translation Logic --- + if scan_layers: + # Scanned Pattern: "blocks.attn1.query" -> "diffusion_model.blocks.{}.self_attn.q" + if nnx_path_str.startswith("blocks."): + inner_suffix = nnx_path_str[len("blocks.") :] + if inner_suffix in suffix_map: + return f"diffusion_model.blocks.{{}}.{suffix_map[inner_suffix]}" + else: + # Unscanned Pattern: "blocks.0.attn1.query" -> "diffusion_model.blocks.0.self_attn.q" + m = re.match(r"^blocks\.(\d+)\.(.+)$", nnx_path_str) + if m: + idx, inner_suffix = m.group(1), m.group(2) + if inner_suffix in suffix_map: + return f"diffusion_model.blocks.{idx}.{suffix_map[inner_suffix]}" + + return None diff --git a/src/maxdiffusion/loaders/wan_lora_nnx_loader.py b/src/maxdiffusion/loaders/wan_lora_nnx_loader.py index 95b33237..2fe691c5 100644 --- a/src/maxdiffusion/loaders/wan_lora_nnx_loader.py +++ b/src/maxdiffusion/loaders/wan_lora_nnx_loader.py @@ -21,6 +21,7 @@ from .. import max_logging from . import lora_conversion_utils + class Wan2_1NnxLoraLoader(LoRABaseMixin): """ Handles loading LoRA weights into NNX-based WAN 2.1 model. @@ -44,21 +45,21 @@ def load_lora_weights( lora_loader = StableDiffusionLoraLoaderMixin() merge_fn = lora_nnx.merge_lora_for_scanned if scan_layers else lora_nnx.merge_lora + def translate_fn(nnx_path_str): return lora_conversion_utils.translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=scan_layers) # Handle high noise model if hasattr(pipeline, "transformer") and transformer_weight_name: - max_logging.log(f"Merging LoRA into transformer with rank={rank}") - h_state_dict, _ = lora_loader.lora_state_dict( - lora_model_path, weight_name=transformer_weight_name, **kwargs - ) - merge_fn(pipeline.transformer, h_state_dict, rank, scale, translate_fn) + max_logging.log(f"Merging LoRA into transformer with rank={rank}") + h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=transformer_weight_name, **kwargs) + merge_fn(pipeline.transformer, h_state_dict, rank, scale, translate_fn) else: - max_logging.log("transformer not found or no weight name provided for LoRA.") + max_logging.log("transformer not found or no weight name provided for LoRA.") return pipeline + class Wan2_2NnxLoraLoader(LoRABaseMixin): """ Handles loading LoRA weights into NNX-based WAN 2.2 model. @@ -83,27 +84,24 @@ def load_lora_weights( lora_loader = StableDiffusionLoraLoaderMixin() merge_fn = lora_nnx.merge_lora_for_scanned if scan_layers else lora_nnx.merge_lora + def translate_fn(nnx_path_str: str): return lora_conversion_utils.translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=scan_layers) # Handle high noise model if hasattr(pipeline, "high_noise_transformer") and high_noise_weight_name: - max_logging.log(f"Merging LoRA into high_noise_transformer with rank={rank}") - h_state_dict, _ = lora_loader.lora_state_dict( - lora_model_path, weight_name=high_noise_weight_name, **kwargs - ) - merge_fn(pipeline.high_noise_transformer, h_state_dict, rank, scale, translate_fn) + max_logging.log(f"Merging LoRA into high_noise_transformer with rank={rank}") + h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=high_noise_weight_name, **kwargs) + merge_fn(pipeline.high_noise_transformer, h_state_dict, rank, scale, translate_fn) else: - max_logging.log("high_noise_transformer not found or no weight name provided for LoRA.") + max_logging.log("high_noise_transformer not found or no weight name provided for LoRA.") # Handle low noise model if hasattr(pipeline, "low_noise_transformer") and low_noise_weight_name: - max_logging.log(f"Merging LoRA into low_noise_transformer with rank={rank}") - l_state_dict, _ = lora_loader.lora_state_dict( - lora_model_path, weight_name=low_noise_weight_name, **kwargs - ) - merge_fn(pipeline.low_noise_transformer, l_state_dict, rank, scale, translate_fn) + max_logging.log(f"Merging LoRA into low_noise_transformer with rank={rank}") + l_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=low_noise_weight_name, **kwargs) + merge_fn(pipeline.low_noise_transformer, l_state_dict, rank, scale, translate_fn) else: - max_logging.log("low_noise_transformer not found or no weight name provided for LoRA.") + max_logging.log("low_noise_transformer not found or no weight name provided for LoRA.") return pipeline diff --git a/src/maxdiffusion/models/lora_nnx.py b/src/maxdiffusion/models/lora_nnx.py index d2dbb5f6..58ec94b6 100644 --- a/src/maxdiffusion/models/lora_nnx.py +++ b/src/maxdiffusion/models/lora_nnx.py @@ -1,17 +1,17 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. """ import re @@ -27,491 +27,516 @@ # JIT Helpers (The Fix for Sharding & Device-Side Computation) # ----------------------------------------------------------------------------- + @jax.jit def _compute_and_add_single_jit(kernel, bias, down, up, scale, w_diff, b_diff): - """ - Applies LoRA + Weight Diff + Bias Diff on device. - """ - # 1. Apply LoRA (if valid) - if down is not None and up is not None: - # down: (Rank, In), up: (Out, Rank) -> Result: (In, Out) - # Note: We reshape to kernel shape to handle 1x1 convs - delta = (down.T @ up.T).reshape(kernel.shape) - kernel = kernel + (delta * scale).astype(kernel.dtype) - - # 2. Apply Full Weight Diff (if valid) - if w_diff is not None: - kernel = kernel + w_diff.astype(kernel.dtype) - - # 3. Apply Bias Diff (if valid and bias exists) - if bias is not None and b_diff is not None: - bias = bias + b_diff.astype(bias.dtype) - - return kernel, bias + """ + Applies LoRA + Weight Diff + Bias Diff on device. + """ + # 1. Apply LoRA (if valid) + if down is not None and up is not None: + # down: (Rank, In), up: (Out, Rank) -> Result: (In, Out) + # Note: We reshape to kernel shape to handle 1x1 convs + delta = (down.T @ up.T).reshape(kernel.shape) + kernel = kernel + (delta * scale).astype(kernel.dtype) + + # 2. Apply Full Weight Diff (if valid) + if w_diff is not None: + kernel = kernel + w_diff.astype(kernel.dtype) + + # 3. Apply Bias Diff (if valid and bias exists) + if bias is not None and b_diff is not None: + bias = bias + b_diff.astype(bias.dtype) + + return kernel, bias + @jax.jit def _compute_and_add_scanned_jit(kernel, downs, ups, alphas, global_scale, w_diffs=None, b_diffs=None, bias=None): - """ - Applies scanned LoRA + Diffs. - """ - # 1. Apply LoRA - if downs is not None and ups is not None: - rank = downs.shape[1] - scales = (global_scale * alphas / rank) - # Batch Matmul: (L, In, Out) - delta = jnp.matmul(jnp.swapaxes(downs, 1, 2), jnp.swapaxes(ups, 1, 2)) - delta = (delta * scales).astype(kernel.dtype) - kernel = kernel + delta.reshape(kernel.shape) - - # 2. Apply Scanned Weight Diffs (L, ...) - if w_diffs is not None: - kernel = kernel + w_diffs.astype(kernel.dtype) - - # 3. Apply Scanned Bias Diffs (L, ...) - # Note: Scanned bias is usually shape (L, Out) - if bias is not None and b_diffs is not None: - bias = bias + b_diffs.astype(bias.dtype) - - return kernel, bias + """ + Applies scanned LoRA + Diffs. + """ + # 1. Apply LoRA + if downs is not None and ups is not None: + rank = downs.shape[1] + scales = global_scale * alphas / rank + # Batch Matmul: (L, In, Out) + delta = jnp.matmul(jnp.swapaxes(downs, 1, 2), jnp.swapaxes(ups, 1, 2)) + delta = (delta * scales).astype(kernel.dtype) + kernel = kernel + delta.reshape(kernel.shape) + + # 2. Apply Scanned Weight Diffs (L, ...) + if w_diffs is not None: + kernel = kernel + w_diffs.astype(kernel.dtype) + + # 3. Apply Scanned Bias Diffs (L, ...) + # Note: Scanned bias is usually shape (L, Out) + if bias is not None and b_diffs is not None: + bias = bias + b_diffs.astype(bias.dtype) + + return kernel, bias + # ----------------------------------------------------------------------------- + def _to_jax_array(v): - if isinstance(v, torch.Tensor): - return dlpack.from_dlpack(v) - return jnp.array(v) + if isinstance(v, torch.Tensor): + return dlpack.from_dlpack(v) + return jnp.array(v) + def parse_lora_dict(state_dict): - """Helper to parse state_dict into structured params including diffs.""" - lora_params = {} - for k, v in state_dict.items(): - # Alpha - if k.endswith(".alpha"): - key_base = k[:-len(".alpha")] - if key_base not in lora_params: - lora_params[key_base] = {} - lora_params[key_base]["alpha"] = _to_jax_array(v) - continue - - # Bias Diff (e.g., "layer.diff_b") - if k.endswith(".diff_b"): - key_base = k[:-len(".diff_b")] - if key_base not in lora_params: - lora_params[key_base] = {} - lora_params[key_base]["diff_b"] = _to_jax_array(v) - continue - - # Weight Diff (e.g., "layer.diff") - if k.endswith(".diff"): - key_base = k[:-len(".diff")] - if key_base not in lora_params: - lora_params[key_base] = {} - lora_params[key_base]["diff"] = _to_jax_array(v) - continue - - # Standard LoRA - m = re.match(r"^(.*?)_lora\.(down|up)\.weight$", k) - if not m: - m = re.match(r"^(.*?)\.lora\.(down|up)\.weight$", k) - if not m: - m = re.match(r"^(.*?)\.(lora_down|lora_up)\.weight$", k) - - if m: - key_base, weight_type = m.group(1), m.group(2).replace("lora_", "") - if key_base not in lora_params: - lora_params[key_base] = {} - lora_params[key_base][weight_type] = _to_jax_array(v) - else: - # Fallback for exact matches of diffs if regex failed above - pass + """Helper to parse state_dict into structured params including diffs.""" + lora_params = {} + for k, v in state_dict.items(): + # Alpha + if k.endswith(".alpha"): + key_base = k[: -len(".alpha")] + if key_base not in lora_params: + lora_params[key_base] = {} + lora_params[key_base]["alpha"] = _to_jax_array(v) + continue + + # Bias Diff (e.g., "layer.diff_b") + if k.endswith(".diff_b"): + key_base = k[: -len(".diff_b")] + if key_base not in lora_params: + lora_params[key_base] = {} + lora_params[key_base]["diff_b"] = _to_jax_array(v) + continue + + # Weight Diff (e.g., "layer.diff") + if k.endswith(".diff"): + key_base = k[: -len(".diff")] + if key_base not in lora_params: + lora_params[key_base] = {} + lora_params[key_base]["diff"] = _to_jax_array(v) + continue + + # Standard LoRA + m = re.match(r"^(.*?)_lora\.(down|up)\.weight$", k) + if not m: + m = re.match(r"^(.*?)\.lora\.(down|up)\.weight$", k) + if not m: + m = re.match(r"^(.*?)\.(lora_down|lora_up)\.weight$", k) + + if m: + key_base, weight_type = m.group(1), m.group(2).replace("lora_", "") + if key_base not in lora_params: + lora_params[key_base] = {} + lora_params[key_base][weight_type] = _to_jax_array(v) + else: + # Fallback for exact matches of diffs if regex failed above + pass + + return lora_params - return lora_params def merge_lora(model: nnx.Module, state_dict: dict, rank: int, scale: float, translate_fn=None): - """ - Merges weights for non-scanned layers (Embeddings, singular Dense, etc). - Now supports diff and diff_b. - """ - lora_params = parse_lora_dict(state_dict) - max_logging.log(f"Parsed {len(lora_params)} unique module keys.") - matched_keys = set() - - assigned_count = 0 - for path, module in nnx.iter_graph(model): - if not isinstance(module, (nnx.Linear, nnx.Conv, nnx.LayerNorm, nnx.RMSNorm, nnx.Embed)): - continue - - nnx_path_str = ".".join(map(str, path)) - lora_key = translate_fn(nnx_path_str) if translate_fn else None - - if lora_key and lora_key in lora_params: - matched_keys.add(lora_key) - weights = lora_params[lora_key] - - is_conv_kxk_locon = False - if isinstance(module, nnx.Conv) and module.kernel_size != (1,1) and "down" in weights and "up" in weights: - is_conv_kxk_locon = True - - # Handle Embeddings - if isinstance(module, nnx.Embed): - if "diff" in weights and hasattr(module, 'embedding'): - module.embedding.value += np.array(weights["diff"]).reshape(module.embedding.shape).astype(module.embedding.dtype) - assigned_count += 1 - continue - # Handle Norms - elif isinstance(module, (nnx.LayerNorm, nnx.RMSNorm)): - scale_diff = weights.get("diff", None) - bias_diff = weights.get("diff_b", None) - updated = False - if scale_diff is not None and hasattr(module, 'scale') and module.scale is not None: - module.scale.value += np.array(scale_diff).reshape(module.scale.shape).astype(module.scale.dtype) - updated = True - if bias_diff is not None and isinstance(module, nnx.LayerNorm) and hasattr(module, 'bias') and module.bias is not None: - module.bias.value += np.array(bias_diff).reshape(module.bias.shape).astype(module.bias.dtype) - updated = True - if updated: - assigned_count += 1 - continue - - # Prepare LoRA terms - down_w, up_w, current_scale = None, None, None - if "down" in weights and "up" in weights and not is_conv_kxk_locon: - down_w, up_w = weights["down"], weights["up"] - down_w, up_w = np.array(down_w), np.array(up_w) # CPU convert - - # Squeeze dimensions if needed (Conv 1x1 or Linear) - if isinstance(module, nnx.Conv) and module.kernel_size == (1, 1): - down_w, up_w = np.squeeze(down_w), np.squeeze(up_w) - - rank = down_w.shape[0] if down_w.ndim > 0 else 0 - alpha = float(weights.get("alpha", rank)) - current_scale = scale * alpha / rank - - # Prepare Diff terms - w_diff = weights.get("diff", None) - b_diff = weights.get("diff_b", None) - - if w_diff is not None: - w_diff = np.array(w_diff) - # Transpose weights from PyTorch OIHW/OIDHW to Flax HWIO/DHWIO if needed. - if isinstance(module, nnx.Conv): - if w_diff.ndim == 5: - w_diff = w_diff.transpose((2,3,4,1,0)) - elif w_diff.ndim == 4: - w_diff = w_diff.transpose((2,3,1,0)) - elif isinstance(module, nnx.Linear) and w_diff.ndim == 2: - w_diff = w_diff.transpose((1,0)) - if b_diff is not None: - b_diff = np.array(b_diff) - - # If LoCON, compute delta and add to w_diff - if is_conv_kxk_locon: - dw, uw = np.array(weights['down']), np.array(weights['up']) - rank, in_c, *k_dims = dw.shape - out_c = uw.shape[0] - alpha = float(weights.get("alpha", rank)) - - delta_pt = (uw.reshape(out_c, rank) @ dw.reshape(rank, -1)).reshape(out_c, in_c, *k_dims) - - # Transpose to flax - if delta_pt.ndim == 5: - delta_fx = delta_pt.transpose((2,3,4,1,0)) - else: - delta_fx = delta_pt.transpose((2,3,1,0)) - - lora_delta = delta_fx * (scale * alpha / rank) - if w_diff is None: - w_diff = lora_delta.astype(np.float32) - else: - w_diff += lora_delta.astype(w_diff.dtype) - - # Check for Bias existence - bias_val = module.bias.value if module.bias is not None else None - - # --- EXECUTE JIT UPDATE --- - if down_w is not None or w_diff is not None or b_diff is not None: - new_kernel, new_bias = _compute_and_add_single_jit( - module.kernel.value, - bias_val, - down_w, up_w, current_scale, - w_diff, b_diff - ) - - module.kernel.value = new_kernel - if new_bias is not None: - module.bias.value = new_bias - - assigned_count +=1 - else: - max_logging.log(f"Matched key {lora_key} but found no actionable weights.") + """ + Merges weights for non-scanned layers (Embeddings, singular Dense, etc). + Now supports diff and diff_b. + """ + lora_params = parse_lora_dict(state_dict) + max_logging.log(f"Parsed {len(lora_params)} unique module keys.") + matched_keys = set() + + assigned_count = 0 + for path, module in nnx.iter_graph(model): + if not isinstance(module, (nnx.Linear, nnx.Conv, nnx.LayerNorm, nnx.RMSNorm, nnx.Embed)): + continue + + nnx_path_str = ".".join(map(str, path)) + lora_key = translate_fn(nnx_path_str) if translate_fn else None + + if lora_key and lora_key in lora_params: + matched_keys.add(lora_key) + weights = lora_params[lora_key] + + is_conv_kxk_locon = False + if isinstance(module, nnx.Conv) and module.kernel_size != (1, 1) and "down" in weights and "up" in weights: + is_conv_kxk_locon = True + + # Handle Embeddings + if isinstance(module, nnx.Embed): + if "diff" in weights and hasattr(module, "embedding"): + module.embedding.value += np.array(weights["diff"]).reshape(module.embedding.shape).astype(module.embedding.dtype) + assigned_count += 1 + continue + # Handle Norms + elif isinstance(module, (nnx.LayerNorm, nnx.RMSNorm)): + scale_diff = weights.get("diff", None) + bias_diff = weights.get("diff_b", None) + updated = False + if scale_diff is not None and hasattr(module, "scale") and module.scale is not None: + module.scale.value += np.array(scale_diff).reshape(module.scale.shape).astype(module.scale.dtype) + updated = True + if ( + bias_diff is not None + and isinstance(module, nnx.LayerNorm) + and hasattr(module, "bias") + and module.bias is not None + ): + module.bias.value += np.array(bias_diff).reshape(module.bias.shape).astype(module.bias.dtype) + updated = True + if updated: + assigned_count += 1 + continue + + # Prepare LoRA terms + down_w, up_w, current_scale = None, None, None + if "down" in weights and "up" in weights and not is_conv_kxk_locon: + down_w, up_w = weights["down"], weights["up"] + down_w, up_w = np.array(down_w), np.array(up_w) # CPU convert + + # Squeeze dimensions if needed (Conv 1x1 or Linear) + if isinstance(module, nnx.Conv) and module.kernel_size == (1, 1): + down_w, up_w = np.squeeze(down_w), np.squeeze(up_w) + + rank = down_w.shape[0] if down_w.ndim > 0 else 0 + alpha = float(weights.get("alpha", rank)) + current_scale = scale * alpha / rank + + # Prepare Diff terms + w_diff = weights.get("diff", None) + b_diff = weights.get("diff_b", None) + + if w_diff is not None: + w_diff = np.array(w_diff) + # Transpose weights from PyTorch OIHW/OIDHW to Flax HWIO/DHWIO if needed. + if isinstance(module, nnx.Conv): + if w_diff.ndim == 5: + w_diff = w_diff.transpose((2, 3, 4, 1, 0)) + elif w_diff.ndim == 4: + w_diff = w_diff.transpose((2, 3, 1, 0)) + elif isinstance(module, nnx.Linear) and w_diff.ndim == 2: + w_diff = w_diff.transpose((1, 0)) + if b_diff is not None: + b_diff = np.array(b_diff) + + # If LoCON, compute delta and add to w_diff + if is_conv_kxk_locon: + dw, uw = np.array(weights["down"]), np.array(weights["up"]) + rank, in_c, *k_dims = dw.shape + out_c = uw.shape[0] + alpha = float(weights.get("alpha", rank)) + + delta_pt = (uw.reshape(out_c, rank) @ dw.reshape(rank, -1)).reshape(out_c, in_c, *k_dims) + + # Transpose to flax + if delta_pt.ndim == 5: + delta_fx = delta_pt.transpose((2, 3, 4, 1, 0)) + else: + delta_fx = delta_pt.transpose((2, 3, 1, 0)) + + lora_delta = delta_fx * (scale * alpha / rank) + if w_diff is None: + w_diff = lora_delta.astype(np.float32) + else: + w_diff += lora_delta.astype(w_diff.dtype) + + # Check for Bias existence + bias_val = module.bias.value if module.bias is not None else None + + # --- EXECUTE JIT UPDATE --- + if down_w is not None or w_diff is not None or b_diff is not None: + new_kernel, new_bias = _compute_and_add_single_jit( + module.kernel.value, bias_val, down_w, up_w, current_scale, w_diff, b_diff + ) - max_logging.log(f"Merged weights into {assigned_count} layers.") - unmatched_keys = set(lora_params.keys()) - matched_keys - if unmatched_keys: - max_logging.log(f"{len(unmatched_keys)} key(s) in LoRA dictionary were not applied to any layer in the model: {unmatched_keys}") + module.kernel.value = new_kernel + if new_bias is not None: + module.bias.value = new_bias + + assigned_count += 1 + else: + max_logging.log(f"Matched key {lora_key} but found no actionable weights.") + + max_logging.log(f"Merged weights into {assigned_count} layers.") + unmatched_keys = set(lora_params.keys()) - matched_keys + if unmatched_keys: + max_logging.log( + f"{len(unmatched_keys)} key(s) in LoRA dictionary were not applied to any layer in the model: {unmatched_keys}" + ) def merge_lora_for_scanned(model: nnx.Module, state_dict: dict, rank: int, scale: float, translate_fn=None): - """ - Device-Side Optimized Merge for Scanned Layers. - Now supports diff and diff_b. - """ - lora_params = parse_lora_dict(state_dict) - max_logging.log(f"Parsed {len(lora_params)} keys for scanned merge.") - matched_keys = set() - - assigned_count = 0 - for path, module in nnx.iter_graph(model): - if not isinstance(module, (nnx.Linear, nnx.Conv, nnx.LayerNorm, nnx.RMSNorm, nnx.Embed)): - continue - - nnx_path_str = ".".join(map(str, path)) - lora_key_template = translate_fn(nnx_path_str) if translate_fn else None - - if not lora_key_template: - continue - - # Determine if layer is scanned based on parameter dimensions - is_scanned = False - if isinstance(module, nnx.Embed) and hasattr(module, 'embedding'): - is_scanned = module.embedding.ndim > 2 - elif isinstance(module, (nnx.LayerNorm, nnx.RMSNorm)) and hasattr(module, 'scale') and module.scale is not None: - is_scanned = module.scale.ndim > 1 - elif isinstance(module, nnx.Linear): - is_scanned = module.kernel.ndim == 3 - elif isinstance(module, nnx.Conv): - is_scanned = module.kernel.ndim == 5 - - # If layer is not scanned, merge it using single-layer logic - if not is_scanned: - lora_key = lora_key_template - if lora_key in lora_params: - matched_keys.add(lora_key) - weights = lora_params[lora_key] - is_conv_kxk_locon = isinstance(module, nnx.Conv) and module.kernel_size != (1,1) and "down" in weights and "up" in weights - - if isinstance(module, nnx.Embed): - if "diff" in weights and hasattr(module, 'embedding'): - module.embedding.value += np.array(weights["diff"]).reshape(module.embedding.shape).astype(module.embedding.dtype) - assigned_count += 1 - elif isinstance(module, (nnx.LayerNorm, nnx.RMSNorm)): - scale_diff = weights.get("diff", None) - bias_diff = weights.get("diff_b", None) - updated = False - if scale_diff is not None and hasattr(module, 'scale') and module.scale is not None: - module.scale.value += np.array(scale_diff).reshape(module.scale.shape).astype(module.scale.dtype) - updated = True - if bias_diff is not None and isinstance(module, nnx.LayerNorm) and hasattr(module, 'bias') and module.bias is not None: - module.bias.value += np.array(bias_diff).reshape(module.bias.shape).astype(module.bias.dtype) - updated = True - if updated: - assigned_count += 1 - elif isinstance(module, (nnx.Linear, nnx.Conv)): - down_w, up_w, current_scale_ = None, None, None - if "down" in weights and "up" in weights and not is_conv_kxk_locon: - down_w, up_w = np.array(weights["down"]), np.array(weights["up"]) - if isinstance(module, nnx.Conv): - down_w, up_w = np.squeeze(down_w), np.squeeze(up_w) - rank, alpha = down_w.shape[0], float(weights.get("alpha", down_w.shape[0])) - current_scale_ = scale * alpha / rank - - w_diff, b_diff = weights.get("diff", None), weights.get("diff_b", None) - if w_diff is not None: - w_diff = np.array(w_diff) - if isinstance(module, nnx.Conv): - if w_diff.ndim == 5: - w_diff = w_diff.transpose((2,3,4,1,0)) - elif w_diff.ndim == 4: - w_diff = w_diff.transpose((2,3,1,0)) - elif isinstance(module, nnx.Linear) and w_diff.ndim == 2: - w_diff = w_diff.transpose((1,0)) - if b_diff is not None: - b_diff = np.array(b_diff) - if is_conv_kxk_locon: - dw, uw = np.array(weights['down']), np.array(weights['up']) - rank, in_c, *k_dims = dw.shape - out_c = uw.shape[0] - alpha = float(weights.get("alpha", rank)) - delta_pt = (uw.reshape(out_c, rank) @ dw.reshape(rank, -1)).reshape(out_c, in_c, *k_dims) - if delta_pt.ndim == 5: - delta_fx = delta_pt.transpose((2,3,4,1,0)) - else: - delta_fx = delta_pt.transpose((2,3,1,0)) - lora_delta = delta_fx * (scale * alpha / rank) - if w_diff is None: - w_diff = lora_delta.astype(np.float32) - else: - w_diff += lora_delta.astype(w_diff.dtype) - - bias_val = module.bias.value if module.bias is not None else None - if down_w is not None or w_diff is not None or b_diff is not None: - k, b = _compute_and_add_single_jit(module.kernel.value, bias_val, down_w, up_w, current_scale_, w_diff, b_diff) - module.kernel.value = k - if b is not None: - module.bias.value = b - assigned_count +=1 - continue - - # If we reach here, layer is SCANNED + """ + Device-Side Optimized Merge for Scanned Layers. + Now supports diff and diff_b. + """ + lora_params = parse_lora_dict(state_dict) + max_logging.log(f"Parsed {len(lora_params)} keys for scanned merge.") + matched_keys = set() + + assigned_count = 0 + for path, module in nnx.iter_graph(model): + if not isinstance(module, (nnx.Linear, nnx.Conv, nnx.LayerNorm, nnx.RMSNorm, nnx.Embed)): + continue + + nnx_path_str = ".".join(map(str, path)) + lora_key_template = translate_fn(nnx_path_str) if translate_fn else None + + if not lora_key_template: + continue + + # Determine if layer is scanned based on parameter dimensions + is_scanned = False + if isinstance(module, nnx.Embed) and hasattr(module, "embedding"): + is_scanned = module.embedding.ndim > 2 + elif isinstance(module, (nnx.LayerNorm, nnx.RMSNorm)) and hasattr(module, "scale") and module.scale is not None: + is_scanned = module.scale.ndim > 1 + elif isinstance(module, nnx.Linear): + is_scanned = module.kernel.ndim == 3 + elif isinstance(module, nnx.Conv): + is_scanned = module.kernel.ndim == 5 + + # If layer is not scanned, merge it using single-layer logic + if not is_scanned: + lora_key = lora_key_template + if lora_key in lora_params: + matched_keys.add(lora_key) + weights = lora_params[lora_key] + is_conv_kxk_locon = ( + isinstance(module, nnx.Conv) and module.kernel_size != (1, 1) and "down" in weights and "up" in weights + ) + if isinstance(module, nnx.Embed): - num_layers = module.embedding.shape[0] - embed_diffs_to_add = np.zeros_like(module.embedding.value) - updated = False - for i in range(num_layers): - lora_key = lora_key_template.format(i) - if lora_key in lora_params: - matched_keys.add(lora_key) - if "diff" in lora_params[lora_key]: - embed_diffs_to_add[i] = np.array(lora_params[lora_key]["diff"]).reshape(module.embedding.shape[1:]) - updated = True - if updated: - module.embedding.value += embed_diffs_to_add.astype(module.embedding.dtype) - assigned_count += 1 - continue + if "diff" in weights and hasattr(module, "embedding"): + module.embedding.value += ( + np.array(weights["diff"]).reshape(module.embedding.shape).astype(module.embedding.dtype) + ) + assigned_count += 1 elif isinstance(module, (nnx.LayerNorm, nnx.RMSNorm)): - num_layers = module.scale.shape[0] - scale_diffs_to_add = np.zeros_like(module.scale.value) - bias_diffs_to_add = np.zeros_like(module.bias.value) if isinstance(module, nnx.LayerNorm) and hasattr(module, 'bias') and module.bias is not None else None - updated_scale, updated_bias = False, False - for i in range(num_layers): - lora_key = lora_key_template.format(i) - if lora_key in lora_params: - matched_keys.add(lora_key) - weights = lora_params[lora_key] - if "diff" in weights: - scale_diffs_to_add[i] = np.array(weights["diff"]).reshape(module.scale.shape[1:]) - updated_scale = True - if "diff_b" in weights and bias_diffs_to_add is not None: - bias_diffs_to_add[i] = np.array(weights["diff_b"]).reshape(module.bias.shape[1:]) - updated_bias = True - if updated_scale: - module.scale.value += scale_diffs_to_add.astype(module.scale.dtype) - if updated_bias and bias_diffs_to_add is not None: - module.bias.value += bias_diffs_to_add.astype(module.bias.dtype) - if updated_scale or updated_bias: - assigned_count += 1 - continue + scale_diff = weights.get("diff", None) + bias_diff = weights.get("diff_b", None) + updated = False + if scale_diff is not None and hasattr(module, "scale") and module.scale is not None: + module.scale.value += np.array(scale_diff).reshape(module.scale.shape).astype(module.scale.dtype) + updated = True + if ( + bias_diff is not None + and isinstance(module, nnx.LayerNorm) + and hasattr(module, "bias") + and module.bias is not None + ): + module.bias.value += np.array(bias_diff).reshape(module.bias.shape).astype(module.bias.dtype) + updated = True + if updated: + assigned_count += 1 elif isinstance(module, (nnx.Linear, nnx.Conv)): - is_linear = isinstance(module, nnx.Linear) - is_conv = isinstance(module, nnx.Conv) - is_conv_kxk = isinstance(module, nnx.Conv) and module.kernel_size != (1,1) - if is_linear: - num_layers, in_feat, out_feat = module.kernel.shape - else: # Conv - num_layers = module.kernel.shape[0] - in_feat, out_feat = module.kernel.shape[3], module.kernel.shape[4] - else: - # Should not happen based on is_scanned logic - continue - - # 1. Scan for Rank (Fallback use rank in config file) - found_rank = rank - for i in range(num_layers): - k = lora_key_template.format(i) - if k in lora_params and "down" in lora_params[k]: - found_rank = lora_params[k]["down"].shape[0] - break - - # 2. Pre-allocate Buffers (CPU) - # LoRA Buffers - stack_down = np.zeros((num_layers, found_rank, in_feat), dtype=np.float32) - stack_up = np.zeros((num_layers, out_feat, found_rank), dtype=np.float32) - stack_alpha = np.zeros((num_layers, 1, 1), dtype=np.float32) - - # Diff Buffers - # Initialize as None, allocate only if found to save memory - stack_w_diff = None - stack_b_diff = None - - has_lora = False - has_diff = False - - for i in range(num_layers): - lora_key = lora_key_template.format(i) - if lora_key in lora_params: - matched_keys.add(lora_key) - w = lora_params[lora_key] - - # --- Fill LoRA --- - if "down" in w: - d, u = np.array(w["down"]), np.array(w["up"]) - alpha = float(w.get("alpha", d.shape[0])) - rank = d.shape[0] - - if is_conv_kxk: - # For LoCON kxk, compute delta and merge into stack_w_diff - rank, in_c, *k_dims = d.shape - out_c = u.shape[0] - delta_pt = (u.reshape(out_c, rank) @ d.reshape(rank, -1)).reshape(out_c, in_c, *k_dims) - if delta_pt.ndim == 5: - delta_fx = delta_pt.transpose((2,3,4,1,0)) - else: - delta_fx = delta_pt.transpose((2,3,1,0)) - - lora_delta = delta_fx * (scale * alpha / rank) - if stack_w_diff is None: - stack_w_diff = np.zeros(module.kernel.shape, dtype=np.float32) - stack_w_diff[i] += lora_delta.reshape(stack_w_diff[i].shape).astype(stack_w_diff.dtype) - has_diff = True # Mark as having diff because we merged LoRA into w_diff - else: - # For Linear or 1x1 Conv, prepare for JIT - if d.ndim > 2: - d = np.squeeze(d) - if u.ndim > 2: - u = np.squeeze(u) - stack_down[i] = d - stack_up[i] = u - stack_alpha[i] = alpha - has_lora = True - - # --- Fill Weight Diff --- - if "diff" in w: - if stack_w_diff is None: - stack_w_diff = np.zeros(module.kernel.shape, dtype=np.float32) - wd = np.array(w["diff"]) - # Transpose weights from PyTorch OIHW/OIDHW to Flax HWIO/DHWIO if needed. - if is_conv: - if wd.ndim == 5: - wd = wd.transpose((2,3,4,1,0)) - elif wd.ndim == 4: - wd = wd.transpose((2,3,1,0)) - elif is_linear and wd.ndim == 2: - wd = wd.transpose((1,0)) - - stack_w_diff[i] += wd.reshape(stack_w_diff[i].shape) - has_diff = True - - # --- Fill Bias Diff --- - if "diff_b" in w: - if stack_b_diff is None: - # Bias shape: Linear (L, Out), Conv (L, Out) usually - stack_b_diff = np.zeros((num_layers, out_feat), dtype=np.float32) - bd = np.array(w["diff_b"]) - stack_b_diff[i] = bd.flatten() - has_diff = True - - if has_lora or has_diff: - bias_val = module.bias.value if module.bias is not None else None - - # Call JIT - new_k, new_b = _compute_and_add_scanned_jit( - module.kernel.value, - stack_down if has_lora else None, - stack_up if has_lora else None, - stack_alpha if has_lora else None, - scale, - stack_w_diff, - stack_b_diff, - bias_val - ) - - module.kernel.value = new_k - if new_b is not None: - module.bias.value = new_b - + down_w, up_w, current_scale_ = None, None, None + if "down" in weights and "up" in weights and not is_conv_kxk_locon: + down_w, up_w = np.array(weights["down"]), np.array(weights["up"]) + if isinstance(module, nnx.Conv): + down_w, up_w = np.squeeze(down_w), np.squeeze(up_w) + rank, alpha = down_w.shape[0], float(weights.get("alpha", down_w.shape[0])) + current_scale_ = scale * alpha / rank + + w_diff, b_diff = weights.get("diff", None), weights.get("diff_b", None) + if w_diff is not None: + w_diff = np.array(w_diff) + if isinstance(module, nnx.Conv): + if w_diff.ndim == 5: + w_diff = w_diff.transpose((2, 3, 4, 1, 0)) + elif w_diff.ndim == 4: + w_diff = w_diff.transpose((2, 3, 1, 0)) + elif isinstance(module, nnx.Linear) and w_diff.ndim == 2: + w_diff = w_diff.transpose((1, 0)) + if b_diff is not None: + b_diff = np.array(b_diff) + if is_conv_kxk_locon: + dw, uw = np.array(weights["down"]), np.array(weights["up"]) + rank, in_c, *k_dims = dw.shape + out_c = uw.shape[0] + alpha = float(weights.get("alpha", rank)) + delta_pt = (uw.reshape(out_c, rank) @ dw.reshape(rank, -1)).reshape(out_c, in_c, *k_dims) + if delta_pt.ndim == 5: + delta_fx = delta_pt.transpose((2, 3, 4, 1, 0)) + else: + delta_fx = delta_pt.transpose((2, 3, 1, 0)) + lora_delta = delta_fx * (scale * alpha / rank) + if w_diff is None: + w_diff = lora_delta.astype(np.float32) + else: + w_diff += lora_delta.astype(w_diff.dtype) + + bias_val = module.bias.value if module.bias is not None else None + if down_w is not None or w_diff is not None or b_diff is not None: + k, b = _compute_and_add_single_jit(module.kernel.value, bias_val, down_w, up_w, current_scale_, w_diff, b_diff) + module.kernel.value = k + if b is not None: + module.bias.value = b assigned_count += 1 - - max_logging.log(f"Merged weights into {assigned_count} scanned layers.") - unmatched_keys = set(lora_params.keys()) - matched_keys - if unmatched_keys: - max_logging.log(f"{len(unmatched_keys)} key(s) in LoRA dictionary were not applied to any layer in the model: {unmatched_keys}") + continue + + # If we reach here, layer is SCANNED + if isinstance(module, nnx.Embed): + num_layers = module.embedding.shape[0] + embed_diffs_to_add = np.zeros_like(module.embedding.value) + updated = False + for i in range(num_layers): + lora_key = lora_key_template.format(i) + if lora_key in lora_params: + matched_keys.add(lora_key) + if "diff" in lora_params[lora_key]: + embed_diffs_to_add[i] = np.array(lora_params[lora_key]["diff"]).reshape(module.embedding.shape[1:]) + updated = True + if updated: + module.embedding.value += embed_diffs_to_add.astype(module.embedding.dtype) + assigned_count += 1 + continue + elif isinstance(module, (nnx.LayerNorm, nnx.RMSNorm)): + num_layers = module.scale.shape[0] + scale_diffs_to_add = np.zeros_like(module.scale.value) + bias_diffs_to_add = ( + np.zeros_like(module.bias.value) + if isinstance(module, nnx.LayerNorm) and hasattr(module, "bias") and module.bias is not None + else None + ) + updated_scale, updated_bias = False, False + for i in range(num_layers): + lora_key = lora_key_template.format(i) + if lora_key in lora_params: + matched_keys.add(lora_key) + weights = lora_params[lora_key] + if "diff" in weights: + scale_diffs_to_add[i] = np.array(weights["diff"]).reshape(module.scale.shape[1:]) + updated_scale = True + if "diff_b" in weights and bias_diffs_to_add is not None: + bias_diffs_to_add[i] = np.array(weights["diff_b"]).reshape(module.bias.shape[1:]) + updated_bias = True + if updated_scale: + module.scale.value += scale_diffs_to_add.astype(module.scale.dtype) + if updated_bias and bias_diffs_to_add is not None: + module.bias.value += bias_diffs_to_add.astype(module.bias.dtype) + if updated_scale or updated_bias: + assigned_count += 1 + continue + elif isinstance(module, (nnx.Linear, nnx.Conv)): + is_linear = isinstance(module, nnx.Linear) + is_conv = isinstance(module, nnx.Conv) + is_conv_kxk = isinstance(module, nnx.Conv) and module.kernel_size != (1, 1) + if is_linear: + num_layers, in_feat, out_feat = module.kernel.shape + else: # Conv + num_layers = module.kernel.shape[0] + in_feat, out_feat = module.kernel.shape[3], module.kernel.shape[4] + else: + # Should not happen based on is_scanned logic + continue + + # 1. Scan for Rank (Fallback use rank in config file) + found_rank = rank + for i in range(num_layers): + k = lora_key_template.format(i) + if k in lora_params and "down" in lora_params[k]: + found_rank = lora_params[k]["down"].shape[0] + break + + # 2. Pre-allocate Buffers (CPU) + # LoRA Buffers + stack_down = np.zeros((num_layers, found_rank, in_feat), dtype=np.float32) + stack_up = np.zeros((num_layers, out_feat, found_rank), dtype=np.float32) + stack_alpha = np.zeros((num_layers, 1, 1), dtype=np.float32) + + # Diff Buffers + # Initialize as None, allocate only if found to save memory + stack_w_diff = None + stack_b_diff = None + + has_lora = False + has_diff = False + + for i in range(num_layers): + lora_key = lora_key_template.format(i) + if lora_key in lora_params: + matched_keys.add(lora_key) + w = lora_params[lora_key] + + # --- Fill LoRA --- + if "down" in w: + d, u = np.array(w["down"]), np.array(w["up"]) + alpha = float(w.get("alpha", d.shape[0])) + rank = d.shape[0] + + if is_conv_kxk: + # For LoCON kxk, compute delta and merge into stack_w_diff + rank, in_c, *k_dims = d.shape + out_c = u.shape[0] + delta_pt = (u.reshape(out_c, rank) @ d.reshape(rank, -1)).reshape(out_c, in_c, *k_dims) + if delta_pt.ndim == 5: + delta_fx = delta_pt.transpose((2, 3, 4, 1, 0)) + else: + delta_fx = delta_pt.transpose((2, 3, 1, 0)) + + lora_delta = delta_fx * (scale * alpha / rank) + if stack_w_diff is None: + stack_w_diff = np.zeros(module.kernel.shape, dtype=np.float32) + stack_w_diff[i] += lora_delta.reshape(stack_w_diff[i].shape).astype(stack_w_diff.dtype) + has_diff = True # Mark as having diff because we merged LoRA into w_diff + else: + # For Linear or 1x1 Conv, prepare for JIT + if d.ndim > 2: + d = np.squeeze(d) + if u.ndim > 2: + u = np.squeeze(u) + stack_down[i] = d + stack_up[i] = u + stack_alpha[i] = alpha + has_lora = True + + # --- Fill Weight Diff --- + if "diff" in w: + if stack_w_diff is None: + stack_w_diff = np.zeros(module.kernel.shape, dtype=np.float32) + wd = np.array(w["diff"]) + # Transpose weights from PyTorch OIHW/OIDHW to Flax HWIO/DHWIO if needed. + if is_conv: + if wd.ndim == 5: + wd = wd.transpose((2, 3, 4, 1, 0)) + elif wd.ndim == 4: + wd = wd.transpose((2, 3, 1, 0)) + elif is_linear and wd.ndim == 2: + wd = wd.transpose((1, 0)) + + stack_w_diff[i] += wd.reshape(stack_w_diff[i].shape) + has_diff = True + + # --- Fill Bias Diff --- + if "diff_b" in w: + if stack_b_diff is None: + # Bias shape: Linear (L, Out), Conv (L, Out) usually + stack_b_diff = np.zeros((num_layers, out_feat), dtype=np.float32) + bd = np.array(w["diff_b"]) + stack_b_diff[i] = bd.flatten() + has_diff = True + + if has_lora or has_diff: + bias_val = module.bias.value if module.bias is not None else None + + # Call JIT + new_k, new_b = _compute_and_add_scanned_jit( + module.kernel.value, + stack_down if has_lora else None, + stack_up if has_lora else None, + stack_alpha if has_lora else None, + scale, + stack_w_diff, + stack_b_diff, + bias_val, + ) + + module.kernel.value = new_k + if new_b is not None: + module.bias.value = new_b + + assigned_count += 1 + + max_logging.log(f"Merged weights into {assigned_count} scanned layers.") + unmatched_keys = set(lora_params.keys()) - matched_keys + if unmatched_keys: + max_logging.log( + f"{len(unmatched_keys)} key(s) in LoRA dictionary were not applied to any layer in the model: {unmatched_keys}" + ) From 41bedde342b5994b150792e37f3f7071e8e3ef10 Mon Sep 17 00:00:00 2001 From: Rishabh Manoj Date: Fri, 23 Jan 2026 09:46:58 +0000 Subject: [PATCH 3/3] Changed lora weights --- src/maxdiffusion/configs/base_wan_i2v_27b.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/maxdiffusion/configs/base_wan_i2v_27b.yml b/src/maxdiffusion/configs/base_wan_i2v_27b.yml index afe8bc54..86fb74cd 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_27b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_27b.yml @@ -316,10 +316,10 @@ enable_lora: False # Values are lists to support multiple LoRA loading during inference in the future. lora_config: { rank: [64], - lora_model_name_or_path: ["lightx2v/Wan2.2-Distill-Loras"], - high_noise_weight_name: ["wan2.2_i2v_A14b_high_noise_lora_rank64_lightx2v_4step_1022.safetensors"], - low_noise_weight_name: ["wan2.2_i2v_A14b_low_noise_lora_rank64_lightx2v_4step_1022.safetensors"], - adapter_name: ["wan22-distill-lora"], + lora_model_name_or_path: ["lightx2v/Wan2.2-Lightning"], + high_noise_weight_name: ["Wan2.2-I2V-A14B-4steps-lora-rank64-Seko-V1/high_noise_model.safetensors"], + low_noise_weight_name: ["Wan2.2-I2V-A14B-4steps-lora-rank64-Seko-V1/low_noise_model.safetensors"], + adapter_name: ["wan22-lightning-lora"], scale: [1.0], from_pt: [] }