diff --git a/Sana/nodes.py b/Sana/nodes.py index abce9ba..cf02e18 100644 --- a/Sana/nodes.py +++ b/Sana/nodes.py @@ -96,9 +96,13 @@ class SanaTextEncode: @classmethod def INPUT_TYPES(s): return { - "required": { + "optional": { + "chi_prompt_string": ("STRING", {"forceInput": True}) + }, + "required": { "text": ("STRING", {"multiline": True}), "GEMMA": ("GEMMA",), + "chi": ("BOOLEAN", {"default": True}) } } @@ -107,16 +111,21 @@ def INPUT_TYPES(s): CATEGORY = "ExtraModels/Sana" TITLE = "Sana Text Encode" - def encode(self, text, GEMMA=None): + def encode(self, text, GEMMA=None, chi=True, chi_prompt_string=None): tokenizer = GEMMA["tokenizer"] text_encoder = GEMMA["text_encoder"] with torch.no_grad(): - chi_prompt = "\n".join(preset_te_prompt) + if chi_prompt_string is None and chi == True: + chi_prompt = "\n".join(preset_te_prompt) + elif chi_prompt_string is not None and chi == True: + chi_prompt = chi_prompt_string + else: + chi_prompt = "" full_prompt = chi_prompt + text num_chi_tokens = len(tokenizer.encode(chi_prompt)) max_length = num_chi_tokens + 300 - 2 - + tokens = tokenizer( [full_prompt], max_length=max_length,