diff --git a/loader.py b/loader.py index 1948027..a17c8ab 100644 --- a/loader.py +++ b/loader.py @@ -48,7 +48,26 @@ def get_list_field(reader, field_name, field_type): else: raise TypeError(f"Unknown field type {field_type}") -def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", return_arch=False, is_text_model=False): +def get_gguf_metadata(reader): + """Extract all simple metadata fields like safetensors""" + metadata = {} + for field_name in reader.fields: + try: + field = reader.get_field(field_name) + if len(field.types) == 1: # Simple scalar fields only + if field.types[0] == gguf.GGUFValueType.STRING: + metadata[field_name] = str(field.parts[field.data[-1]], "utf-8") + elif field.types[0] == gguf.GGUFValueType.INT32: + metadata[field_name] = int(field.parts[field.data[-1]]) + elif field.types[0] == gguf.GGUFValueType.F32: + metadata[field_name] = float(field.parts[field.data[-1]]) + elif field.types[0] == gguf.GGUFValueType.BOOL: + metadata[field_name] = bool(field.parts[field.data[-1]]) + except: + continue + return metadata + +def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", is_text_model=False): """ Read state dict as fake tensors """ @@ -136,9 +155,12 @@ def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", return_arch=Fal max_key = max(qsd.keys(), key=lambda k: qsd[k].numel()) state_dict[max_key].is_largest_weight = True - if return_arch: - return (state_dict, arch_str) - return state_dict + # extra info to return + extra = { + "arch_str": arch_str, + "metadata": get_gguf_metadata(reader) + } + return (state_dict, extra) # for remapping llama.cpp -> original key names T5_SD_MAP = { @@ -246,7 +268,7 @@ def gguf_mmproj_loader(path): logging.info(f"Using mmproj '{target[0]}' for text encoder '{tenc_fname}'.") target = os.path.join(root, target[0]) - vsd = gguf_sd_loader(target, is_text_model=True) + vsd, _ = gguf_sd_loader(target, is_text_model=True) # concat 4D to 5D if "v.patch_embd.weight.1" in vsd: @@ -375,7 +397,8 @@ def gguf_tekken_tokenizer_loader(path, temb_shape): return torch.ByteTensor(list(json.dumps(data).encode('utf-8'))) def gguf_clip_loader(path): - sd, arch = gguf_sd_loader(path, return_arch=True, is_text_model=True) + sd, extra = gguf_sd_loader(path, is_text_model=True) + arch = extra.get("arch_str", None) if arch in {"t5", "t5encoder"}: temb_key = "token_embd.weight" if temb_key in sd and sd[temb_key].shape == (256384, 4096): diff --git a/nodes.py b/nodes.py index ff5aaf0..50d3a60 100644 --- a/nodes.py +++ b/nodes.py @@ -165,9 +165,9 @@ def load_unet(self, unet_name, dequant_dtype=None, patch_dtype=None, patch_on_de # init model unet_path = folder_paths.get_full_path("unet", unet_name) - sd = gguf_sd_loader(unet_path) + sd, extra = gguf_sd_loader(unet_path) model = comfy.sd.load_diffusion_model_state_dict( - sd, model_options={"custom_operations": ops} + sd, model_options={"custom_operations": ops}, metadata=extra.get("metadata", {}) ) if model is None: logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path)) @@ -319,3 +319,4 @@ def load_clip(self, clip_name1, clip_name2, clip_name3, clip_name4, type="stable "QuadrupleCLIPLoaderGGUF": QuadrupleCLIPLoaderGGUF, "UnetLoaderGGUFAdvanced": UnetLoaderGGUFAdvanced, } +