Skip to content

Commit c044a40

Browse files
committed
support longcat-image-edit
Fix base rope offset for ref images
1 parent 9f225e4 commit c044a40

File tree

2 files changed

+131
-61
lines changed

2 files changed

+131
-61
lines changed

conditioner.hpp

Lines changed: 125 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1690,7 +1690,7 @@ struct LLMEmbedder : public Conditioner {
16901690
std::string current_part;
16911691

16921692
for (char c : curr_text) {
1693-
if (c == '\'') {
1693+
if (c == '"') {
16941694
if (!current_part.empty()) {
16951695
parts.push_back(current_part);
16961696
current_part.clear();
@@ -1711,7 +1711,7 @@ struct LLMEmbedder : public Conditioner {
17111711
for (const auto& part : parts) {
17121712
if (part.empty())
17131713
continue;
1714-
if (part[0] == '\'' && part.back() == '\'') {
1714+
if (part[0] == '"' && part.back() == '"') {
17151715
std::string quoted_content = part.substr(1, part.size() - 2);
17161716
for (char ch : quoted_content) {
17171717
std::string char_str(1, ch);
@@ -1747,68 +1747,139 @@ struct LLMEmbedder : public Conditioner {
17471747
bool spell_quotes = false;
17481748
std::set<int> out_layers;
17491749
if (llm->enable_vision && conditioner_params.ref_images.size() > 0) {
1750-
LOG_INFO("QwenImageEditPlusPipeline");
1751-
prompt_template_encode_start_idx = 64;
1752-
int image_embed_idx = 64 + 6;
1753-
1754-
int min_pixels = 384 * 384;
1755-
int max_pixels = 560 * 560;
1756-
std::string placeholder = "<|image_pad|>";
1757-
std::string img_prompt;
1758-
1759-
for (int i = 0; i < conditioner_params.ref_images.size(); i++) {
1760-
sd_image_f32_t image = sd_image_t_to_sd_image_f32_t(*conditioner_params.ref_images[i]);
1761-
double factor = llm->params.vision.patch_size * llm->params.vision.spatial_merge_size;
1762-
int height = image.height;
1763-
int width = image.width;
1764-
int h_bar = static_cast<int>(std::round(height / factor)) * factor;
1765-
int w_bar = static_cast<int>(std::round(width / factor)) * factor;
1766-
1767-
if (static_cast<double>(h_bar) * w_bar > max_pixels) {
1768-
double beta = std::sqrt((height * width) / static_cast<double>(max_pixels));
1769-
h_bar = std::max(static_cast<int>(factor),
1770-
static_cast<int>(std::floor(height / beta / factor)) * static_cast<int>(factor));
1771-
w_bar = std::max(static_cast<int>(factor),
1772-
static_cast<int>(std::floor(width / beta / factor)) * static_cast<int>(factor));
1773-
} else if (static_cast<double>(h_bar) * w_bar < min_pixels) {
1774-
double beta = std::sqrt(static_cast<double>(min_pixels) / (height * width));
1775-
h_bar = static_cast<int>(std::ceil(height * beta / factor)) * static_cast<int>(factor);
1776-
w_bar = static_cast<int>(std::ceil(width * beta / factor)) * static_cast<int>(factor);
1750+
if (sd_version_is_longcat(version)) {
1751+
LOG_INFO("LongCatEditPipeline");
1752+
prompt_template_encode_start_idx = 67;
1753+
// prompt_template_encode_end_idx = 5;
1754+
int image_embed_idx = 36 + 6;
1755+
1756+
int min_pixels = 384 * 384;
1757+
int max_pixels = 560 * 560;
1758+
std::string placeholder = "<|image_pad|>";
1759+
std::string img_prompt;
1760+
1761+
1762+
// Only one image is officicially supported by the model, not sure how it handles multiple images
1763+
for (int i = 0; i < conditioner_params.ref_images.size(); i++) {
1764+
sd_image_f32_t image = sd_image_t_to_sd_image_f32_t(*conditioner_params.ref_images[i]);
1765+
double factor = llm->params.vision.patch_size * llm->params.vision.spatial_merge_size;
1766+
int height = image.height;
1767+
int width = image.width;
1768+
int h_bar = static_cast<int>(std::round(height / factor)) * factor;
1769+
int w_bar = static_cast<int>(std::round(width / factor)) * factor;
1770+
1771+
if (static_cast<double>(h_bar) * w_bar > max_pixels) {
1772+
double beta = std::sqrt((height * width) / static_cast<double>(max_pixels));
1773+
h_bar = std::max(static_cast<int>(factor),
1774+
static_cast<int>(std::floor(height / beta / factor)) * static_cast<int>(factor));
1775+
w_bar = std::max(static_cast<int>(factor),
1776+
static_cast<int>(std::floor(width / beta / factor)) * static_cast<int>(factor));
1777+
} else if (static_cast<double>(h_bar) * w_bar < min_pixels) {
1778+
double beta = std::sqrt(static_cast<double>(min_pixels) / (height * width));
1779+
h_bar = static_cast<int>(std::ceil(height * beta / factor)) * static_cast<int>(factor);
1780+
w_bar = static_cast<int>(std::ceil(width * beta / factor)) * static_cast<int>(factor);
1781+
}
1782+
1783+
LOG_DEBUG("resize conditioner ref image %d from %dx%d to %dx%d", i, image.height, image.width, h_bar, w_bar);
1784+
1785+
sd_image_f32_t resized_image = clip_preprocess(image, w_bar, h_bar);
1786+
free(image.data);
1787+
image.data = nullptr;
1788+
1789+
ggml_tensor* image_tensor = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1);
1790+
sd_image_f32_to_ggml_tensor(resized_image, image_tensor, false);
1791+
free(resized_image.data);
1792+
resized_image.data = nullptr;
1793+
1794+
ggml_tensor* image_embed = nullptr;
1795+
llm->encode_image(n_threads, image_tensor, &image_embed, work_ctx);
1796+
image_embeds.emplace_back(image_embed_idx, image_embed);
1797+
image_embed_idx += 1 + image_embed->ne[1] + 6;
1798+
1799+
img_prompt += "<|vision_start|>";
1800+
int64_t num_image_tokens = image_embed->ne[1];
1801+
img_prompt.reserve(num_image_tokens * placeholder.size());
1802+
for (int j = 0; j < num_image_tokens; j++) {
1803+
img_prompt += placeholder;
1804+
}
1805+
img_prompt += "<|vision_end|>";
17771806
}
17781807

1779-
LOG_DEBUG("resize conditioner ref image %d from %dx%d to %dx%d", i, image.height, image.width, h_bar, w_bar);
1808+
max_length = 512;
1809+
spell_quotes = true;
1810+
prompt = "<|im_start|>system\nAs an image editing expert, first analyze the content and attributes of the input image(s). Then, based on the user's editing instructions, clearly and precisely determine how to modify the given image(s), ensuring that only the specified parts are altered and all other aspects remain consistent with the original(s).<|im_end|>\n<|im_start|>user\n";
1811+
prompt += img_prompt;
1812+
1813+
prompt_attn_range.first = static_cast<int>(prompt.size());
1814+
prompt += conditioner_params.text;
1815+
prompt_attn_range.second = static_cast<int>(prompt.size());
1816+
1817+
prompt += "<|im_end|>\n<|im_start|>assistant\n";
1818+
1819+
} else {
1820+
LOG_INFO("QwenImageEditPlusPipeline");
1821+
prompt_template_encode_start_idx = 64;
1822+
int image_embed_idx = 64 + 6;
1823+
1824+
int min_pixels = 384 * 384;
1825+
int max_pixels = 560 * 560;
1826+
std::string placeholder = "<|image_pad|>";
1827+
std::string img_prompt;
1828+
1829+
for (int i = 0; i < conditioner_params.ref_images.size(); i++) {
1830+
sd_image_f32_t image = sd_image_t_to_sd_image_f32_t(*conditioner_params.ref_images[i]);
1831+
double factor = llm->params.vision.patch_size * llm->params.vision.spatial_merge_size;
1832+
int height = image.height;
1833+
int width = image.width;
1834+
int h_bar = static_cast<int>(std::round(height / factor)) * factor;
1835+
int w_bar = static_cast<int>(std::round(width / factor)) * factor;
1836+
1837+
if (static_cast<double>(h_bar) * w_bar > max_pixels) {
1838+
double beta = std::sqrt((height * width) / static_cast<double>(max_pixels));
1839+
h_bar = std::max(static_cast<int>(factor),
1840+
static_cast<int>(std::floor(height / beta / factor)) * static_cast<int>(factor));
1841+
w_bar = std::max(static_cast<int>(factor),
1842+
static_cast<int>(std::floor(width / beta / factor)) * static_cast<int>(factor));
1843+
} else if (static_cast<double>(h_bar) * w_bar < min_pixels) {
1844+
double beta = std::sqrt(static_cast<double>(min_pixels) / (height * width));
1845+
h_bar = static_cast<int>(std::ceil(height * beta / factor)) * static_cast<int>(factor);
1846+
w_bar = static_cast<int>(std::ceil(width * beta / factor)) * static_cast<int>(factor);
1847+
}
1848+
1849+
LOG_DEBUG("resize conditioner ref image %d from %dx%d to %dx%d", i, image.height, image.width, h_bar, w_bar);
17801850

1781-
sd_image_f32_t resized_image = clip_preprocess(image, w_bar, h_bar);
1782-
free(image.data);
1783-
image.data = nullptr;
1851+
sd_image_f32_t resized_image = clip_preprocess(image, w_bar, h_bar);
1852+
free(image.data);
1853+
image.data = nullptr;
17841854

1785-
ggml_tensor* image_tensor = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1);
1786-
sd_image_f32_to_ggml_tensor(resized_image, image_tensor, false);
1787-
free(resized_image.data);
1788-
resized_image.data = nullptr;
1855+
ggml_tensor* image_tensor = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1);
1856+
sd_image_f32_to_ggml_tensor(resized_image, image_tensor, false);
1857+
free(resized_image.data);
1858+
resized_image.data = nullptr;
17891859

1790-
ggml_tensor* image_embed = nullptr;
1791-
llm->encode_image(n_threads, image_tensor, &image_embed, work_ctx);
1792-
image_embeds.emplace_back(image_embed_idx, image_embed);
1793-
image_embed_idx += 1 + image_embed->ne[1] + 6;
1860+
ggml_tensor* image_embed = nullptr;
1861+
llm->encode_image(n_threads, image_tensor, &image_embed, work_ctx);
1862+
image_embeds.emplace_back(image_embed_idx, image_embed);
1863+
image_embed_idx += 1 + image_embed->ne[1] + 6;
17941864

1795-
img_prompt += "Picture " + std::to_string(i + 1) + ": <|vision_start|>"; // [24669, 220, index, 25, 220, 151652]
1796-
int64_t num_image_tokens = image_embed->ne[1];
1797-
img_prompt.reserve(num_image_tokens * placeholder.size());
1798-
for (int j = 0; j < num_image_tokens; j++) {
1799-
img_prompt += placeholder;
1865+
img_prompt += "Picture " + std::to_string(i + 1) + ": <|vision_start|>"; // [24669, 220, index, 25, 220, 151652]
1866+
int64_t num_image_tokens = image_embed->ne[1];
1867+
img_prompt.reserve(num_image_tokens * placeholder.size());
1868+
for (int j = 0; j < num_image_tokens; j++) {
1869+
img_prompt += placeholder;
1870+
}
1871+
img_prompt += "<|vision_end|>";
18001872
}
1801-
img_prompt += "<|vision_end|>";
1802-
}
18031873

1804-
prompt = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n";
1805-
prompt += img_prompt;
1874+
prompt = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n";
1875+
prompt += img_prompt;
18061876

1807-
prompt_attn_range.first = static_cast<int>(prompt.size());
1808-
prompt += conditioner_params.text;
1809-
prompt_attn_range.second = static_cast<int>(prompt.size());
1877+
prompt_attn_range.first = static_cast<int>(prompt.size());
1878+
prompt += conditioner_params.text;
1879+
prompt_attn_range.second = static_cast<int>(prompt.size());
18101880

1811-
prompt += "<|im_end|>\n<|im_start|>assistant\n";
1881+
prompt += "<|im_end|>\n<|im_start|>assistant\n";
1882+
}
18121883
} else if (sd_version_is_flux2(version)) {
18131884
prompt_template_encode_start_idx = 0;
18141885
out_layers = {10, 20, 30};

rope.hpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ namespace Rope {
9393
return txt_ids;
9494
}
9595

96-
__STATIC_INLINE__ std::vector<std::vector<float>> gen_flux_img_ids(int h,
96+
__STATIC_INLINE__ std::vector<std::vector<float>> gen_flux_img_ids(int h,
9797
int w,
9898
int patch_size,
9999
int bs,
@@ -107,7 +107,6 @@ namespace Rope {
107107

108108
std::vector<float> row_ids = linspace<float>(h_offset, h_len - 1 + h_offset, h_len);
109109
std::vector<float> col_ids = linspace<float>(w_offset, w_len - 1 + w_offset, w_len);
110-
111110
for (int i = 0; i < h_len; ++i) {
112111
for (int j = 0; j < w_len; ++j) {
113112
img_ids[i * w_len + j][0] = index;
@@ -181,10 +180,10 @@ namespace Rope {
181180
const std::vector<ggml_tensor*>& ref_latents,
182181
bool increase_ref_index,
183182
float ref_index_scale,
184-
int base_offset = 0) {
183+
int base_offset = 0) {
185184
std::vector<std::vector<float>> ids;
186-
uint64_t curr_h_offset = base_offset;
187-
uint64_t curr_w_offset = base_offset;
185+
uint64_t curr_h_offset = 0;
186+
uint64_t curr_w_offset = 0;
188187
int index = start_index;
189188
for (ggml_tensor* ref : ref_latents) {
190189
uint64_t h_offset = 0;
@@ -203,8 +202,8 @@ namespace Rope {
203202
bs,
204203
axes_dim_num,
205204
static_cast<int>(index * ref_index_scale),
206-
h_offset,
207-
w_offset);
205+
h_offset + base_offset,
206+
w_offset + base_offset);
208207
ids = concat_ids(ids, ref_ids, bs);
209208

210209
if (increase_ref_index) {

0 commit comments

Comments
 (0)