diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 390ea3c6..bb9a8a3b 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -317,12 +317,14 @@ lightning_repo: "" lightning_ckpt: "" # LoRA parameters +enable_lora: False # Values are lists to support multiple LoRA loading during inference in the future. lora_config: { - lora_model_name_or_path: [], - weight_name: [], - adapter_name: [], - scale: [], + rank: [64], + lora_model_name_or_path: ["lightx2v/Wan2.1-Distill-Loras"], + weight_name: ["wan2.1_t2v_14b_lora_rank64_lightx2v_4step.safetensors"], + adapter_name: ["wan21-distill-lora"], + scale: [1.0], from_pt: [] } # Ex with values: diff --git a/src/maxdiffusion/configs/base_wan_27b.yml b/src/maxdiffusion/configs/base_wan_27b.yml index f2839cba..e3f8e5a9 100644 --- a/src/maxdiffusion/configs/base_wan_27b.yml +++ b/src/maxdiffusion/configs/base_wan_27b.yml @@ -316,12 +316,15 @@ lightning_repo: "" lightning_ckpt: "" # LoRA parameters +enable_lora: False # Values are lists to support multiple LoRA loading during inference in the future. lora_config: { - lora_model_name_or_path: [], - weight_name: [], - adapter_name: [], - scale: [], + rank: [64], + lora_model_name_or_path: ["lightx2v/Wan2.2-Distill-Loras"], + high_noise_weight_name: ["wan2.2_t2v_A14b_high_noise_lora_rank64_lightx2v_4step_1217.safetensors"], + low_noise_weight_name: ["wan2.2_t2v_A14b_low_noise_lora_rank64_lightx2v_4step_1217.safetensors"], + adapter_name: ["wan22-distill-lora"], + scale: [1.0], from_pt: [] } # Ex with values: diff --git a/src/maxdiffusion/configs/base_wan_i2v_14b.yml b/src/maxdiffusion/configs/base_wan_i2v_14b.yml index 07a84419..09c05cfd 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_14b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_14b.yml @@ -300,12 +300,14 @@ lightning_repo: "" lightning_ckpt: "" # LoRA parameters +enable_lora: False # Values are lists to support multiple LoRA loading during inference in the future. lora_config: { - lora_model_name_or_path: [], - weight_name: [], - adapter_name: [], - scale: [], + rank: [64], + lora_model_name_or_path: ["lightx2v/Wan2.1-Distill-Loras"], + weight_name: ["wan2.1_i2v_lora_rank64_lightx2v_4step.safetensors"], + adapter_name: ["wan21-distill-lora-i2v"], + scale: [1.0], from_pt: [] } # Ex with values: diff --git a/src/maxdiffusion/configs/base_wan_i2v_27b.yml b/src/maxdiffusion/configs/base_wan_i2v_27b.yml index bcc69e66..86fb74cd 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_27b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_27b.yml @@ -312,12 +312,15 @@ lightning_repo: "" lightning_ckpt: "" # LoRA parameters +enable_lora: False # Values are lists to support multiple LoRA loading during inference in the future. lora_config: { - lora_model_name_or_path: [], - weight_name: [], - adapter_name: [], - scale: [], + rank: [64], + lora_model_name_or_path: ["lightx2v/Wan2.2-Lightning"], + high_noise_weight_name: ["Wan2.2-I2V-A14B-4steps-lora-rank64-Seko-V1/high_noise_model.safetensors"], + low_noise_weight_name: ["Wan2.2-I2V-A14B-4steps-lora-rank64-Seko-V1/low_noise_model.safetensors"], + adapter_name: ["wan22-lightning-lora"], + scale: [1.0], from_pt: [] } # Ex with values: diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index b739a6f3..fef2d4fa 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -28,6 +28,7 @@ from google.cloud import storage import flax from maxdiffusion.common_types import WAN2_1, WAN2_2 +from maxdiffusion.loaders.wan_lora_nnx_loader import Wan2_1NnxLoraLoader, Wan2_2NnxLoraLoader def upload_video_to_gcs(output_dir: str, video_path: str): @@ -190,6 +191,42 @@ def run(config, pipeline=None, filename_prefix=""): else: raise ValueError(f"Unsupported model_name for checkpointer: {model_key}") pipeline, _, _ = checkpoint_loader.load_checkpoint() + + # If LoRA is specified, inject layers and load weights. + if config.enable_lora and hasattr(config, "lora_config") and config.lora_config and config.lora_config["lora_model_name_or_path"]: + if model_key == WAN2_1: + lora_loader = Wan2_1NnxLoraLoader() + lora_config = config.lora_config + + if len(lora_config["lora_model_name_or_path"]) > 1: + max_logging.log("Found multiple LoRAs in config, but only loading the first one.") + + pipeline = lora_loader.load_lora_weights( + pipeline, + lora_config["lora_model_name_or_path"][0], + transformer_weight_name=lora_config["weight_name"][0], + rank=lora_config["rank"][0], + scale=lora_config["scale"][0], + scan_layers=config.scan_layers, + ) + + if model_key == WAN2_2: + lora_loader = Wan2_2NnxLoraLoader() + lora_config = config.lora_config + + if len(lora_config["lora_model_name_or_path"]) > 1: + max_logging.log("Found multiple LoRAs in config, but only loading the first one.") + + pipeline = lora_loader.load_lora_weights( + pipeline, + lora_config["lora_model_name_or_path"][0], + high_noise_weight_name=lora_config["high_noise_weight_name"][0], + low_noise_weight_name=lora_config["low_noise_weight_name"][0], + rank=lora_config["rank"][0], + scale=lora_config["scale"][0], + scan_layers=config.scan_layers, + ) + s0 = time.perf_counter() # Using global_batch_size_to_train_on so not to create more config variables diff --git a/src/maxdiffusion/loaders/__init__.py b/src/maxdiffusion/loaders/__init__.py index 2c9e973d..e7abb88a 100644 --- a/src/maxdiffusion/loaders/__init__.py +++ b/src/maxdiffusion/loaders/__init__.py @@ -14,3 +14,4 @@ from .lora_pipeline import StableDiffusionLoraLoaderMixin from .flux_lora_pipeline import FluxLoraLoaderMixin +from .wan_lora_nnx_loader import Wan2_1NnxLoraLoader, Wan2_2NnxLoraLoader diff --git a/src/maxdiffusion/loaders/lora_conversion_utils.py b/src/maxdiffusion/loaders/lora_conversion_utils.py index 534a440d..55e5f8f8 100644 --- a/src/maxdiffusion/loaders/lora_conversion_utils.py +++ b/src/maxdiffusion/loaders/lora_conversion_utils.py @@ -608,3 +608,82 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None): raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.") return new_state_dict + + +def translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False): + """ + Translates WAN NNX path to Diffusers/LoRA keys. + Verified against wan_utils.py mappings. + """ + + # --- 1. Embeddings (Exact Matches) --- + if nnx_path_str == "condition_embedder.text_embedder.linear_1": + return "diffusion_model.text_embedding.0" + if nnx_path_str == "condition_embedder.text_embedder.linear_2": + return "diffusion_model.text_embedding.2" + if nnx_path_str == "condition_embedder.time_embedder.linear_1": + return "diffusion_model.time_embedding.0" + if nnx_path_str == "condition_embedder.time_embedder.linear_2": + return "diffusion_model.time_embedding.2" + if nnx_path_str == "condition_embedder.image_embedder.norm1.layer_norm": + return "diffusion_model.img_emb.proj.0" + if nnx_path_str == "condition_embedder.image_embedder.ff.net_0": + return "diffusion_model.img_emb.proj.1" + if nnx_path_str == "condition_embedder.image_embedder.ff.net_2": + return "diffusion_model.img_emb.proj.3" + if nnx_path_str == "condition_embedder.image_embedder.norm2.layer_norm": + return "diffusion_model.img_emb.proj.4" + if nnx_path_str == "patch_embedding": + return "diffusion_model.patch_embedding" + if nnx_path_str == "proj_out": + return "diffusion_model.head.head" + if nnx_path_str == "condition_embedder.time_proj": + return "diffusion_model.time_projection.1" + + # --- 2. Map NNX Suffixes to LoRA Suffixes --- + suffix_map = { + # Self Attention (attn1) + "attn1.query": "self_attn.q", + "attn1.key": "self_attn.k", + "attn1.value": "self_attn.v", + "attn1.proj_attn": "self_attn.o", + # Self Attention Norms (QK Norm) + "attn1.norm_q": "self_attn.norm_q", + "attn1.norm_k": "self_attn.norm_k", + # Cross Attention (attn2) + "attn2.query": "cross_attn.q", + "attn2.key": "cross_attn.k", + "attn2.value": "cross_attn.v", + "attn2.proj_attn": "cross_attn.o", + # Cross Attention Norms (QK Norm) + "attn2.norm_q": "cross_attn.norm_q", + "attn2.norm_k": "cross_attn.norm_k", + # Cross Attention img + "attn2.add_k_proj": "cross_attn.k_img", + "attn2.add_v_proj": "cross_attn.v_img", + "attn2.norm_added_k": "cross_attn.norm_k_img", + # Feed Forward (ffn) + "ffn.act_fn.proj": "ffn.0", # Up proj + "ffn.proj_out": "ffn.2", # Down proj + # Global Norms & Modulation + "norm2.layer_norm": "norm3", + "scale_shift_table": "modulation", + "proj_out": "head.head", + } + + # --- 3. Translation Logic --- + if scan_layers: + # Scanned Pattern: "blocks.attn1.query" -> "diffusion_model.blocks.{}.self_attn.q" + if nnx_path_str.startswith("blocks."): + inner_suffix = nnx_path_str[len("blocks.") :] + if inner_suffix in suffix_map: + return f"diffusion_model.blocks.{{}}.{suffix_map[inner_suffix]}" + else: + # Unscanned Pattern: "blocks.0.attn1.query" -> "diffusion_model.blocks.0.self_attn.q" + m = re.match(r"^blocks\.(\d+)\.(.+)$", nnx_path_str) + if m: + idx, inner_suffix = m.group(1), m.group(2) + if inner_suffix in suffix_map: + return f"diffusion_model.blocks.{idx}.{suffix_map[inner_suffix]}" + + return None diff --git a/src/maxdiffusion/loaders/wan_lora_nnx_loader.py b/src/maxdiffusion/loaders/wan_lora_nnx_loader.py new file mode 100644 index 00000000..2fe691c5 --- /dev/null +++ b/src/maxdiffusion/loaders/wan_lora_nnx_loader.py @@ -0,0 +1,107 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""NNX-based LoRA loader for WAN models.""" + +from flax import nnx +from .lora_base import LoRABaseMixin +from .lora_pipeline import StableDiffusionLoraLoaderMixin +from ..models import lora_nnx +from .. import max_logging +from . import lora_conversion_utils + + +class Wan2_1NnxLoraLoader(LoRABaseMixin): + """ + Handles loading LoRA weights into NNX-based WAN 2.1 model. + Assumes WAN pipeline contains 'transformer' + attributes that are NNX Modules. + """ + + def load_lora_weights( + self, + pipeline: nnx.Module, + lora_model_path: str, + transformer_weight_name: str, + rank: int, + scale: float = 1.0, + scan_layers: bool = False, + **kwargs, + ): + """ + Merges LoRA weights into the pipeline from a checkpoint. + """ + lora_loader = StableDiffusionLoraLoaderMixin() + + merge_fn = lora_nnx.merge_lora_for_scanned if scan_layers else lora_nnx.merge_lora + + def translate_fn(nnx_path_str): + return lora_conversion_utils.translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=scan_layers) + + # Handle high noise model + if hasattr(pipeline, "transformer") and transformer_weight_name: + max_logging.log(f"Merging LoRA into transformer with rank={rank}") + h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=transformer_weight_name, **kwargs) + merge_fn(pipeline.transformer, h_state_dict, rank, scale, translate_fn) + else: + max_logging.log("transformer not found or no weight name provided for LoRA.") + + return pipeline + + +class Wan2_2NnxLoraLoader(LoRABaseMixin): + """ + Handles loading LoRA weights into NNX-based WAN 2.2 model. + Assumes WAN pipeline contains 'high_noise_transformer' and 'low_noise_transformer' + attributes that are NNX Modules. + """ + + def load_lora_weights( + self, + pipeline: nnx.Module, + lora_model_path: str, + high_noise_weight_name: str, + low_noise_weight_name: str, + rank: int, + scale: float = 1.0, + scan_layers: bool = False, + **kwargs, + ): + """ + Merges LoRA weights into the pipeline from a checkpoint. + """ + lora_loader = StableDiffusionLoraLoaderMixin() + + merge_fn = lora_nnx.merge_lora_for_scanned if scan_layers else lora_nnx.merge_lora + + def translate_fn(nnx_path_str: str): + return lora_conversion_utils.translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=scan_layers) + + # Handle high noise model + if hasattr(pipeline, "high_noise_transformer") and high_noise_weight_name: + max_logging.log(f"Merging LoRA into high_noise_transformer with rank={rank}") + h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=high_noise_weight_name, **kwargs) + merge_fn(pipeline.high_noise_transformer, h_state_dict, rank, scale, translate_fn) + else: + max_logging.log("high_noise_transformer not found or no weight name provided for LoRA.") + + # Handle low noise model + if hasattr(pipeline, "low_noise_transformer") and low_noise_weight_name: + max_logging.log(f"Merging LoRA into low_noise_transformer with rank={rank}") + l_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=low_noise_weight_name, **kwargs) + merge_fn(pipeline.low_noise_transformer, l_state_dict, rank, scale, translate_fn) + else: + max_logging.log("low_noise_transformer not found or no weight name provided for LoRA.") + + return pipeline diff --git a/src/maxdiffusion/models/lora_nnx.py b/src/maxdiffusion/models/lora_nnx.py new file mode 100644 index 00000000..58ec94b6 --- /dev/null +++ b/src/maxdiffusion/models/lora_nnx.py @@ -0,0 +1,542 @@ +""" +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import re +import torch +import jax +from jax import dlpack +import jax.numpy as jnp +from flax import nnx +from .. import max_logging +import numpy as np + +# ----------------------------------------------------------------------------- +# JIT Helpers (The Fix for Sharding & Device-Side Computation) +# ----------------------------------------------------------------------------- + + +@jax.jit +def _compute_and_add_single_jit(kernel, bias, down, up, scale, w_diff, b_diff): + """ + Applies LoRA + Weight Diff + Bias Diff on device. + """ + # 1. Apply LoRA (if valid) + if down is not None and up is not None: + # down: (Rank, In), up: (Out, Rank) -> Result: (In, Out) + # Note: We reshape to kernel shape to handle 1x1 convs + delta = (down.T @ up.T).reshape(kernel.shape) + kernel = kernel + (delta * scale).astype(kernel.dtype) + + # 2. Apply Full Weight Diff (if valid) + if w_diff is not None: + kernel = kernel + w_diff.astype(kernel.dtype) + + # 3. Apply Bias Diff (if valid and bias exists) + if bias is not None and b_diff is not None: + bias = bias + b_diff.astype(bias.dtype) + + return kernel, bias + + +@jax.jit +def _compute_and_add_scanned_jit(kernel, downs, ups, alphas, global_scale, w_diffs=None, b_diffs=None, bias=None): + """ + Applies scanned LoRA + Diffs. + """ + # 1. Apply LoRA + if downs is not None and ups is not None: + rank = downs.shape[1] + scales = global_scale * alphas / rank + # Batch Matmul: (L, In, Out) + delta = jnp.matmul(jnp.swapaxes(downs, 1, 2), jnp.swapaxes(ups, 1, 2)) + delta = (delta * scales).astype(kernel.dtype) + kernel = kernel + delta.reshape(kernel.shape) + + # 2. Apply Scanned Weight Diffs (L, ...) + if w_diffs is not None: + kernel = kernel + w_diffs.astype(kernel.dtype) + + # 3. Apply Scanned Bias Diffs (L, ...) + # Note: Scanned bias is usually shape (L, Out) + if bias is not None and b_diffs is not None: + bias = bias + b_diffs.astype(bias.dtype) + + return kernel, bias + + +# ----------------------------------------------------------------------------- + + +def _to_jax_array(v): + if isinstance(v, torch.Tensor): + return dlpack.from_dlpack(v) + return jnp.array(v) + + +def parse_lora_dict(state_dict): + """Helper to parse state_dict into structured params including diffs.""" + lora_params = {} + for k, v in state_dict.items(): + # Alpha + if k.endswith(".alpha"): + key_base = k[: -len(".alpha")] + if key_base not in lora_params: + lora_params[key_base] = {} + lora_params[key_base]["alpha"] = _to_jax_array(v) + continue + + # Bias Diff (e.g., "layer.diff_b") + if k.endswith(".diff_b"): + key_base = k[: -len(".diff_b")] + if key_base not in lora_params: + lora_params[key_base] = {} + lora_params[key_base]["diff_b"] = _to_jax_array(v) + continue + + # Weight Diff (e.g., "layer.diff") + if k.endswith(".diff"): + key_base = k[: -len(".diff")] + if key_base not in lora_params: + lora_params[key_base] = {} + lora_params[key_base]["diff"] = _to_jax_array(v) + continue + + # Standard LoRA + m = re.match(r"^(.*?)_lora\.(down|up)\.weight$", k) + if not m: + m = re.match(r"^(.*?)\.lora\.(down|up)\.weight$", k) + if not m: + m = re.match(r"^(.*?)\.(lora_down|lora_up)\.weight$", k) + + if m: + key_base, weight_type = m.group(1), m.group(2).replace("lora_", "") + if key_base not in lora_params: + lora_params[key_base] = {} + lora_params[key_base][weight_type] = _to_jax_array(v) + else: + # Fallback for exact matches of diffs if regex failed above + pass + + return lora_params + + +def merge_lora(model: nnx.Module, state_dict: dict, rank: int, scale: float, translate_fn=None): + """ + Merges weights for non-scanned layers (Embeddings, singular Dense, etc). + Now supports diff and diff_b. + """ + lora_params = parse_lora_dict(state_dict) + max_logging.log(f"Parsed {len(lora_params)} unique module keys.") + matched_keys = set() + + assigned_count = 0 + for path, module in nnx.iter_graph(model): + if not isinstance(module, (nnx.Linear, nnx.Conv, nnx.LayerNorm, nnx.RMSNorm, nnx.Embed)): + continue + + nnx_path_str = ".".join(map(str, path)) + lora_key = translate_fn(nnx_path_str) if translate_fn else None + + if lora_key and lora_key in lora_params: + matched_keys.add(lora_key) + weights = lora_params[lora_key] + + is_conv_kxk_locon = False + if isinstance(module, nnx.Conv) and module.kernel_size != (1, 1) and "down" in weights and "up" in weights: + is_conv_kxk_locon = True + + # Handle Embeddings + if isinstance(module, nnx.Embed): + if "diff" in weights and hasattr(module, "embedding"): + module.embedding.value += np.array(weights["diff"]).reshape(module.embedding.shape).astype(module.embedding.dtype) + assigned_count += 1 + continue + # Handle Norms + elif isinstance(module, (nnx.LayerNorm, nnx.RMSNorm)): + scale_diff = weights.get("diff", None) + bias_diff = weights.get("diff_b", None) + updated = False + if scale_diff is not None and hasattr(module, "scale") and module.scale is not None: + module.scale.value += np.array(scale_diff).reshape(module.scale.shape).astype(module.scale.dtype) + updated = True + if ( + bias_diff is not None + and isinstance(module, nnx.LayerNorm) + and hasattr(module, "bias") + and module.bias is not None + ): + module.bias.value += np.array(bias_diff).reshape(module.bias.shape).astype(module.bias.dtype) + updated = True + if updated: + assigned_count += 1 + continue + + # Prepare LoRA terms + down_w, up_w, current_scale = None, None, None + if "down" in weights and "up" in weights and not is_conv_kxk_locon: + down_w, up_w = weights["down"], weights["up"] + down_w, up_w = np.array(down_w), np.array(up_w) # CPU convert + + # Squeeze dimensions if needed (Conv 1x1 or Linear) + if isinstance(module, nnx.Conv) and module.kernel_size == (1, 1): + down_w, up_w = np.squeeze(down_w), np.squeeze(up_w) + + rank = down_w.shape[0] if down_w.ndim > 0 else 0 + alpha = float(weights.get("alpha", rank)) + current_scale = scale * alpha / rank + + # Prepare Diff terms + w_diff = weights.get("diff", None) + b_diff = weights.get("diff_b", None) + + if w_diff is not None: + w_diff = np.array(w_diff) + # Transpose weights from PyTorch OIHW/OIDHW to Flax HWIO/DHWIO if needed. + if isinstance(module, nnx.Conv): + if w_diff.ndim == 5: + w_diff = w_diff.transpose((2, 3, 4, 1, 0)) + elif w_diff.ndim == 4: + w_diff = w_diff.transpose((2, 3, 1, 0)) + elif isinstance(module, nnx.Linear) and w_diff.ndim == 2: + w_diff = w_diff.transpose((1, 0)) + if b_diff is not None: + b_diff = np.array(b_diff) + + # If LoCON, compute delta and add to w_diff + if is_conv_kxk_locon: + dw, uw = np.array(weights["down"]), np.array(weights["up"]) + rank, in_c, *k_dims = dw.shape + out_c = uw.shape[0] + alpha = float(weights.get("alpha", rank)) + + delta_pt = (uw.reshape(out_c, rank) @ dw.reshape(rank, -1)).reshape(out_c, in_c, *k_dims) + + # Transpose to flax + if delta_pt.ndim == 5: + delta_fx = delta_pt.transpose((2, 3, 4, 1, 0)) + else: + delta_fx = delta_pt.transpose((2, 3, 1, 0)) + + lora_delta = delta_fx * (scale * alpha / rank) + if w_diff is None: + w_diff = lora_delta.astype(np.float32) + else: + w_diff += lora_delta.astype(w_diff.dtype) + + # Check for Bias existence + bias_val = module.bias.value if module.bias is not None else None + + # --- EXECUTE JIT UPDATE --- + if down_w is not None or w_diff is not None or b_diff is not None: + new_kernel, new_bias = _compute_and_add_single_jit( + module.kernel.value, bias_val, down_w, up_w, current_scale, w_diff, b_diff + ) + + module.kernel.value = new_kernel + if new_bias is not None: + module.bias.value = new_bias + + assigned_count += 1 + else: + max_logging.log(f"Matched key {lora_key} but found no actionable weights.") + + max_logging.log(f"Merged weights into {assigned_count} layers.") + unmatched_keys = set(lora_params.keys()) - matched_keys + if unmatched_keys: + max_logging.log( + f"{len(unmatched_keys)} key(s) in LoRA dictionary were not applied to any layer in the model: {unmatched_keys}" + ) + + +def merge_lora_for_scanned(model: nnx.Module, state_dict: dict, rank: int, scale: float, translate_fn=None): + """ + Device-Side Optimized Merge for Scanned Layers. + Now supports diff and diff_b. + """ + lora_params = parse_lora_dict(state_dict) + max_logging.log(f"Parsed {len(lora_params)} keys for scanned merge.") + matched_keys = set() + + assigned_count = 0 + for path, module in nnx.iter_graph(model): + if not isinstance(module, (nnx.Linear, nnx.Conv, nnx.LayerNorm, nnx.RMSNorm, nnx.Embed)): + continue + + nnx_path_str = ".".join(map(str, path)) + lora_key_template = translate_fn(nnx_path_str) if translate_fn else None + + if not lora_key_template: + continue + + # Determine if layer is scanned based on parameter dimensions + is_scanned = False + if isinstance(module, nnx.Embed) and hasattr(module, "embedding"): + is_scanned = module.embedding.ndim > 2 + elif isinstance(module, (nnx.LayerNorm, nnx.RMSNorm)) and hasattr(module, "scale") and module.scale is not None: + is_scanned = module.scale.ndim > 1 + elif isinstance(module, nnx.Linear): + is_scanned = module.kernel.ndim == 3 + elif isinstance(module, nnx.Conv): + is_scanned = module.kernel.ndim == 5 + + # If layer is not scanned, merge it using single-layer logic + if not is_scanned: + lora_key = lora_key_template + if lora_key in lora_params: + matched_keys.add(lora_key) + weights = lora_params[lora_key] + is_conv_kxk_locon = ( + isinstance(module, nnx.Conv) and module.kernel_size != (1, 1) and "down" in weights and "up" in weights + ) + + if isinstance(module, nnx.Embed): + if "diff" in weights and hasattr(module, "embedding"): + module.embedding.value += ( + np.array(weights["diff"]).reshape(module.embedding.shape).astype(module.embedding.dtype) + ) + assigned_count += 1 + elif isinstance(module, (nnx.LayerNorm, nnx.RMSNorm)): + scale_diff = weights.get("diff", None) + bias_diff = weights.get("diff_b", None) + updated = False + if scale_diff is not None and hasattr(module, "scale") and module.scale is not None: + module.scale.value += np.array(scale_diff).reshape(module.scale.shape).astype(module.scale.dtype) + updated = True + if ( + bias_diff is not None + and isinstance(module, nnx.LayerNorm) + and hasattr(module, "bias") + and module.bias is not None + ): + module.bias.value += np.array(bias_diff).reshape(module.bias.shape).astype(module.bias.dtype) + updated = True + if updated: + assigned_count += 1 + elif isinstance(module, (nnx.Linear, nnx.Conv)): + down_w, up_w, current_scale_ = None, None, None + if "down" in weights and "up" in weights and not is_conv_kxk_locon: + down_w, up_w = np.array(weights["down"]), np.array(weights["up"]) + if isinstance(module, nnx.Conv): + down_w, up_w = np.squeeze(down_w), np.squeeze(up_w) + rank, alpha = down_w.shape[0], float(weights.get("alpha", down_w.shape[0])) + current_scale_ = scale * alpha / rank + + w_diff, b_diff = weights.get("diff", None), weights.get("diff_b", None) + if w_diff is not None: + w_diff = np.array(w_diff) + if isinstance(module, nnx.Conv): + if w_diff.ndim == 5: + w_diff = w_diff.transpose((2, 3, 4, 1, 0)) + elif w_diff.ndim == 4: + w_diff = w_diff.transpose((2, 3, 1, 0)) + elif isinstance(module, nnx.Linear) and w_diff.ndim == 2: + w_diff = w_diff.transpose((1, 0)) + if b_diff is not None: + b_diff = np.array(b_diff) + if is_conv_kxk_locon: + dw, uw = np.array(weights["down"]), np.array(weights["up"]) + rank, in_c, *k_dims = dw.shape + out_c = uw.shape[0] + alpha = float(weights.get("alpha", rank)) + delta_pt = (uw.reshape(out_c, rank) @ dw.reshape(rank, -1)).reshape(out_c, in_c, *k_dims) + if delta_pt.ndim == 5: + delta_fx = delta_pt.transpose((2, 3, 4, 1, 0)) + else: + delta_fx = delta_pt.transpose((2, 3, 1, 0)) + lora_delta = delta_fx * (scale * alpha / rank) + if w_diff is None: + w_diff = lora_delta.astype(np.float32) + else: + w_diff += lora_delta.astype(w_diff.dtype) + + bias_val = module.bias.value if module.bias is not None else None + if down_w is not None or w_diff is not None or b_diff is not None: + k, b = _compute_and_add_single_jit(module.kernel.value, bias_val, down_w, up_w, current_scale_, w_diff, b_diff) + module.kernel.value = k + if b is not None: + module.bias.value = b + assigned_count += 1 + continue + + # If we reach here, layer is SCANNED + if isinstance(module, nnx.Embed): + num_layers = module.embedding.shape[0] + embed_diffs_to_add = np.zeros_like(module.embedding.value) + updated = False + for i in range(num_layers): + lora_key = lora_key_template.format(i) + if lora_key in lora_params: + matched_keys.add(lora_key) + if "diff" in lora_params[lora_key]: + embed_diffs_to_add[i] = np.array(lora_params[lora_key]["diff"]).reshape(module.embedding.shape[1:]) + updated = True + if updated: + module.embedding.value += embed_diffs_to_add.astype(module.embedding.dtype) + assigned_count += 1 + continue + elif isinstance(module, (nnx.LayerNorm, nnx.RMSNorm)): + num_layers = module.scale.shape[0] + scale_diffs_to_add = np.zeros_like(module.scale.value) + bias_diffs_to_add = ( + np.zeros_like(module.bias.value) + if isinstance(module, nnx.LayerNorm) and hasattr(module, "bias") and module.bias is not None + else None + ) + updated_scale, updated_bias = False, False + for i in range(num_layers): + lora_key = lora_key_template.format(i) + if lora_key in lora_params: + matched_keys.add(lora_key) + weights = lora_params[lora_key] + if "diff" in weights: + scale_diffs_to_add[i] = np.array(weights["diff"]).reshape(module.scale.shape[1:]) + updated_scale = True + if "diff_b" in weights and bias_diffs_to_add is not None: + bias_diffs_to_add[i] = np.array(weights["diff_b"]).reshape(module.bias.shape[1:]) + updated_bias = True + if updated_scale: + module.scale.value += scale_diffs_to_add.astype(module.scale.dtype) + if updated_bias and bias_diffs_to_add is not None: + module.bias.value += bias_diffs_to_add.astype(module.bias.dtype) + if updated_scale or updated_bias: + assigned_count += 1 + continue + elif isinstance(module, (nnx.Linear, nnx.Conv)): + is_linear = isinstance(module, nnx.Linear) + is_conv = isinstance(module, nnx.Conv) + is_conv_kxk = isinstance(module, nnx.Conv) and module.kernel_size != (1, 1) + if is_linear: + num_layers, in_feat, out_feat = module.kernel.shape + else: # Conv + num_layers = module.kernel.shape[0] + in_feat, out_feat = module.kernel.shape[3], module.kernel.shape[4] + else: + # Should not happen based on is_scanned logic + continue + + # 1. Scan for Rank (Fallback use rank in config file) + found_rank = rank + for i in range(num_layers): + k = lora_key_template.format(i) + if k in lora_params and "down" in lora_params[k]: + found_rank = lora_params[k]["down"].shape[0] + break + + # 2. Pre-allocate Buffers (CPU) + # LoRA Buffers + stack_down = np.zeros((num_layers, found_rank, in_feat), dtype=np.float32) + stack_up = np.zeros((num_layers, out_feat, found_rank), dtype=np.float32) + stack_alpha = np.zeros((num_layers, 1, 1), dtype=np.float32) + + # Diff Buffers + # Initialize as None, allocate only if found to save memory + stack_w_diff = None + stack_b_diff = None + + has_lora = False + has_diff = False + + for i in range(num_layers): + lora_key = lora_key_template.format(i) + if lora_key in lora_params: + matched_keys.add(lora_key) + w = lora_params[lora_key] + + # --- Fill LoRA --- + if "down" in w: + d, u = np.array(w["down"]), np.array(w["up"]) + alpha = float(w.get("alpha", d.shape[0])) + rank = d.shape[0] + + if is_conv_kxk: + # For LoCON kxk, compute delta and merge into stack_w_diff + rank, in_c, *k_dims = d.shape + out_c = u.shape[0] + delta_pt = (u.reshape(out_c, rank) @ d.reshape(rank, -1)).reshape(out_c, in_c, *k_dims) + if delta_pt.ndim == 5: + delta_fx = delta_pt.transpose((2, 3, 4, 1, 0)) + else: + delta_fx = delta_pt.transpose((2, 3, 1, 0)) + + lora_delta = delta_fx * (scale * alpha / rank) + if stack_w_diff is None: + stack_w_diff = np.zeros(module.kernel.shape, dtype=np.float32) + stack_w_diff[i] += lora_delta.reshape(stack_w_diff[i].shape).astype(stack_w_diff.dtype) + has_diff = True # Mark as having diff because we merged LoRA into w_diff + else: + # For Linear or 1x1 Conv, prepare for JIT + if d.ndim > 2: + d = np.squeeze(d) + if u.ndim > 2: + u = np.squeeze(u) + stack_down[i] = d + stack_up[i] = u + stack_alpha[i] = alpha + has_lora = True + + # --- Fill Weight Diff --- + if "diff" in w: + if stack_w_diff is None: + stack_w_diff = np.zeros(module.kernel.shape, dtype=np.float32) + wd = np.array(w["diff"]) + # Transpose weights from PyTorch OIHW/OIDHW to Flax HWIO/DHWIO if needed. + if is_conv: + if wd.ndim == 5: + wd = wd.transpose((2, 3, 4, 1, 0)) + elif wd.ndim == 4: + wd = wd.transpose((2, 3, 1, 0)) + elif is_linear and wd.ndim == 2: + wd = wd.transpose((1, 0)) + + stack_w_diff[i] += wd.reshape(stack_w_diff[i].shape) + has_diff = True + + # --- Fill Bias Diff --- + if "diff_b" in w: + if stack_b_diff is None: + # Bias shape: Linear (L, Out), Conv (L, Out) usually + stack_b_diff = np.zeros((num_layers, out_feat), dtype=np.float32) + bd = np.array(w["diff_b"]) + stack_b_diff[i] = bd.flatten() + has_diff = True + + if has_lora or has_diff: + bias_val = module.bias.value if module.bias is not None else None + + # Call JIT + new_k, new_b = _compute_and_add_scanned_jit( + module.kernel.value, + stack_down if has_lora else None, + stack_up if has_lora else None, + stack_alpha if has_lora else None, + scale, + stack_w_diff, + stack_b_diff, + bias_val, + ) + + module.kernel.value = new_k + if new_b is not None: + module.bias.value = new_b + + assigned_count += 1 + + max_logging.log(f"Merged weights into {assigned_count} scanned layers.") + unmatched_keys = set(lora_params.keys()) - matched_keys + if unmatched_keys: + max_logging.log( + f"{len(unmatched_keys)} key(s) in LoRA dictionary were not applied to any layer in the model: {unmatched_keys}" + )