diff --git a/tools/convert.py b/tools/convert.py index fa12655..385beb1 100644 --- a/tools/convert.py +++ b/tools/convert.py @@ -155,6 +155,7 @@ def parse_args(): parser = argparse.ArgumentParser(description="Generate F16 GGUF files from single UNET") parser.add_argument("--src", required=True, help="Source model ckpt file.") parser.add_argument("--dst", help="Output unet gguf file.") + parser.add_argument("--clearsrc", action="store_true", help="Delete source file after conversion before saving to disk, for low space environment. RISKY.") args = parser.parse_args() if not os.path.isfile(args.src): @@ -274,7 +275,7 @@ def handle_tensors(writer, state_dict, model_arch): writer.add_tensor(new_name, data, raw_dtype=data_qtype) -def convert_file(path, dst_path=None, interact=True, overwrite=False): +def convert_file(path, dst_path=None, interact=True, overwrite=False, clear_source = False): # load & run model detection logic state_dict = load_state_dict(path) model_arch = detect_arch(state_dict) @@ -307,11 +308,16 @@ def convert_file(path, dst_path=None, interact=True, overwrite=False): raise OSError("Output exists and overwriting is disabled!") # handle actual file - writer = gguf.GGUFWriter(path=None, arch=model_arch.arch) + writer = gguf.GGUFWriter(path=None, arch=model_arch.arch, use_temp_file=True) # Cache to file. writer.add_quantization_version(gguf.GGML_QUANT_VERSION) if ftype_gguf is not None: writer.add_file_type(ftype_gguf) - + if clear_source: + try: + os.remove(path) + logging.info(f"Deleted source file: {path}") + except Exception as e: + logging.warning(f"Failed to delete source file: {e}") handle_tensors(writer, state_dict, model_arch) writer.write_header_to_file(path=dst_path) writer.write_kv_data_to_file() @@ -327,4 +333,4 @@ def convert_file(path, dst_path=None, interact=True, overwrite=False): if __name__ == "__main__": args = parse_args() - convert_file(args.src, args.dst) + convert_file(args.src, args.dst, clear_source = args.clearsrc)