Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions tools/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def parse_args():
parser = argparse.ArgumentParser(description="Generate F16 GGUF files from single UNET")
parser.add_argument("--src", required=True, help="Source model ckpt file.")
parser.add_argument("--dst", help="Output unet gguf file.")
parser.add_argument("--clearsrc", action="store_true", help="Delete source file after conversion before saving to disk, for low space environment. RISKY.")
args = parser.parse_args()

if not os.path.isfile(args.src):
Expand Down Expand Up @@ -274,7 +275,7 @@ def handle_tensors(writer, state_dict, model_arch):

writer.add_tensor(new_name, data, raw_dtype=data_qtype)

def convert_file(path, dst_path=None, interact=True, overwrite=False):
def convert_file(path, dst_path=None, interact=True, overwrite=False, clear_source = False):
# load & run model detection logic
state_dict = load_state_dict(path)
model_arch = detect_arch(state_dict)
Expand Down Expand Up @@ -307,11 +308,16 @@ def convert_file(path, dst_path=None, interact=True, overwrite=False):
raise OSError("Output exists and overwriting is disabled!")

# handle actual file
writer = gguf.GGUFWriter(path=None, arch=model_arch.arch)
writer = gguf.GGUFWriter(path=None, arch=model_arch.arch, use_temp_file=True) # Cache to file.
writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
if ftype_gguf is not None:
writer.add_file_type(ftype_gguf)

if clear_source:
try:
os.remove(path)
logging.info(f"Deleted source file: {path}")
except Exception as e:
logging.warning(f"Failed to delete source file: {e}")
handle_tensors(writer, state_dict, model_arch)
writer.write_header_to_file(path=dst_path)
writer.write_kv_data_to_file()
Expand All @@ -327,4 +333,4 @@ def convert_file(path, dst_path=None, interact=True, overwrite=False):

if __name__ == "__main__":
args = parse_args()
convert_file(args.src, args.dst)
convert_file(args.src, args.dst, clear_source = args.clearsrc)