@@ -180,10 +180,11 @@ namespace Rope {
180180 int start_index,
181181 const std::vector<ggml_tensor*>& ref_latents,
182182 bool increase_ref_index,
183- float ref_index_scale) {
183+ float ref_index_scale,
184+ int base_offset = 0 ) {
184185 std::vector<std::vector<float >> ids;
185- uint64_t curr_h_offset = 0 ;
186- uint64_t curr_w_offset = 0 ;
186+ uint64_t curr_h_offset = base_offset ;
187+ uint64_t curr_w_offset = base_offset ;
187188 int index = start_index;
188189 for (ggml_tensor* ref : ref_latents) {
189190 uint64_t h_offset = 0 ;
@@ -227,15 +228,15 @@ namespace Rope {
227228 bool increase_ref_index,
228229 float ref_index_scale,
229230 bool is_longcat) {
230- int start_index = is_longcat ? 1 : 0 ;
231+ int x_index = is_longcat ? 1 : 0 ;
231232
232233 auto txt_ids = is_longcat ? gen_longcat_txt_ids (bs, context_len, axes_dim_num) : gen_flux_txt_ids (bs, context_len, axes_dim_num, txt_arange_dims);
233234 int offset = is_longcat ? context_len : 0 ;
234- auto img_ids = gen_flux_img_ids (h, w, patch_size, bs, axes_dim_num, start_index , offset, offset);
235+ auto img_ids = gen_flux_img_ids (h, w, patch_size, bs, axes_dim_num, x_index , offset, offset);
235236
236237 auto ids = concat_ids (txt_ids, img_ids, bs);
237238 if (ref_latents.size () > 0 ) {
238- auto refs_ids = gen_refs_ids (patch_size, bs, axes_dim_num, start_index + 1 , ref_latents, increase_ref_index, ref_index_scale);
239+ auto refs_ids = gen_refs_ids (patch_size, bs, axes_dim_num, x_index + 1 , ref_latents, increase_ref_index, ref_index_scale, offset );
239240 ids = concat_ids (ids, refs_ids, bs);
240241 }
241242 return ids;
0 commit comments