Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,26 @@ def get_list_field(reader, field_name, field_type):
else:
raise TypeError(f"Unknown field type {field_type}")

def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", return_arch=False, is_text_model=False):
def get_gguf_metadata(reader):
"""Extract all simple metadata fields like safetensors"""
metadata = {}
for field_name in reader.fields:
try:
field = reader.get_field(field_name)
if len(field.types) == 1: # Simple scalar fields only
if field.types[0] == gguf.GGUFValueType.STRING:
metadata[field_name] = str(field.parts[field.data[-1]], "utf-8")
elif field.types[0] == gguf.GGUFValueType.INT32:
metadata[field_name] = int(field.parts[field.data[-1]])
elif field.types[0] == gguf.GGUFValueType.F32:
metadata[field_name] = float(field.parts[field.data[-1]])
elif field.types[0] == gguf.GGUFValueType.BOOL:
metadata[field_name] = bool(field.parts[field.data[-1]])
except:
continue
return metadata

def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", is_text_model=False):
"""
Read state dict as fake tensors
"""
Expand Down Expand Up @@ -136,9 +155,12 @@ def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", return_arch=Fal
max_key = max(qsd.keys(), key=lambda k: qsd[k].numel())
state_dict[max_key].is_largest_weight = True

if return_arch:
return (state_dict, arch_str)
return state_dict
# extra info to return
extra = {
"arch_str": arch_str,
"metadata": get_gguf_metadata(reader)
}
return (state_dict, extra)

# for remapping llama.cpp -> original key names
T5_SD_MAP = {
Expand Down Expand Up @@ -246,7 +268,7 @@ def gguf_mmproj_loader(path):

logging.info(f"Using mmproj '{target[0]}' for text encoder '{tenc_fname}'.")
target = os.path.join(root, target[0])
vsd = gguf_sd_loader(target, is_text_model=True)
vsd, _ = gguf_sd_loader(target, is_text_model=True)

# concat 4D to 5D
if "v.patch_embd.weight.1" in vsd:
Expand Down Expand Up @@ -375,7 +397,8 @@ def gguf_tekken_tokenizer_loader(path, temb_shape):
return torch.ByteTensor(list(json.dumps(data).encode('utf-8')))

def gguf_clip_loader(path):
sd, arch = gguf_sd_loader(path, return_arch=True, is_text_model=True)
sd, extra = gguf_sd_loader(path, is_text_model=True)
arch = extra.get("arch_str", None)
if arch in {"t5", "t5encoder"}:
temb_key = "token_embd.weight"
if temb_key in sd and sd[temb_key].shape == (256384, 4096):
Expand Down
5 changes: 3 additions & 2 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,9 @@ def load_unet(self, unet_name, dequant_dtype=None, patch_dtype=None, patch_on_de

# init model
unet_path = folder_paths.get_full_path("unet", unet_name)
sd = gguf_sd_loader(unet_path)
sd, extra = gguf_sd_loader(unet_path)
model = comfy.sd.load_diffusion_model_state_dict(
sd, model_options={"custom_operations": ops}
sd, model_options={"custom_operations": ops}, metadata=extra.get("metadata", {})
)
if model is None:
logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path))
Expand Down Expand Up @@ -319,3 +319,4 @@ def load_clip(self, clip_name1, clip_name2, clip_name3, clip_name4, type="stable
"QuadrupleCLIPLoaderGGUF": QuadrupleCLIPLoaderGGUF,
"UnetLoaderGGUFAdvanced": UnetLoaderGGUFAdvanced,
}