diff --git a/loader.py b/loader.py index 37ef133..7cefb11 100644 --- a/loader.py +++ b/loader.py @@ -10,7 +10,7 @@ from .dequant import is_quantized, dequantize_tensor IMG_ARCH_LIST = {"flux", "sd1", "sdxl", "sd3", "aura", "hidream", "cosmos", "ltxv", "hyvid", "wan", "lumina2", "qwen_image"} -TXT_ARCH_LIST = {"t5", "t5encoder", "llama", "qwen2vl", "qwen3", "qwen3vl"} +TXT_ARCH_LIST = {"t5", "t5encoder", "llama", "qwen2vl", "qwen3", "qwen3vl", "gemma3"} VIS_TYPE_LIST = {"clip-vision", "mmproj"} def get_orig_shape(reader, tensor_name): @@ -199,6 +199,13 @@ def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", is_text_model=F "output.weight": "lm_head.weight", } +GEMMA3_SD_MAP = LLAMA_SD_MAP.copy() +GEMMA3_SD_MAP.update({ + "ffn_norm": "pre_feedforward_layernorm", + "post_ffw_norm": "post_feedforward_layernorm", + "post_attention_norm": "post_attention_layernorm", +}) + CLIP_VISION_SD_MAP = { "mm.": "visual.merger.mlp.", "v.post_ln.": "visual.merger.ln_q.", @@ -232,6 +239,28 @@ def llama_permute(raw_sd, n_head, n_head_kv): sd[k] = v return sd +def gemma3_norm_corrections(sd): + # Reverse change from Gemma3Model modify_tensors in llama.cpp convert script + norm_patterns = [ + "input_layernorm.weight", + "post_attention_layernorm.weight", + "pre_feedforward_layernorm.weight", + "post_feedforward_layernorm.weight", + "self_attn.q_norm.weight", + "self_attn.k_norm.weight", + "model.norm.weight" + ] + corrected = 0 + for key in list(sd.keys()): + if any(p in key for p in norm_patterns): + if is_quantized(sd[key]): + sd[key] = dequantize_tensor(sd[key], dtype=torch.float32) - 1.0 + else: + sd[key] = sd[key].float() - 1.0 + corrected += 1 + #logging.info(f"Gemma3: Applied -1 norm correction to {corrected} tensors") + return sd + def strip_quant_suffix(name): pattern = r"[-_]?(?:ud-)?i?q[0-9]_[a-z0-9_\-]{1,8}$" match = re.search(pattern, name, re.IGNORECASE) @@ -396,6 +425,48 @@ def gguf_tekken_tokenizer_loader(path, temb_shape): del reader return torch.ByteTensor(list(json.dumps(data).encode('utf-8'))) +def gguf_gemma3_tokenizer_loader(path): + #TODO: merge into gguf_tokenizer_loader + logging.info("Attempting to recreate sentencepiece tokenizer from GGUF file metadata...") + try: + from sentencepiece import sentencepiece_model_pb2 as model + except ImportError: + raise ImportError("Please install sentencepiece and protobuf.\npip install sentencepiece protobuf") + spm = model.ModelProto() + reader = gguf.GGUFReader(path) + + spm.normalizer_spec.name = "identity" + spm.normalizer_spec.add_dummy_prefix = False + spm.trainer_spec.model_type = 2 + spm.trainer_spec.input_format = "tsv" + spm.trainer_spec.byte_fallback = True + spm.trainer_spec.max_sentence_length = 4192 + spm.trainer_spec.bos_piece = "" + + tokens = get_list_field(reader, "tokenizer.ggml.tokens", str) + scores = get_list_field(reader, "tokenizer.ggml.scores", float) + toktype = get_list_field(reader, "tokenizer.ggml.token_type", int) + + if not tokens or not scores or not toktype: + raise ValueError("Missing tokenizer metadata") + + for idx in range(len(tokens)): + piece = spm.SentencePiece() + piece.piece = tokens[idx] + if idx == 3: # UNK position + piece.type = 2 # UNK Token + piece.score = 0.0 # UNK Score + else: + piece.type = toktype[idx] + piece.score = scores[idx] + spm.pieces.append(piece) + + spm.trainer_spec.vocab_size = len(spm.pieces) + logging.info(f"Created tokenizer with vocab size of {len(spm.pieces)}") + + del reader + return torch.ByteTensor(list(spm.SerializeToString())) + def gguf_clip_loader(path): sd, extra = gguf_sd_loader(path, is_text_model=True) arch = extra.get("arch_str", None) @@ -408,17 +479,23 @@ def gguf_clip_loader(path): logging.warning(f"Dequantizing {temb_key} to prevent runtime OOM.") sd[temb_key] = dequantize_tensor(sd[temb_key], dtype=torch.float16) sd = sd_map_replace(sd, T5_SD_MAP) - elif arch in {"llama", "qwen2vl", "qwen3", "qwen3vl"}: + elif arch in {"llama", "qwen2vl", "qwen3", "qwen3vl", "gemma3"}: # TODO: pass model_options["vocab_size"] to loader somehow temb_key = "token_embd.weight" if temb_key in sd and sd[temb_key].shape[0] >= (64 * 1024): if arch == "llama" and sd[temb_key].shape == (131072, 5120): # non-standard Comfy-Org tokenizer sd["tekken_model"] = gguf_tekken_tokenizer_loader(path, sd[temb_key].shape) + elif arch == "gemma3": + sd["spiece_model"] = gguf_gemma3_tokenizer_loader(path) # See note above for T5. logging.warning(f"Dequantizing {temb_key} to prevent runtime OOM.") sd[temb_key] = dequantize_tensor(sd[temb_key], dtype=torch.float16) - sd = sd_map_replace(sd, LLAMA_SD_MAP) + if arch == "gemma3": + sd = sd_map_replace(sd, GEMMA3_SD_MAP) + sd = gemma3_norm_corrections(sd) + else: + sd = sd_map_replace(sd, LLAMA_SD_MAP) if arch == "llama": sd = llama_permute(sd, 32, 8) # L3 / Mistral if arch == "qwen2vl": @@ -427,4 +504,3 @@ def gguf_clip_loader(path): else: pass return sd -