Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 91 additions & 41 deletions loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .dequant import is_quantized, dequantize_tensor

IMG_ARCH_LIST = {"flux", "sd1", "sdxl", "sd3", "aura", "hidream", "cosmos", "ltxv", "hyvid", "wan", "lumina2", "qwen_image"}
TXT_ARCH_LIST = {"t5", "t5encoder", "llama", "qwen2vl"}
TXT_ARCH_LIST = {"t5", "t5encoder", "llama", "qwen2vl", "gemma2"}
VIS_TYPE_LIST = {"clip-vision"}

def get_orig_shape(reader, tensor_name):
Expand Down Expand Up @@ -170,6 +170,28 @@ def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", return_arch=Fal
"output.weight": "lm_head.weight",
}

GEMMA2_SD_MAP = {
"blk.": "model.layers.",
# Attention
".attn_q.weight": ".self_attn.q_proj.weight",
".attn_k.weight": ".self_attn.k_proj.weight",
".attn_v.weight": ".self_attn.v_proj.weight",
".attn_output.weight": ".self_attn.o_proj.weight",
# LayerNorm
".attn_norm.weight": ".input_layernorm.weight",
".post_attention_norm.weight": ".post_attention_layernorm.weight",
".post_ffw_norm.weight": ".post_feedforward_layernorm.weight",
".ffn_norm.weight": ".pre_feedforward_layernorm.weight", # Gemma2 safetensors only has pre_feedforward_layernorm
# MLP
".ffn_up.weight": ".mlp.up_proj.weight",
".ffn_down.weight": ".mlp.down_proj.weight",
".ffn_gate.weight": ".mlp.gate_proj.weight",
# emb/out
"token_embd.weight": "model.embed_tokens.weight",
"output_norm.weight": "model.norm.weight",
"output.weight": "lm_head.weight",
}

CLIP_VISION_SD_MAP = {
"mm.": "visual.merger.mlp.",
"v.post_ln.": "visual.merger.ln_q.",
Expand All @@ -186,8 +208,10 @@ def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", return_arch=Fal
def sd_map_replace(raw_sd, key_map):
sd = {}
for k,v in raw_sd.items():
orig_k = k
for s,d in key_map.items():
k = k.replace(s,d)
if s in k:
k = k.replace(s,d)
sd[k] = v
return sd

Expand Down Expand Up @@ -278,51 +302,48 @@ def gguf_mmproj_loader(path):

def gguf_tokenizer_loader(path, temb_shape):
# convert gguf tokenizer to spiece
logging.info("Attempting to recreate sentencepiece tokenizer from GGUF file metadata...")
try:
from sentencepiece import sentencepiece_model_pb2 as model
except ImportError:
raise ImportError("Please make sure sentencepiece and protobuf are installed.\npip install sentencepiece protobuf")
spm = model.ModelProto()


reader = gguf.GGUFReader(path)

if get_field(reader, "tokenizer.ggml.model", str) == "t5":
if temb_shape == (256384, 4096): # probably UMT5
spm.trainer_spec.model_type == 1 # Unigram (do we have a T5 w/ BPE?)
else:
raise NotImplementedError("Unknown model, can't set tokenizer!")
else:
raise NotImplementedError("Unknown model, can't set tokenizer!")

spm.normalizer_spec.add_dummy_prefix = get_field(reader, "tokenizer.ggml.add_space_prefix", bool)
spm.normalizer_spec.remove_extra_whitespaces = get_field(reader, "tokenizer.ggml.remove_extra_whitespaces", bool)

tokens = get_list_field(reader, "tokenizer.ggml.tokens", str)
scores = get_list_field(reader, "tokenizer.ggml.scores", float)
toktypes = get_list_field(reader, "tokenizer.ggml.token_type", int)

for idx, (token, score, toktype) in enumerate(zip(tokens, scores, toktypes)):
# # These aren't present in the original?
# if toktype == 5 and idx >= temb_shape[0]%1000):
# continue

piece = spm.SentencePiece()
piece.piece = token
piece.score = score
piece.type = toktype
spm.pieces.append(piece)

# unsure if any of these are correct
spm.trainer_spec.byte_fallback = True
spm.trainer_spec.vocab_size = len(tokens) # split off unused?
spm.trainer_spec.max_sentence_length = 4096
spm.trainer_spec.eos_id = get_field(reader, "tokenizer.ggml.eos_token_id", int)
spm.trainer_spec.pad_id = get_field(reader, "tokenizer.ggml.padding_token_id", int)

logging.info(f"Created tokenizer with vocab size of {len(spm.pieces)}")

proto_tensor = None
try:
for tensor in reader.tensors:
if tensor.name == "tokenizer.ggml.spiece_model_raw":
proto_tensor = torch.from_numpy(tensor.data)
break
except Exception as e:
logging.warning(f"Failed to read tokenizer.ggml.spiece_model_raw tensor: {e}")
proto_tensor = None
if proto_tensor is not None:
try:
proto_bytes = proto_tensor.cpu().numpy().tobytes()
spm = model.ModelProto()
spm.ParseFromString(proto_bytes)
vocab_size = len(spm.pieces)
logging.info(f"✓ Loaded complete sentencepiece proto from GGUF tensor: {vocab_size} pieces, {len(proto_bytes)} bytes")
logging.info(f" unk_id={spm.trainer_spec.unk_id}, bos_id={spm.trainer_spec.bos_id}, "
f"eos_id={spm.trainer_spec.eos_id}, pad_id={spm.trainer_spec.pad_id}")
if temb_shape[0] != vocab_size:
logging.warning(f"Proto vocab_size ({vocab_size}) != embedding shape[0] ({temb_shape[0]})")
del reader
return torch.ByteTensor(list(proto_bytes))
except Exception as e:
logging.warning(f"Failed to parse proto from int8 tensor: {e}")
spiece_tensor = reader.get_tensor("tokenizer.ggml.spiece_model_raw")
if spiece_tensor is not None:
del reader
return spiece_tensor
raw_proto_field = get_field(reader, "tokenizer.ggml.spiece_model_raw", str)
if raw_proto_field is not None:
proto_bytes = raw_proto_field.encode('latin1')
del reader
return torch.ByteTensor(list(proto_bytes))
del reader
return torch.ByteTensor(list(spm.SerializeToString()))
raise NotImplementedError("No sentencepiece proto found in GGUF metadata!")

def gguf_clip_loader(path):
sd, arch = gguf_sd_loader(path, return_arch=True, is_text_model=True)
Expand All @@ -348,6 +369,35 @@ def gguf_clip_loader(path):
if arch == "qwen2vl":
vsd = gguf_mmproj_loader(path)
sd.update(vsd)
elif arch == "gemma2":
temb_key = "token_embd.weight"
# Load tokenizer from GGUF metadata
if temb_key in sd:
try:
spm_tensor = gguf_tokenizer_loader(path, sd[temb_key].shape)
if spm_tensor is not None:
sd["spiece_model"] = spm_tensor
except NotImplementedError as e:
logging.error(f"[Gemma2] Failed to load tokenizer: {e}")
raise
if sd[temb_key].shape[0] >= (64 * 1024):
# Dequantize token embeddings to prevent OOM
logging.warning(f"Dequantizing {temb_key} to prevent runtime OOM.")
sd[temb_key] = dequantize_tensor(sd[temb_key], dtype=torch.float16)
sd = sd_map_replace(sd, GEMMA2_SD_MAP)
# Gemma2_2B has 8 attention heads and 4 key-value heads
sd = llama_permute(sd, 8, 4)
fix_keys = {}
for k in list(sd.keys()):
if k.startswith("model.layers."):
if (
("layernorm" in k or "mlp." in k or "proj" in k)
and not k.endswith(".weight")
and not k.endswith(".bias")
):
fix_keys[k+".weight"] = sd[k]
del sd[k]
sd.update(fix_keys)
else:
pass
return sd
187 changes: 187 additions & 0 deletions tools/convert_gemma2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
import os
import argparse
import logging
from safetensors.torch import load_file
import torch
import gguf
from tqdm import tqdm

# Gemma2 key mapping
KEY_MAP = {
# embedding
"model.embed_tokens.weight": "token_embd.weight",
# norm
"model.norm.weight": "output_norm.weight",
# spiece
"spiece_model": "tokenizer.ggml.spiece_model_raw",
}

# Layer parameter mapping
LAYER_KEY_MAP = {
# LayerNorm
"input_layernorm.weight": "attn_norm.weight",
"post_attention_layernorm.weight": "post_attention_norm.weight",
"post_feedforward_layernorm.weight": "post_ffw_norm.weight",
"pre_feedforward_layernorm.weight": "ffn_norm.weight",
# MLP
"mlp.down_proj.weight": "ffn_down.weight",
"mlp.gate_proj.weight": "ffn_gate.weight",
"mlp.up_proj.weight": "ffn_up.weight",
# Attention
"self_attn.k_proj.weight": "attn_k.weight",
"self_attn.o_proj.weight": "attn_output.weight",
"self_attn.q_proj.weight": "attn_q.weight",
"self_attn.v_proj.weight": "attn_v.weight",
}


def parse_args():
parser = argparse.ArgumentParser(description="Convert Gemma2 safetensors to GGUF with precision preservation")
parser.add_argument("--src", required=True, help="Source safetensors file")
parser.add_argument("--dst", help="Output GGUF file")
parser.add_argument("--quantize", "--quant", "-q",
choices=["f32", "f16", "bf16", "q8_0", "q4_0", "q4_1", "q5_0", "q5_1", "q2_k", "q3_k", "q4_k", "q5_k", "q6_k"],
help="Quantization type")
args = parser.parse_args()
if not os.path.isfile(args.src):
parser.error("Input file does not exist!")
return args


def map_key(key):
# Direct mapping
if key in KEY_MAP:
return KEY_MAP[key]
# Layer parameter mapping
import re
m = re.match(r"model.layers.(\d+)\.(.+)", key)
if m:
layer_idx, subkey = m.groups()
if subkey in LAYER_KEY_MAP:
return f"blk.{layer_idx}.{LAYER_KEY_MAP[subkey]}"
return key # Keep others as-is


def get_quantization_type(quant_str):
quant_map = {
"f32": gguf.GGMLQuantizationType.F32,
"f16": gguf.GGMLQuantizationType.F16,
"bf16": gguf.GGMLQuantizationType.BF16,
"q8_0": gguf.GGMLQuantizationType.Q8_0,
"q4_0": gguf.GGMLQuantizationType.Q4_0,
"q4_1": gguf.GGMLQuantizationType.Q4_1,
"q5_0": gguf.GGMLQuantizationType.Q5_0,
"q5_1": gguf.GGMLQuantizationType.Q5_1,
"q2_k": gguf.GGMLQuantizationType.Q2_K,
"q3_k": gguf.GGMLQuantizationType.Q3_K,
"q4_k": gguf.GGMLQuantizationType.Q4_K,
"q5_k": gguf.GGMLQuantizationType.Q5_K,
"q6_k": gguf.GGMLQuantizationType.Q6_K,
}
return quant_map.get(quant_str.lower())


def should_quantize_tensor(key, quant_type):
"""Determine if a tensor should be quantized
Rules:
- token_embd (embedding) kept at F16 (quantization severely impacts quality)
- norm layers kept at F32 (quantization affects stability)
- other weights (attn/mlp) use target quantization
"""
# Embedding always kept at F16
if key == "token_embd.weight":
return False, gguf.GGMLQuantizationType.F16

# Norm layers kept at F32
norm_suffixes = [
"attn_norm.weight",
"post_attention_norm.weight",
"post_ffw_norm.weight",
"ffn_norm.weight",
"output_norm.weight"
]
if any(key.endswith(suffix) for suffix in norm_suffixes):
return False, gguf.GGMLQuantizationType.F32

# Other layers (attn/mlp) use target quantization
return True, quant_type


def main():
args = parse_args()
state_dict = load_file(args.src)

if args.quantize:
quant_type = get_quantization_type(args.quantize)
ftype_name = args.quantize.upper()
else:
dtypes = [v.dtype for v in state_dict.values() if hasattr(v, 'dtype')]
main_dtype = max(set(dtypes), key=dtypes.count) if dtypes else torch.float16
if main_dtype == torch.float32:
ftype_name = "F32"
quant_type = gguf.GGMLQuantizationType.F32
elif main_dtype == torch.bfloat16:
ftype_name = "BF16"
quant_type = gguf.GGMLQuantizationType.BF16
else:
ftype_name = "F16"
quant_type = gguf.GGMLQuantizationType.F16

dst = args.dst or f"{os.path.splitext(args.src)[0]}-{ftype_name}.gguf"
if os.path.isfile(dst):
input(f"Output file {dst} exists, press Enter to overwrite or Ctrl+C to cancel...")

writer = gguf.GGUFWriter(path=None, arch="gemma2")
writer.add_quantization_version(gguf.GGML_QUANT_VERSION)

print(f"Target quantization: {ftype_name}")
print(f"Output file: {dst}")

for key, value in tqdm(state_dict.items(), desc="Converting"):
new_key = map_key(key)

# Special handling for spiece_model
if key == "spiece_model":
arr = value.cpu().numpy().astype("int8")
writer.add_tensor(new_key, arr, raw_dtype=gguf.GGMLQuantizationType.I8)
tqdm.write(f"{key} -> {new_key} (spiece_model, {arr.shape[0]} bytes, I8)")
continue

if not hasattr(value, 'dtype'):
tqdm.write(f"Skipping non-tensor: {key}")
continue

arr = value.cpu().numpy()

# Determine if quantization needed + get target precision
should_quant, target_qtype = should_quantize_tensor(new_key, quant_type)

# Apply quantization or keep original precision
if should_quant and target_qtype not in [gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16, gguf.GGMLQuantizationType.BF16]:
quantized_arr = gguf.quants.quantize(arr, target_qtype)
writer.add_tensor(new_key, quantized_arr, raw_dtype=target_qtype)
tqdm.write(f"{key} -> {new_key}, {value.dtype} -> {target_qtype.name}, shape={arr.shape}")
else:
if target_qtype == gguf.GGMLQuantizationType.F32:
arr = arr.astype('float32')
elif target_qtype == gguf.GGMLQuantizationType.BF16:
# BF16 requires special handling
pass # gguf.quants.quantize handles this
else: # F16
arr = arr.astype('float16')

quantized_arr = gguf.quants.quantize(arr, target_qtype)
writer.add_tensor(new_key, quantized_arr, raw_dtype=target_qtype)
tqdm.write(f"{key} -> {new_key}, {value.dtype} -> {target_qtype.name}, shape={arr.shape}")

print("Writing GGUF file...")
writer.write_header_to_file(path=dst)
writer.write_kv_data_to_file()
writer.write_tensors_to_file(progress=True)
writer.close()
print(f"Conversion complete: {dst}")
print(f"Quantization type: {ftype_name}")
print(f"File size: {os.path.getsize(dst) / (1024**3):.2f} GB")

if __name__ == "__main__":
main()