Skip to content

Commit 5e48f46

Browse files
authored
fix the prefix_token_len bug (#12845)
1 parent a748a83 commit 5e48f46

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

src/diffusers/pipelines/longcat_image/pipeline_longcat_image_edit.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,9 @@ def _encode_prompt(self, prompt, image):
306306

307307
prefix_tokens = self.tokenizer(text, add_special_tokens=False)["input_ids"]
308308
suffix_tokens = self.tokenizer(self.prompt_template_encode_suffix, add_special_tokens=False)["input_ids"]
309-
prefix_len = len(prefix_tokens)
309+
310+
vision_start_token_id = self.tokenizer.convert_tokens_to_ids("<|vision_start|>")
311+
prefix_len = prefix_tokens.index(vision_start_token_id)
310312
suffix_len = len(suffix_tokens)
311313

312314
prefix_tokens_mask = torch.tensor([1] * len(prefix_tokens), dtype=text_tokens_and_mask.attention_mask[0].dtype)
@@ -660,7 +662,6 @@ def __call__(
660662
if image_latents is not None:
661663
latent_model_input = torch.cat([latents, image_latents], dim=1)
662664

663-
# latent_model_input = torch.cat([latent_model_input] * 2) if self.do_classifier_free_guidance else latent_model_input
664665
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
665666
with self.transformer.cache_context("cond"):
666667
noise_pred_text = self.transformer(

0 commit comments

Comments
 (0)