From c35439688f366c6864a09a3b1e98e7232bc3756b Mon Sep 17 00:00:00 2001 From: SLAPaper Date: Sat, 27 Jan 2024 04:05:09 +0800 Subject: [PATCH 1/3] feat: add dtype selection to PixArt --- PixArt/loader.py | 10 +++++++--- PixArt/nodes.py | 27 ++++++++++++++++++++++++++- 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/PixArt/loader.py b/PixArt/loader.py index e84fe14..36830d5 100644 --- a/PixArt/loader.py +++ b/PixArt/loader.py @@ -23,10 +23,11 @@ def __init__(self, model_conf): def model_type(self, state_dict, prefix=""): return comfy.model_base.ModelType.EPS -def load_pixart(model_path, model_conf): +def load_pixart(model_path, model_conf, target_dtype: torch.dtype | None=None): state_dict = comfy.utils.load_torch_file(model_path) state_dict = state_dict.get("model", state_dict) + # prefix for prefix in ["model.diffusion_model.",]: if any(True for x in state_dict if x.startswith(prefix)): @@ -36,8 +37,11 @@ def load_pixart(model_path, model_conf): if "adaln_single.linear.weight" in state_dict: state_dict = convert_state_dict(state_dict) # Diffusers - parameters = comfy.utils.calculate_parameters(state_dict) - unet_dtype = model_management.unet_dtype(model_params=parameters) + if target_dtype is None: + parameters = comfy.utils.calculate_parameters(state_dict) + unet_dtype = model_management.unet_dtype(model_params=parameters) + else: + unet_dtype = target_dtype model_conf = EXM_PixArt(model_conf) # convert to object model = comfy.model_base.BaseModel( diff --git a/PixArt/nodes.py b/PixArt/nodes.py index 184e6c3..fe88905 100644 --- a/PixArt/nodes.py +++ b/PixArt/nodes.py @@ -9,6 +9,14 @@ from .loader import load_pixart from .sampler import sample_pixart +dtypes = [ + "default", + "auto (comfy)", + "float32", + "float16", + "bfloat16", +] + class PixArtCheckpointLoader: @classmethod def INPUT_TYPES(s): @@ -16,6 +24,7 @@ def INPUT_TYPES(s): "required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"),), "model": (list(pixart_conf.keys()),), + "dtype": (dtypes,), } } RETURN_TYPES = ("MODEL",) @@ -24,12 +33,28 @@ def INPUT_TYPES(s): CATEGORY = "ExtraModels/PixArt" TITLE = "PixArt Checkpoint Loader" - def load_checkpoint(self, ckpt_name, model): + def load_checkpoint(self, ckpt_name, model, dtype: str): ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) model_conf = pixart_conf[model] + + target_dtype: torch.dtype | None = None + if dtype == "default": + target_dtype = torch.float16 + elif dtype == 'auto (comfy)': + target_dtype = None + elif dtype == 'float32': + target_dtype = torch.float32 + elif dtype == 'float16': + target_dtype = torch.float16 + elif dtype == 'bfloat16': + target_dtype = torch.bfloat16 + else: + raise ValueError(f"Invalid dtype: {dtype}") + model = load_pixart( model_path = ckpt_path, model_conf = model_conf, + target_dtype = target_dtype, ) return (model,) From d7cecb4a1d5a488c951bb9d8c0adcb257355d6ae Mon Sep 17 00:00:00 2001 From: SLAPaper Date: Sat, 27 Jan 2024 07:16:43 +0800 Subject: [PATCH 2/3] format: remove type annotation --- PixArt/loader.py | 2 +- PixArt/nodes.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/PixArt/loader.py b/PixArt/loader.py index 36830d5..7f90d93 100644 --- a/PixArt/loader.py +++ b/PixArt/loader.py @@ -23,7 +23,7 @@ def __init__(self, model_conf): def model_type(self, state_dict, prefix=""): return comfy.model_base.ModelType.EPS -def load_pixart(model_path, model_conf, target_dtype: torch.dtype | None=None): +def load_pixart(model_path, model_conf, target_dtype): state_dict = comfy.utils.load_torch_file(model_path) state_dict = state_dict.get("model", state_dict) diff --git a/PixArt/nodes.py b/PixArt/nodes.py index fe88905..2a3e72c 100644 --- a/PixArt/nodes.py +++ b/PixArt/nodes.py @@ -33,11 +33,11 @@ def INPUT_TYPES(s): CATEGORY = "ExtraModels/PixArt" TITLE = "PixArt Checkpoint Loader" - def load_checkpoint(self, ckpt_name, model, dtype: str): + def load_checkpoint(self, ckpt_name, model, dtype): ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) model_conf = pixart_conf[model] - target_dtype: torch.dtype | None = None + target_dtype = None if dtype == "default": target_dtype = torch.float16 elif dtype == 'auto (comfy)': From 0b1c6f2c41295ffe8f4736bacdd6439c09175423 Mon Sep 17 00:00:00 2001 From: SLAPaper Date: Sat, 27 Jan 2024 07:18:27 +0800 Subject: [PATCH 3/3] format: remove reduntant empty lines --- PixArt/loader.py | 1 - 1 file changed, 1 deletion(-) diff --git a/PixArt/loader.py b/PixArt/loader.py index 7f90d93..7f3df5c 100644 --- a/PixArt/loader.py +++ b/PixArt/loader.py @@ -27,7 +27,6 @@ def load_pixart(model_path, model_conf, target_dtype): state_dict = comfy.utils.load_torch_file(model_path) state_dict = state_dict.get("model", state_dict) - # prefix for prefix in ["model.diffusion_model.",]: if any(True for x in state_dict if x.startswith(prefix)):