diff --git a/loader.py b/loader.py index fd35e13..684655b 100644 --- a/loader.py +++ b/loader.py @@ -10,7 +10,7 @@ from .dequant import is_quantized, dequantize_tensor IMG_ARCH_LIST = {"flux", "sd1", "sdxl", "sd3", "aura", "hidream", "cosmos", "ltxv", "hyvid", "wan", "lumina2", "qwen_image"} -TXT_ARCH_LIST = {"t5", "t5encoder", "llama", "qwen2vl"} +TXT_ARCH_LIST = {"t5", "t5encoder", "llama", "qwen2vl", "gemma2"} VIS_TYPE_LIST = {"clip-vision"} def get_orig_shape(reader, tensor_name): @@ -170,6 +170,28 @@ def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", return_arch=Fal "output.weight": "lm_head.weight", } +GEMMA2_SD_MAP = { + "blk.": "model.layers.", + # Attention + ".attn_q.weight": ".self_attn.q_proj.weight", + ".attn_k.weight": ".self_attn.k_proj.weight", + ".attn_v.weight": ".self_attn.v_proj.weight", + ".attn_output.weight": ".self_attn.o_proj.weight", + # LayerNorm + ".attn_norm.weight": ".input_layernorm.weight", + ".post_attention_norm.weight": ".post_attention_layernorm.weight", + ".post_ffw_norm.weight": ".post_feedforward_layernorm.weight", + ".ffn_norm.weight": ".pre_feedforward_layernorm.weight", # Gemma2 safetensors only has pre_feedforward_layernorm + # MLP + ".ffn_up.weight": ".mlp.up_proj.weight", + ".ffn_down.weight": ".mlp.down_proj.weight", + ".ffn_gate.weight": ".mlp.gate_proj.weight", + # emb/out + "token_embd.weight": "model.embed_tokens.weight", + "output_norm.weight": "model.norm.weight", + "output.weight": "lm_head.weight", +} + CLIP_VISION_SD_MAP = { "mm.": "visual.merger.mlp.", "v.post_ln.": "visual.merger.ln_q.", @@ -186,8 +208,10 @@ def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", return_arch=Fal def sd_map_replace(raw_sd, key_map): sd = {} for k,v in raw_sd.items(): + orig_k = k for s,d in key_map.items(): - k = k.replace(s,d) + if s in k: + k = k.replace(s,d) sd[k] = v return sd @@ -278,51 +302,48 @@ def gguf_mmproj_loader(path): def gguf_tokenizer_loader(path, temb_shape): # convert gguf tokenizer to spiece - logging.info("Attempting to recreate sentencepiece tokenizer from GGUF file metadata...") try: from sentencepiece import sentencepiece_model_pb2 as model except ImportError: raise ImportError("Please make sure sentencepiece and protobuf are installed.\npip install sentencepiece protobuf") - spm = model.ModelProto() - + reader = gguf.GGUFReader(path) - - if get_field(reader, "tokenizer.ggml.model", str) == "t5": - if temb_shape == (256384, 4096): # probably UMT5 - spm.trainer_spec.model_type == 1 # Unigram (do we have a T5 w/ BPE?) - else: - raise NotImplementedError("Unknown model, can't set tokenizer!") - else: - raise NotImplementedError("Unknown model, can't set tokenizer!") - - spm.normalizer_spec.add_dummy_prefix = get_field(reader, "tokenizer.ggml.add_space_prefix", bool) - spm.normalizer_spec.remove_extra_whitespaces = get_field(reader, "tokenizer.ggml.remove_extra_whitespaces", bool) - - tokens = get_list_field(reader, "tokenizer.ggml.tokens", str) - scores = get_list_field(reader, "tokenizer.ggml.scores", float) - toktypes = get_list_field(reader, "tokenizer.ggml.token_type", int) - - for idx, (token, score, toktype) in enumerate(zip(tokens, scores, toktypes)): - # # These aren't present in the original? - # if toktype == 5 and idx >= temb_shape[0]%1000): - # continue - - piece = spm.SentencePiece() - piece.piece = token - piece.score = score - piece.type = toktype - spm.pieces.append(piece) - - # unsure if any of these are correct - spm.trainer_spec.byte_fallback = True - spm.trainer_spec.vocab_size = len(tokens) # split off unused? - spm.trainer_spec.max_sentence_length = 4096 - spm.trainer_spec.eos_id = get_field(reader, "tokenizer.ggml.eos_token_id", int) - spm.trainer_spec.pad_id = get_field(reader, "tokenizer.ggml.padding_token_id", int) - - logging.info(f"Created tokenizer with vocab size of {len(spm.pieces)}") + + proto_tensor = None + try: + for tensor in reader.tensors: + if tensor.name == "tokenizer.ggml.spiece_model_raw": + proto_tensor = torch.from_numpy(tensor.data) + break + except Exception as e: + logging.warning(f"Failed to read tokenizer.ggml.spiece_model_raw tensor: {e}") + proto_tensor = None + if proto_tensor is not None: + try: + proto_bytes = proto_tensor.cpu().numpy().tobytes() + spm = model.ModelProto() + spm.ParseFromString(proto_bytes) + vocab_size = len(spm.pieces) + logging.info(f"✓ Loaded complete sentencepiece proto from GGUF tensor: {vocab_size} pieces, {len(proto_bytes)} bytes") + logging.info(f" unk_id={spm.trainer_spec.unk_id}, bos_id={spm.trainer_spec.bos_id}, " + f"eos_id={spm.trainer_spec.eos_id}, pad_id={spm.trainer_spec.pad_id}") + if temb_shape[0] != vocab_size: + logging.warning(f"Proto vocab_size ({vocab_size}) != embedding shape[0] ({temb_shape[0]})") + del reader + return torch.ByteTensor(list(proto_bytes)) + except Exception as e: + logging.warning(f"Failed to parse proto from int8 tensor: {e}") + spiece_tensor = reader.get_tensor("tokenizer.ggml.spiece_model_raw") + if spiece_tensor is not None: + del reader + return spiece_tensor + raw_proto_field = get_field(reader, "tokenizer.ggml.spiece_model_raw", str) + if raw_proto_field is not None: + proto_bytes = raw_proto_field.encode('latin1') + del reader + return torch.ByteTensor(list(proto_bytes)) del reader - return torch.ByteTensor(list(spm.SerializeToString())) + raise NotImplementedError("No sentencepiece proto found in GGUF metadata!") def gguf_clip_loader(path): sd, arch = gguf_sd_loader(path, return_arch=True, is_text_model=True) @@ -348,6 +369,35 @@ def gguf_clip_loader(path): if arch == "qwen2vl": vsd = gguf_mmproj_loader(path) sd.update(vsd) + elif arch == "gemma2": + temb_key = "token_embd.weight" + # Load tokenizer from GGUF metadata + if temb_key in sd: + try: + spm_tensor = gguf_tokenizer_loader(path, sd[temb_key].shape) + if spm_tensor is not None: + sd["spiece_model"] = spm_tensor + except NotImplementedError as e: + logging.error(f"[Gemma2] Failed to load tokenizer: {e}") + raise + if sd[temb_key].shape[0] >= (64 * 1024): + # Dequantize token embeddings to prevent OOM + logging.warning(f"Dequantizing {temb_key} to prevent runtime OOM.") + sd[temb_key] = dequantize_tensor(sd[temb_key], dtype=torch.float16) + sd = sd_map_replace(sd, GEMMA2_SD_MAP) + # Gemma2_2B has 8 attention heads and 4 key-value heads + sd = llama_permute(sd, 8, 4) + fix_keys = {} + for k in list(sd.keys()): + if k.startswith("model.layers."): + if ( + ("layernorm" in k or "mlp." in k or "proj" in k) + and not k.endswith(".weight") + and not k.endswith(".bias") + ): + fix_keys[k+".weight"] = sd[k] + del sd[k] + sd.update(fix_keys) else: pass return sd diff --git a/tools/convert_gemma2.py b/tools/convert_gemma2.py new file mode 100644 index 0000000..afcd884 --- /dev/null +++ b/tools/convert_gemma2.py @@ -0,0 +1,187 @@ +import os +import argparse +import logging +from safetensors.torch import load_file +import torch +import gguf +from tqdm import tqdm + +# Gemma2 key mapping +KEY_MAP = { + # embedding + "model.embed_tokens.weight": "token_embd.weight", + # norm + "model.norm.weight": "output_norm.weight", + # spiece + "spiece_model": "tokenizer.ggml.spiece_model_raw", +} + +# Layer parameter mapping +LAYER_KEY_MAP = { + # LayerNorm + "input_layernorm.weight": "attn_norm.weight", + "post_attention_layernorm.weight": "post_attention_norm.weight", + "post_feedforward_layernorm.weight": "post_ffw_norm.weight", + "pre_feedforward_layernorm.weight": "ffn_norm.weight", + # MLP + "mlp.down_proj.weight": "ffn_down.weight", + "mlp.gate_proj.weight": "ffn_gate.weight", + "mlp.up_proj.weight": "ffn_up.weight", + # Attention + "self_attn.k_proj.weight": "attn_k.weight", + "self_attn.o_proj.weight": "attn_output.weight", + "self_attn.q_proj.weight": "attn_q.weight", + "self_attn.v_proj.weight": "attn_v.weight", +} + + +def parse_args(): + parser = argparse.ArgumentParser(description="Convert Gemma2 safetensors to GGUF with precision preservation") + parser.add_argument("--src", required=True, help="Source safetensors file") + parser.add_argument("--dst", help="Output GGUF file") + parser.add_argument("--quantize", "--quant", "-q", + choices=["f32", "f16", "bf16", "q8_0", "q4_0", "q4_1", "q5_0", "q5_1", "q2_k", "q3_k", "q4_k", "q5_k", "q6_k"], + help="Quantization type") + args = parser.parse_args() + if not os.path.isfile(args.src): + parser.error("Input file does not exist!") + return args + + +def map_key(key): + # Direct mapping + if key in KEY_MAP: + return KEY_MAP[key] + # Layer parameter mapping + import re + m = re.match(r"model.layers.(\d+)\.(.+)", key) + if m: + layer_idx, subkey = m.groups() + if subkey in LAYER_KEY_MAP: + return f"blk.{layer_idx}.{LAYER_KEY_MAP[subkey]}" + return key # Keep others as-is + + +def get_quantization_type(quant_str): + quant_map = { + "f32": gguf.GGMLQuantizationType.F32, + "f16": gguf.GGMLQuantizationType.F16, + "bf16": gguf.GGMLQuantizationType.BF16, + "q8_0": gguf.GGMLQuantizationType.Q8_0, + "q4_0": gguf.GGMLQuantizationType.Q4_0, + "q4_1": gguf.GGMLQuantizationType.Q4_1, + "q5_0": gguf.GGMLQuantizationType.Q5_0, + "q5_1": gguf.GGMLQuantizationType.Q5_1, + "q2_k": gguf.GGMLQuantizationType.Q2_K, + "q3_k": gguf.GGMLQuantizationType.Q3_K, + "q4_k": gguf.GGMLQuantizationType.Q4_K, + "q5_k": gguf.GGMLQuantizationType.Q5_K, + "q6_k": gguf.GGMLQuantizationType.Q6_K, + } + return quant_map.get(quant_str.lower()) + + +def should_quantize_tensor(key, quant_type): + """Determine if a tensor should be quantized + Rules: + - token_embd (embedding) kept at F16 (quantization severely impacts quality) + - norm layers kept at F32 (quantization affects stability) + - other weights (attn/mlp) use target quantization + """ + # Embedding always kept at F16 + if key == "token_embd.weight": + return False, gguf.GGMLQuantizationType.F16 + + # Norm layers kept at F32 + norm_suffixes = [ + "attn_norm.weight", + "post_attention_norm.weight", + "post_ffw_norm.weight", + "ffn_norm.weight", + "output_norm.weight" + ] + if any(key.endswith(suffix) for suffix in norm_suffixes): + return False, gguf.GGMLQuantizationType.F32 + + # Other layers (attn/mlp) use target quantization + return True, quant_type + + +def main(): + args = parse_args() + state_dict = load_file(args.src) + + if args.quantize: + quant_type = get_quantization_type(args.quantize) + ftype_name = args.quantize.upper() + else: + dtypes = [v.dtype for v in state_dict.values() if hasattr(v, 'dtype')] + main_dtype = max(set(dtypes), key=dtypes.count) if dtypes else torch.float16 + if main_dtype == torch.float32: + ftype_name = "F32" + quant_type = gguf.GGMLQuantizationType.F32 + elif main_dtype == torch.bfloat16: + ftype_name = "BF16" + quant_type = gguf.GGMLQuantizationType.BF16 + else: + ftype_name = "F16" + quant_type = gguf.GGMLQuantizationType.F16 + + dst = args.dst or f"{os.path.splitext(args.src)[0]}-{ftype_name}.gguf" + if os.path.isfile(dst): + input(f"Output file {dst} exists, press Enter to overwrite or Ctrl+C to cancel...") + + writer = gguf.GGUFWriter(path=None, arch="gemma2") + writer.add_quantization_version(gguf.GGML_QUANT_VERSION) + + print(f"Target quantization: {ftype_name}") + print(f"Output file: {dst}") + + for key, value in tqdm(state_dict.items(), desc="Converting"): + new_key = map_key(key) + + # Special handling for spiece_model + if key == "spiece_model": + arr = value.cpu().numpy().astype("int8") + writer.add_tensor(new_key, arr, raw_dtype=gguf.GGMLQuantizationType.I8) + tqdm.write(f"{key} -> {new_key} (spiece_model, {arr.shape[0]} bytes, I8)") + continue + + if not hasattr(value, 'dtype'): + tqdm.write(f"Skipping non-tensor: {key}") + continue + + arr = value.cpu().numpy() + + # Determine if quantization needed + get target precision + should_quant, target_qtype = should_quantize_tensor(new_key, quant_type) + + # Apply quantization or keep original precision + if should_quant and target_qtype not in [gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16, gguf.GGMLQuantizationType.BF16]: + quantized_arr = gguf.quants.quantize(arr, target_qtype) + writer.add_tensor(new_key, quantized_arr, raw_dtype=target_qtype) + tqdm.write(f"{key} -> {new_key}, {value.dtype} -> {target_qtype.name}, shape={arr.shape}") + else: + if target_qtype == gguf.GGMLQuantizationType.F32: + arr = arr.astype('float32') + elif target_qtype == gguf.GGMLQuantizationType.BF16: + # BF16 requires special handling + pass # gguf.quants.quantize handles this + else: # F16 + arr = arr.astype('float16') + + quantized_arr = gguf.quants.quantize(arr, target_qtype) + writer.add_tensor(new_key, quantized_arr, raw_dtype=target_qtype) + tqdm.write(f"{key} -> {new_key}, {value.dtype} -> {target_qtype.name}, shape={arr.shape}") + + print("Writing GGUF file...") + writer.write_header_to_file(path=dst) + writer.write_kv_data_to_file() + writer.write_tensors_to_file(progress=True) + writer.close() + print(f"Conversion complete: {dst}") + print(f"Quantization type: {ftype_name}") + print(f"File size: {os.path.getsize(dst) / (1024**3):.2f} GB") + +if __name__ == "__main__": + main()