Skip to content

Commit a907fe2

Browse files
committed
correct rope offset for image tokens
stuff
1 parent 37c5e3e commit a907fe2

File tree

3 files changed

+18
-13
lines changed

3 files changed

+18
-13
lines changed

ggml_extend.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2238,15 +2238,15 @@ class SplitLinear : public Linear {
22382238
forward_params.linear.scale = scale;
22392239
return ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, x, w, b, prefix, forward_params);
22402240
}
2241-
auto x0 = ggml_ext_linear(ctx->ggml_ctx, x, w, b, force_prec_f32, scale);
2241+
auto out = ggml_ext_linear(ctx->ggml_ctx, x, w, b, force_prec_f32, scale);
22422242
for (int i = 1; i < out_features_vec.size(); i++) {
2243-
auto wi = params["weight." + std::to_string(i)];
2244-
auto bi = bias ? params["bias." + std::to_string(i)] : nullptr;
2245-
auto xi = ggml_ext_linear(ctx->ggml_ctx, x, wi, bi, force_prec_f32, scale);
2246-
x0 = ggml_concat(ctx->ggml_ctx, x0, xi, 0);
2243+
auto wi = params["weight." + std::to_string(i)];
2244+
auto bi = bias ? params["bias." + std::to_string(i)] : nullptr;
2245+
auto curr_out = ggml_ext_linear(ctx->ggml_ctx, x, wi, bi, force_prec_f32, scale);
2246+
out = ggml_concat(ctx->ggml_ctx, out, curr_out, 0);
22472247
}
22482248

2249-
return x0;
2249+
return out;
22502250
}
22512251
};
22522252

rope.hpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -180,10 +180,11 @@ namespace Rope {
180180
int start_index,
181181
const std::vector<ggml_tensor*>& ref_latents,
182182
bool increase_ref_index,
183-
float ref_index_scale) {
183+
float ref_index_scale,
184+
int base_offset = 0) {
184185
std::vector<std::vector<float>> ids;
185-
uint64_t curr_h_offset = 0;
186-
uint64_t curr_w_offset = 0;
186+
uint64_t curr_h_offset = base_offset;
187+
uint64_t curr_w_offset = base_offset;
187188
int index = start_index;
188189
for (ggml_tensor* ref : ref_latents) {
189190
uint64_t h_offset = 0;
@@ -227,15 +228,15 @@ namespace Rope {
227228
bool increase_ref_index,
228229
float ref_index_scale,
229230
bool is_longcat) {
230-
int start_index = is_longcat ? 1 : 0;
231+
int x_index = is_longcat ? 1 : 0;
231232

232233
auto txt_ids = is_longcat ? gen_longcat_txt_ids(bs, context_len, axes_dim_num) : gen_flux_txt_ids(bs, context_len, axes_dim_num, txt_arange_dims);
233234
int offset = is_longcat ? context_len : 0;
234-
auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num, start_index, offset, offset);
235+
auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num, x_index, offset, offset);
235236

236237
auto ids = concat_ids(txt_ids, img_ids, bs);
237238
if (ref_latents.size() > 0) {
238-
auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, start_index + 1, ref_latents, increase_ref_index, ref_index_scale);
239+
auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, x_index + 1, ref_latents, increase_ref_index, ref_index_scale, offset);
239240
ids = concat_ids(ids, refs_ids, bs);
240241
}
241242
return ids;

stable-diffusion.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,9 @@ class StableDiffusionGGML {
456456
sd_ctx_params->chroma_use_dit_mask);
457457
} else if (sd_version_is_longcat(version)) {
458458
bool enable_vision = false;
459+
if (!vae_decode_only) {
460+
enable_vision = true;
461+
}
459462
cond_stage_model = std::make_shared<LLMEmbedder>(clip_backend,
460463
offload_params_to_cpu,
461464
tensor_storage_map,
@@ -850,7 +853,7 @@ class StableDiffusionGGML {
850853
flow_shift = 1.15f;
851854
}
852855
}
853-
if(sd_version_is_longcat(version)) {
856+
if (sd_version_is_longcat(version)) {
854857
flow_shift = 3.0f;
855858
}
856859
}
@@ -2244,6 +2247,7 @@ class StableDiffusionGGML {
22442247
sd_version_is_qwen_image(version) ||
22452248
sd_version_is_wan(version) ||
22462249
sd_version_is_flux2(version) ||
2250+
sd_version_is_longcat(version) ||
22472251
version == VERSION_CHROMA_RADIANCE) {
22482252
latent = vae_output;
22492253
} else if (version == VERSION_SD1_PIX2PIX) {

0 commit comments

Comments
 (0)