Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 75 additions & 21 deletions loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .dequant import is_quantized, dequantize_tensor

IMG_ARCH_LIST = {"flux", "sd1", "sdxl", "sd3", "aura", "hidream", "cosmos", "ltxv", "hyvid", "wan", "lumina2", "qwen_image"}
TXT_ARCH_LIST = {"t5", "t5encoder", "llama", "qwen2vl"}
TXT_ARCH_LIST = {"t5", "t5encoder", "llama", "qwen2vl", "gemma2"}
VIS_TYPE_LIST = {"clip-vision", "mmproj"}

def get_orig_shape(reader, tensor_name):
Expand Down Expand Up @@ -170,6 +170,13 @@ def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", return_arch=Fal
"output.weight": "lm_head.weight",
}

GEMMA_SD_MAP = LLAMA_SD_MAP.copy()
GEMMA_SD_MAP.update({
"ffn_norm": "pre_feedforward_layernorm",
"post_ffw_norm": "post_feedforward_layernorm",
"post_attention_norm": "post_attention_layernorm",
})

CLIP_VISION_SD_MAP = {
"mm.": "visual.merger.mlp.",
"v.post_ln.": "visual.merger.ln_q.",
Expand Down Expand Up @@ -287,26 +294,32 @@ def gguf_tokenizer_loader(path, temb_shape):

reader = gguf.GGUFReader(path)

if get_field(reader, "tokenizer.ggml.model", str) == "t5":
model_str = get_field(reader, "tokenizer.ggml.model", str)
if model_str == "t5":
if temb_shape == (256384, 4096): # probably UMT5
spm.trainer_spec.model_type == 1 # Unigram (do we have a T5 w/ BPE?)
spm.trainer_spec.max_sentence_length = 4096
else:
raise NotImplementedError("Unknown model, can't set tokenizer!")
elif model_str == "llama":
if temb_shape == (256000, 2304): # probably gemma
spm.trainer_spec.model_type == 2 # BPE
# TODO: something is missing, can't match 1:1
spm.trainer_spec.max_sentence_length = 0
spm.trainer_spec.max_sentencepiece_length = 16
spm.trainer_spec.split_digits = True
spm.trainer_spec.allow_whitespace_only_pieces = True
else:
raise NotImplementedError("Unknown model, can't set tokenizer!")

spm.normalizer_spec.add_dummy_prefix = get_field(reader, "tokenizer.ggml.add_space_prefix", bool)
spm.normalizer_spec.remove_extra_whitespaces = get_field(reader, "tokenizer.ggml.remove_extra_whitespaces", bool)
spm.normalizer_spec.add_dummy_prefix = get_field(reader, "tokenizer.ggml.add_space_prefix", bool) or False
spm.normalizer_spec.remove_extra_whitespaces = get_field(reader, "tokenizer.ggml.remove_extra_whitespaces", bool) or False

tokens = get_list_field(reader, "tokenizer.ggml.tokens", str)
scores = get_list_field(reader, "tokenizer.ggml.scores", float)
toktypes = get_list_field(reader, "tokenizer.ggml.token_type", int)

for idx, (token, score, toktype) in enumerate(zip(tokens, scores, toktypes)):
# # These aren't present in the original?
# if toktype == 5 and idx >= temb_shape[0]%1000):
# continue

for token, score, toktype in zip(tokens, scores, toktypes):
piece = spm.SentencePiece()
piece.piece = token
piece.score = score
Expand All @@ -315,39 +328,80 @@ def gguf_tokenizer_loader(path, temb_shape):

# unsure if any of these are correct
spm.trainer_spec.byte_fallback = True
spm.trainer_spec.vocab_size = len(tokens) # split off unused?
spm.trainer_spec.max_sentence_length = 4096
spm.trainer_spec.eos_id = get_field(reader, "tokenizer.ggml.eos_token_id", int)
spm.trainer_spec.pad_id = get_field(reader, "tokenizer.ggml.padding_token_id", int)
spm.trainer_spec.vocab_size = len(tokens)

# map special token IDs
tok_map = {
"bos_id": "tokenizer.ggml.bos_token_id",
"eos_id": "tokenizer.ggml.eos_token_id",
"pad_id": "tokenizer.ggml.padding_token_id",
"unk_id": "tokenizer.ggml.unknown_token_id",
}
for sp, gg in tok_map.items():
val = get_field(reader, gg, int)
if val is not None:
logging.debug(f"setting sp:{sp} to {val}")
setattr(spm.trainer_spec, sp, val)

# fix special token
if model_str == "llama" and hasattr(spm.trainer_spec, "unk_id"):
spm.pieces[spm.trainer_spec.unk_id].type = 2
for p in tok_map.keys():
if hasattr(spm.trainer_spec, p):
val = spm.pieces[getattr(spm.trainer_spec, p)].piece
setattr(spm.trainer_spec, p.replace("_id", "_piece"), val)

if temb_shape == (256000, 2304):
# for some reason the ggml tokenizer has these set to -1000 instead of 0...?
for p in spm.pieces:
if p.score == -1000 and p.type > 1:
p.score = 0.0

logging.info(f"Created tokenizer with vocab size of {len(spm.pieces)}")
del reader
return torch.ByteTensor(list(spm.SerializeToString()))

def dequantize_temb(sd, temb_key="token_embd.weight"):
# TODO: dequantizing token embed here is janky but otherwise we OOM due to tensor being massive.
if temb_key in sd and is_quantized(sd[temb_key]):
logging.warning(f"Dequantizing {temb_key} to prevent runtime OOM.")
sd[temb_key] = dequantize_tensor(sd[temb_key], dtype=torch.float16)
return sd

def gguf_clip_loader(path):
sd, arch = gguf_sd_loader(path, return_arch=True, is_text_model=True)
temb_key = "token_embd.weight"
if arch in {"t5", "t5encoder"}:
temb_key = "token_embd.weight"
if temb_key in sd and sd[temb_key].shape == (256384, 4096):
# non-standard Comfy-Org tokenizer
sd["spiece_model"] = gguf_tokenizer_loader(path, sd[temb_key].shape)
# TODO: dequantizing token embed here is janky but otherwise we OOM due to tensor being massive.
logging.warning(f"Dequantizing {temb_key} to prevent runtime OOM.")
sd[temb_key] = dequantize_tensor(sd[temb_key], dtype=torch.float16)
sd = dequantize_temb(sd, temb_key)
sd = sd_map_replace(sd, T5_SD_MAP)
elif arch in {"llama", "qwen2vl"}:
# TODO: pass model_options["vocab_size"] to loader somehow
temb_key = "token_embd.weight"
if temb_key in sd and sd[temb_key].shape[0] >= (64 * 1024):
# See note above for T5.
logging.warning(f"Dequantizing {temb_key} to prevent runtime OOM.")
sd[temb_key] = dequantize_tensor(sd[temb_key], dtype=torch.float16)
sd = dequantize_temb(sd, temb_key)
sd = sd_map_replace(sd, LLAMA_SD_MAP)
if arch == "llama":
sd = llama_permute(sd, 32, 8) # L3
if arch == "qwen2vl":
vsd = gguf_mmproj_loader(path)
sd.update(vsd)
elif arch in {"gemma2"}:
if temb_key in sd:
# non-standard Comfy-Org tokenizer
sd["spiece_model"] = gguf_tokenizer_loader(path, sd[temb_key].shape)
# # TODO: for verifying tokenizer accuracy, remove this
# from safetensors.torch import load_file
# sd["spiece_model"] = load_file(r"models\clip\gemma_2_2b_fp16.safetensors")["spiece_model"]
sd = dequantize_temb(sd, temb_key)
sd = sd_map_replace(sd, GEMMA_SD_MAP)
# Reverse change from Gemma2Model.modify_tensors in convert_hf_to_gguf.py
for k,v in sd.items():
if k.endswith("norm.weight"):
if is_quantized(v):
v = dequantize_tensor(v, torch.float16)
sd[k] = v - 1.0
else:
pass
return sd