Skip to content

Commit e3b35dd

Browse files
authored
vulkan: Extend rope fusions to allow mrope (#18264)
Extend the test-backend-ops tests as well.
1 parent 6ce863c commit e3b35dd

File tree

4 files changed

+82
-25
lines changed

4 files changed

+82
-25
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -731,7 +731,7 @@ struct vk_device_struct {
731731

732732
vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16, pipeline_rope_norm_f32_f16;
733733
vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16, pipeline_rope_neox_f32_f16;
734-
vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
734+
vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16, pipeline_rope_multi_f32_f16;
735735
vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
736736
vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
737737
vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
@@ -4077,6 +4077,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
40774077

40784078
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_rte_len, rope_norm_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
40794079
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_rte_len, rope_neox_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
4080+
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32_f16, "rope_multi_f32_f16", rope_multi_f32_f16_rte_len, rope_multi_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
40804081
} else {
40814082
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
40824083
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
@@ -4085,6 +4086,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
40854086

40864087
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_len, rope_norm_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
40874088
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_len, rope_neox_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
4089+
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32_f16, "rope_multi_f32_f16", rope_multi_f32_f16_len, rope_multi_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
40884090
}
40894091

40904092
for (uint32_t i = 0; i < num_argsort_pipelines; ++i) {
@@ -8680,6 +8682,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
86808682
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
86818683
return ctx->device->pipeline_rope_multi_f32;
86828684
}
8685+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
8686+
return ctx->device->pipeline_rope_multi_f32_f16;
8687+
}
86838688
if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
86848689
return ctx->device->pipeline_rope_multi_f16;
86858690
}
@@ -13076,9 +13081,9 @@ static bool ggml_vk_can_fuse_rope_set_rows(ggml_backend_vk_context * ctx, const
1307613081
return false;
1307713082
}
1307813083

13079-
// Only norm/neox shaders have the fusion code
13084+
// Only norm/neox/mrope shaders have the fusion code
1308013085
const int mode = ((const int32_t *) rope->op_params)[2];
13081-
if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX) {
13086+
if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX && mode != GGML_ROPE_TYPE_MROPE) {
1308213087
return false;
1308313088
}
1308413089

ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ void rope_norm(const uint i0, const uint i1, rope_params p) {
4949
uint idst = i1*ne0 + i0;
5050
const uint ix = rope_a_coord(i0, i01, i02, p);
5151

52-
// Fusion optimization: ROPE + VIEW + SET_ROWS..
53-
// The rope output is viewed as a 1D tensor and offset based on a row index in data_i.
52+
// Fusion optimization: ROPE + VIEW + SET_ROWS.
53+
// The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i.
5454
if (p.set_rows_stride != 0) {
5555
idst = i01*ne0 + i0;
5656
idst += rope_data_i[i02].x * p.set_rows_stride;
@@ -91,7 +91,7 @@ void rope_neox(const uint i0, const uint i1, rope_params p) {
9191
uint idst = i1*ne0 + i0/2;
9292
const uint ix = rope_a_coord(i0/2, i01, i02, p);
9393

94-
// Fusion optimization: ROPE + VIEW + SET_ROWS..
94+
// Fusion optimization: ROPE + VIEW + SET_ROWS.
9595
// The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i.
9696
if (p.set_rows_stride != 0) {
9797
idst = i01*ne0 + i0/2;
@@ -132,9 +132,16 @@ void rope_multi(const uint i0, const uint i1, rope_params p) {
132132
const uint i01 = i1 % ne1;
133133
const uint i02 = i1 / ne1;
134134

135-
const uint idst = i1*ne0 + i0/2;
135+
uint idst = i1*ne0 + i0/2;
136136
const uint ix = rope_a_coord(i0/2, i01, i02, p);
137137

138+
// Fusion optimization: ROPE + VIEW + SET_ROWS.
139+
// The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i.
140+
if (p.set_rows_stride != 0) {
141+
idst = i01*ne0 + i0/2;
142+
idst += rope_data_i[i02].x * p.set_rows_stride;
143+
}
144+
138145
if (i0 >= p.n_dims) {
139146
rope_data_d[idst + i0/2 + 0] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 0]);
140147
rope_data_d[idst + i0/2 + 1] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 1]);

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -927,6 +927,8 @@ void process_shaders() {
927927
string_to_spv("rope_multi_f32", "rope_multi.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}});
928928
string_to_spv("rope_multi_f16", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
929929
string_to_spv("rope_multi_f16_rte", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
930+
string_to_spv("rope_multi_f32_f16", "rope_multi.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}});
931+
string_to_spv("rope_multi_f32_f16_rte", "rope_multi.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
930932

931933
string_to_spv("rope_vision_f32", "rope_vision.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}});
932934
string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});

tests/test-backend-ops.cpp

Lines changed: 61 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2329,11 +2329,13 @@ struct test_set_rows : public test_case {
23292329
struct test_rope_set_rows : public test_case {
23302330
const ggml_type type;
23312331
const ggml_type type_idx;
2332-
const std::array<int64_t, 4> ne;
2332+
const std::array<int64_t, 4> ne_a;
23332333
int mode;
2334+
const int n_ctx{512};
2335+
const int n_dims{128};
23342336

23352337
std::string vars() override {
2336-
return VARS_TO_STR4(type, type_idx, ne, mode);
2338+
return VARS_TO_STR4(type, type_idx, ne_a, mode);
23372339
}
23382340

23392341
std::string op_desc(ggml_tensor * t) override {
@@ -2345,24 +2347,51 @@ struct test_rope_set_rows : public test_case {
23452347

23462348
test_rope_set_rows(ggml_type type,
23472349
ggml_type type_idx,
2348-
std::array<int64_t, 4> ne,
2350+
std::array<int64_t, 4> ne_a,
23492351
int mode)
2350-
: type(type), type_idx(type_idx), ne(ne), mode(mode) {}
2352+
: type(type), type_idx(type_idx), ne_a(ne_a), mode(mode) {}
23512353

23522354
ggml_tensor * build_graph(ggml_context * ctx) override {
2353-
ggml_tensor * src = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne[0], ne[1], ne[2], 1);
2354-
ggml_set_name(src, "src");
2355+
ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne_a[0], ne_a[1], ne_a[2], 1);
2356+
ggml_set_name(a, "a");
23552357

2356-
ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne[2]);
2358+
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
2359+
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
23572360

2358-
ggml_tensor * rope = ggml_rope(ctx, src, pos, ne[0], mode);
2361+
ggml_tensor * pos;
2362+
if (is_mrope || is_vision) {
2363+
pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne_a[2] * 4);
2364+
} else {
2365+
pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne_a[2]);
2366+
}
2367+
ggml_set_name(pos, "pos");
2368+
2369+
float fs = 1.4245f;
2370+
float ef = 0.7465f;
2371+
float af = 1.4245f;
2372+
ggml_tensor * freq = nullptr;
2373+
2374+
ggml_tensor * rope = nullptr;
2375+
if (is_mrope) {
2376+
if (is_vision) {
2377+
GGML_ASSERT(n_dims/4 > 0);
2378+
int rope_sections[4] = {n_dims/4, n_dims/4, 0, 0}; // Vision-RoPE only use first two dimension for image (x, y) coordinate
2379+
rope = ggml_rope_multi(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
2380+
} else {
2381+
GGML_ASSERT(n_dims/3 > 0);
2382+
int rope_sections[4] = {n_dims/3, n_dims/3, n_dims/3, 0};
2383+
rope = ggml_rope_multi(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
2384+
}
2385+
} else {
2386+
rope = ggml_rope(ctx, a, pos, ne_a[0], mode);
2387+
}
23592388

2360-
ggml_tensor * view = ggml_view_2d(ctx, rope, ne[0] * ne[1], ne[2], rope->nb[2], 0);
2389+
ggml_tensor * view = ggml_view_2d(ctx, rope, ne_a[0] * ne_a[1], ne_a[2], rope->nb[2], 0);
23612390

2362-
ggml_tensor * dst = ggml_new_tensor_4d(ctx, type, ne[0] * ne[1], ne[2] * ne[3], 1, 1);
2391+
ggml_tensor * dst = ggml_new_tensor_4d(ctx, type, ne_a[0] * ne_a[1], ne_a[2] * ne_a[3], 1, 1);
23632392
ggml_set_name(dst, "dst");
23642393

2365-
ggml_tensor * row_idxs = ggml_new_tensor_3d(ctx, type_idx, ne[2], 1, 1);
2394+
ggml_tensor * row_idxs = ggml_new_tensor_3d(ctx, type_idx, ne_a[2], 1, 1);
23662395
ggml_set_name(row_idxs, "row_idxs");
23672396

23682397
ggml_tensor * out = ggml_set_rows(ctx, dst, view, row_idxs);
@@ -2373,14 +2402,26 @@ struct test_rope_set_rows : public test_case {
23732402

23742403
void initialize_tensors(ggml_context * ctx) override {
23752404
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
2376-
if (t->type == GGML_TYPE_I64 || t->type == GGML_TYPE_I32) {
2405+
if (strcmp(t->name, "row_idxs") == 0) {
23772406
if (ggml_is_view_op(t->op)) {
23782407
continue;
23792408
}
2380-
2381-
init_set_rows_row_ids(t, ne[2]);
2409+
init_set_rows_row_ids(t, ne_a[2]);
2410+
} else if (t->type == GGML_TYPE_I32) {
2411+
// pos
2412+
const int num_pos_ids = (mode & GGML_ROPE_TYPE_MROPE) ? ne_a[2] * 4 : ne_a[2];
2413+
std::vector<int> data(num_pos_ids);
2414+
for (int i = 0; i < num_pos_ids; i++) {
2415+
data[i] = rand() % n_ctx;
2416+
}
2417+
ggml_backend_tensor_set(t, data.data(), 0, num_pos_ids * sizeof(int));
23822418
} else {
2383-
init_tensor_uniform(t);
2419+
if (t->ne[0] == n_dims/2) {
2420+
// frequency factors in the range [0.9f, 1.1f]
2421+
init_tensor_uniform(t, 0.9f, 1.1f);
2422+
} else {
2423+
init_tensor_uniform(t);
2424+
}
23842425
}
23852426
}
23862427
}
@@ -6854,10 +6895,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
68546895
}
68556896
}
68566897

6857-
for (int mode : { GGML_ROPE_TYPE_NORMAL, GGML_ROPE_TYPE_NEOX }) {
6898+
for (int mode : { GGML_ROPE_TYPE_NORMAL, GGML_ROPE_TYPE_NEOX, GGML_ROPE_TYPE_MROPE, GGML_ROPE_TYPE_VISION }) {
68586899
for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
6859-
test_cases.emplace_back(new test_rope_set_rows(type, GGML_TYPE_I64, { 128, 32, 1, 100 }, mode));
6860-
test_cases.emplace_back(new test_rope_set_rows(type, GGML_TYPE_I64, { 128, 32, 512, 1 }, mode));
6900+
for (int ne2 : {1, 8, 512}) {
6901+
test_cases.emplace_back(new test_rope_set_rows(type, GGML_TYPE_I64, { 128, 32, ne2, 1 }, mode));
6902+
test_cases.emplace_back(new test_rope_set_rows(type, GGML_TYPE_I64, { 128, 32, ne2, 3 }, mode));
6903+
}
68616904
}
68626905
}
68636906

0 commit comments

Comments
 (0)