diff --git a/loader.py b/loader.py index 032c739..40ae8af 100644 --- a/loader.py +++ b/loader.py @@ -10,7 +10,7 @@ from .dequant import is_quantized, dequantize_tensor IMG_ARCH_LIST = {"flux", "sd1", "sdxl", "sd3", "aura", "hidream", "cosmos", "ltxv", "hyvid", "wan", "lumina2", "qwen_image"} -TXT_ARCH_LIST = {"t5", "t5encoder", "llama", "qwen2vl"} +TXT_ARCH_LIST = {"t5", "t5encoder", "llama", "qwen2vl", "gemma2"} VIS_TYPE_LIST = {"clip-vision", "mmproj"} def get_orig_shape(reader, tensor_name): @@ -170,6 +170,13 @@ def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", return_arch=Fal "output.weight": "lm_head.weight", } +GEMMA_SD_MAP = LLAMA_SD_MAP.copy() +GEMMA_SD_MAP.update({ + "ffn_norm": "pre_feedforward_layernorm", + "post_ffw_norm": "post_feedforward_layernorm", + "post_attention_norm": "post_attention_layernorm", +}) + CLIP_VISION_SD_MAP = { "mm.": "visual.merger.mlp.", "v.post_ln.": "visual.merger.ln_q.", @@ -287,26 +294,32 @@ def gguf_tokenizer_loader(path, temb_shape): reader = gguf.GGUFReader(path) - if get_field(reader, "tokenizer.ggml.model", str) == "t5": + model_str = get_field(reader, "tokenizer.ggml.model", str) + if model_str == "t5": if temb_shape == (256384, 4096): # probably UMT5 spm.trainer_spec.model_type == 1 # Unigram (do we have a T5 w/ BPE?) + spm.trainer_spec.max_sentence_length = 4096 else: raise NotImplementedError("Unknown model, can't set tokenizer!") + elif model_str == "llama": + if temb_shape == (256000, 2304): # probably gemma + spm.trainer_spec.model_type == 2 # BPE + # TODO: something is missing, can't match 1:1 + spm.trainer_spec.max_sentence_length = 0 + spm.trainer_spec.max_sentencepiece_length = 16 + spm.trainer_spec.split_digits = True + spm.trainer_spec.allow_whitespace_only_pieces = True else: raise NotImplementedError("Unknown model, can't set tokenizer!") - spm.normalizer_spec.add_dummy_prefix = get_field(reader, "tokenizer.ggml.add_space_prefix", bool) - spm.normalizer_spec.remove_extra_whitespaces = get_field(reader, "tokenizer.ggml.remove_extra_whitespaces", bool) + spm.normalizer_spec.add_dummy_prefix = get_field(reader, "tokenizer.ggml.add_space_prefix", bool) or False + spm.normalizer_spec.remove_extra_whitespaces = get_field(reader, "tokenizer.ggml.remove_extra_whitespaces", bool) or False tokens = get_list_field(reader, "tokenizer.ggml.tokens", str) scores = get_list_field(reader, "tokenizer.ggml.scores", float) toktypes = get_list_field(reader, "tokenizer.ggml.token_type", int) - for idx, (token, score, toktype) in enumerate(zip(tokens, scores, toktypes)): - # # These aren't present in the original? - # if toktype == 5 and idx >= temb_shape[0]%1000): - # continue - + for token, score, toktype in zip(tokens, scores, toktypes): piece = spm.SentencePiece() piece.piece = token piece.score = score @@ -315,39 +328,80 @@ def gguf_tokenizer_loader(path, temb_shape): # unsure if any of these are correct spm.trainer_spec.byte_fallback = True - spm.trainer_spec.vocab_size = len(tokens) # split off unused? - spm.trainer_spec.max_sentence_length = 4096 - spm.trainer_spec.eos_id = get_field(reader, "tokenizer.ggml.eos_token_id", int) - spm.trainer_spec.pad_id = get_field(reader, "tokenizer.ggml.padding_token_id", int) + spm.trainer_spec.vocab_size = len(tokens) + + # map special token IDs + tok_map = { + "bos_id": "tokenizer.ggml.bos_token_id", + "eos_id": "tokenizer.ggml.eos_token_id", + "pad_id": "tokenizer.ggml.padding_token_id", + "unk_id": "tokenizer.ggml.unknown_token_id", + } + for sp, gg in tok_map.items(): + val = get_field(reader, gg, int) + if val is not None: + logging.debug(f"setting sp:{sp} to {val}") + setattr(spm.trainer_spec, sp, val) + + # fix special token + if model_str == "llama" and hasattr(spm.trainer_spec, "unk_id"): + spm.pieces[spm.trainer_spec.unk_id].type = 2 + for p in tok_map.keys(): + if hasattr(spm.trainer_spec, p): + val = spm.pieces[getattr(spm.trainer_spec, p)].piece + setattr(spm.trainer_spec, p.replace("_id", "_piece"), val) + + if temb_shape == (256000, 2304): + # for some reason the ggml tokenizer has these set to -1000 instead of 0...? + for p in spm.pieces: + if p.score == -1000 and p.type > 1: + p.score = 0.0 logging.info(f"Created tokenizer with vocab size of {len(spm.pieces)}") del reader return torch.ByteTensor(list(spm.SerializeToString())) +def dequantize_temb(sd, temb_key="token_embd.weight"): + # TODO: dequantizing token embed here is janky but otherwise we OOM due to tensor being massive. + if temb_key in sd and is_quantized(sd[temb_key]): + logging.warning(f"Dequantizing {temb_key} to prevent runtime OOM.") + sd[temb_key] = dequantize_tensor(sd[temb_key], dtype=torch.float16) + return sd + def gguf_clip_loader(path): sd, arch = gguf_sd_loader(path, return_arch=True, is_text_model=True) + temb_key = "token_embd.weight" if arch in {"t5", "t5encoder"}: - temb_key = "token_embd.weight" if temb_key in sd and sd[temb_key].shape == (256384, 4096): # non-standard Comfy-Org tokenizer sd["spiece_model"] = gguf_tokenizer_loader(path, sd[temb_key].shape) - # TODO: dequantizing token embed here is janky but otherwise we OOM due to tensor being massive. - logging.warning(f"Dequantizing {temb_key} to prevent runtime OOM.") - sd[temb_key] = dequantize_tensor(sd[temb_key], dtype=torch.float16) + sd = dequantize_temb(sd, temb_key) sd = sd_map_replace(sd, T5_SD_MAP) elif arch in {"llama", "qwen2vl"}: # TODO: pass model_options["vocab_size"] to loader somehow - temb_key = "token_embd.weight" if temb_key in sd and sd[temb_key].shape[0] >= (64 * 1024): - # See note above for T5. - logging.warning(f"Dequantizing {temb_key} to prevent runtime OOM.") - sd[temb_key] = dequantize_tensor(sd[temb_key], dtype=torch.float16) + sd = dequantize_temb(sd, temb_key) sd = sd_map_replace(sd, LLAMA_SD_MAP) if arch == "llama": sd = llama_permute(sd, 32, 8) # L3 if arch == "qwen2vl": vsd = gguf_mmproj_loader(path) sd.update(vsd) + elif arch in {"gemma2"}: + if temb_key in sd: + # non-standard Comfy-Org tokenizer + sd["spiece_model"] = gguf_tokenizer_loader(path, sd[temb_key].shape) + # # TODO: for verifying tokenizer accuracy, remove this + # from safetensors.torch import load_file + # sd["spiece_model"] = load_file(r"models\clip\gemma_2_2b_fp16.safetensors")["spiece_model"] + sd = dequantize_temb(sd, temb_key) + sd = sd_map_replace(sd, GEMMA_SD_MAP) + # Reverse change from Gemma2Model.modify_tensors in convert_hf_to_gguf.py + for k,v in sd.items(): + if k.endswith("norm.weight"): + if is_quantized(v): + v = dequantize_tensor(v, torch.float16) + sd[k] = v - 1.0 else: pass return sd