@@ -856,6 +856,15 @@ struct vk_subbuffer {
856856 }
857857};
858858
859+ // vk_event is used for the event-related backend interfaces. It uses 'event' for
860+ // event_wait and 'fence' for event_synchronize. Polling on an event for
861+ // event_synchronize wouldn't be sufficient to wait for command buffers to complete,
862+ // and would lead to validation errors.
863+ struct vk_event {
864+ vk::Event event;
865+ vk::Fence fence;
866+ };
867+
859868struct vk_semaphore {
860869 vk::Semaphore s;
861870 uint64_t value;
@@ -2544,6 +2553,15 @@ static void ggml_vk_sync_buffers(ggml_backend_vk_context* ctx, vk_context& subct
25442553 );
25452554}
25462555
2556+ static void ggml_vk_set_event(vk_context& ctx, vk::Event& event) {
2557+ VK_LOG_DEBUG("ggml_vk_set_event()");
2558+
2559+ ctx->s->buffer.setEvent(
2560+ event,
2561+ ctx->p->q->stage_flags
2562+ );
2563+ }
2564+
25472565static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events) {
25482566 VK_LOG_DEBUG("ggml_vk_wait_events()");
25492567 if (events.empty()) {
@@ -6089,13 +6107,8 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont
60896107 }
60906108}
60916109
6092- static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, bool sync_staging = false) {
6110+ static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, bool sync_staging = false) {
60936111 VK_LOG_DEBUG("ggml_vk_buffer_write_2d_async(" << width << ", " << height << ")");
6094- // Buffer is already mapped
6095- if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
6096- std::cerr << "ggml_vulkan: buffer_write_async dst buffer is host_visible. Use synchronous write." << std::endl;
6097- GGML_ABORT("fatal error");
6098- }
60996112 // Check if src is pinned memory
61006113 vk_buffer buf = nullptr;
61016114 size_t buf_offset = 0;
@@ -6120,12 +6133,13 @@ static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz
61206133
61216134 ggml_vk_sync_buffers(nullptr, subctx);
61226135 subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices);
6123- return;
6136+ return true ;
61246137 }
61256138 VK_LOG_DEBUG("STAGING");
61266139
61276140 if (!sync_staging) {
6128- GGML_ABORT("Asynchronous write to non-pinned memory not supported");
6141+ // copy was not handled caller needs to fall back
6142+ return false;
61296143 }
61306144
61316145 // Staging buffer required
@@ -6149,9 +6163,10 @@ static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz
61496163 deferred_memcpy((uint8_t *)staging_buffer->ptr + i * width, (const uint8_t *) src + i * spitch, width, &subctx->in_memcpys);
61506164 }
61516165 }
6166+ return true;
61526167}
61536168
6154- static void ggml_vk_buffer_write_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t size, bool sync_staging = false) {
6169+ static bool ggml_vk_buffer_write_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t size, bool sync_staging = false) {
61556170 VK_LOG_DEBUG("ggml_vk_buffer_write_async(" << size << ")");
61566171 return ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, size, size, 1, sync_staging);
61576172}
@@ -6170,7 +6185,8 @@ static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void *
61706185
61716186 vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool);
61726187 ggml_vk_ctx_begin(dst->device, subctx);
6173- ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, width, height, true);
6188+ bool ret = ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, width, height, true);
6189+ GGML_ASSERT(ret);
61746190 ggml_vk_ctx_end(subctx);
61756191
61766192 for (auto& cpy : subctx->in_memcpys) {
@@ -12671,7 +12687,23 @@ static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor
1267112687
1267212688 vk_buffer buf = buf_ctx->dev_buffer;
1267312689
12674- ggml_vk_buffer_write_async(transfer_ctx, buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);
12690+ auto dst_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset;
12691+
12692+ bool ret = ggml_vk_buffer_write_async(transfer_ctx, buf, dst_offset, data, size);
12693+
12694+ if (!ret) {
12695+ ggml_vk_ensure_sync_staging_buffer(ctx, size);
12696+ ggml_vk_sync_buffers(nullptr, transfer_ctx);
12697+
12698+ vk::BufferCopy buffer_cpy;
12699+ buffer_cpy.srcOffset = 0;
12700+ buffer_cpy.dstOffset = dst_offset;
12701+ buffer_cpy.size = size;
12702+
12703+ transfer_ctx->s->buffer.copyBuffer(ctx->sync_staging->buffer, buf->buffer, { buffer_cpy });
12704+ deferred_memcpy(ctx->sync_staging->ptr, data, size, &transfer_ctx->in_memcpys);
12705+ ggml_vk_synchronize(ctx);
12706+ }
1267512707}
1267612708
1267712709static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
@@ -13678,11 +13710,58 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
1367813710 }
1367913711}
1368013712
13713+ static void ggml_backend_vk_event_record(ggml_backend_t backend, ggml_backend_event_t event) {
13714+ ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
13715+ vk_event *vkev = (vk_event *)event->context;
13716+
13717+ vk_context transfer_ctx;
13718+
13719+ if (ctx->transfer_ctx.expired()) {
13720+ // Initialize new transfer context
13721+ transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
13722+ ctx->transfer_ctx = transfer_ctx;
13723+ ggml_vk_ctx_begin(ctx->device, transfer_ctx);
13724+ } else {
13725+ transfer_ctx = ctx->transfer_ctx.lock();
13726+ }
13727+
13728+ // the backend interface doesn't have an explicit reset, so reset it here
13729+ // before we record the command to set it
13730+ ctx->device->device.resetEvent(vkev->event);
13731+ ctx->device->device.resetFences({ vkev->fence });
13732+
13733+ ggml_vk_set_event(transfer_ctx, vkev->event);
13734+
13735+ ggml_vk_ctx_end(transfer_ctx);
13736+
13737+ ggml_vk_submit(transfer_ctx, {vkev->fence});
13738+ ctx->submit_pending = true;
13739+ ctx->transfer_ctx.reset();
13740+ }
13741+
13742+ static void ggml_backend_vk_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {
13743+ ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
13744+ vk_event *vkev = (vk_event *)event->context;
13745+
13746+ vk_context transfer_ctx;
13747+
13748+ if (ctx->transfer_ctx.expired()) {
13749+ // Initialize new transfer context
13750+ transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
13751+ ctx->transfer_ctx = transfer_ctx;
13752+ ggml_vk_ctx_begin(ctx->device, transfer_ctx);
13753+ } else {
13754+ transfer_ctx = ctx->transfer_ctx.lock();
13755+ }
13756+
13757+ ggml_vk_wait_events(transfer_ctx, {vkev->event});
13758+ }
13759+
1368113760// TODO: enable async and synchronize
1368213761static ggml_backend_i ggml_backend_vk_interface = {
1368313762 /* .get_name = */ ggml_backend_vk_name,
1368413763 /* .free = */ ggml_backend_vk_free,
13685- /* .set_tensor_async = */ NULL, // ggml_backend_vk_set_tensor_async,
13764+ /* .set_tensor_async = */ ggml_backend_vk_set_tensor_async,
1368613765 /* .get_tensor_async = */ ggml_backend_vk_get_tensor_async,
1368713766 /* .cpy_tensor_async = */ NULL, // ggml_backend_vk_cpy_tensor_async,
1368813767 /* .synchronize = */ ggml_backend_vk_synchronize,
@@ -13691,8 +13770,8 @@ static ggml_backend_i ggml_backend_vk_interface = {
1369113770 /* .graph_plan_update = */ NULL,
1369213771 /* .graph_plan_compute = */ NULL,
1369313772 /* .graph_compute = */ ggml_backend_vk_graph_compute,
13694- /* .event_record = */ NULL ,
13695- /* .event_wait = */ NULL ,
13773+ /* .event_record = */ ggml_backend_vk_event_record ,
13774+ /* .event_wait = */ ggml_backend_vk_event_wait ,
1369613775 /* .graph_optimize = */ ggml_vk_graph_optimize,
1369713776};
1369813777
@@ -13867,10 +13946,10 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml
1386713946 props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str();
1386813947 ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total);
1386913948 props->caps = {
13870- /* .async = */ false ,
13949+ /* .async = */ true ,
1387113950 /* .host_buffer = */ true,
1387213951 /* .buffer_from_host_ptr = */ false,
13873- /* .events = */ false ,
13952+ /* .events = */ true ,
1387413953 };
1387513954}
1387613955
@@ -14402,6 +14481,46 @@ static bool ggml_backend_vk_device_offload_op(ggml_backend_dev_t dev, const ggml
1440214481 UNUSED(dev);
1440314482}
1440414483
14484+ static ggml_backend_event_t ggml_backend_vk_device_event_new(ggml_backend_dev_t dev) {
14485+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
14486+ auto device = ggml_vk_get_device(ctx->device);
14487+
14488+ vk_event *vkev = new vk_event;
14489+ if (!vkev) {
14490+ return nullptr;
14491+ }
14492+
14493+ // The event/fence is expected to initially be in the signaled state.
14494+ vkev->event = device->device.createEvent({});
14495+ vkev->fence = device->device.createFence({vk::FenceCreateFlagBits::eSignaled});
14496+ device->device.setEvent(vkev->event);
14497+
14498+ return new ggml_backend_event {
14499+ /* .device = */ dev,
14500+ /* .context = */ vkev,
14501+ };
14502+ }
14503+
14504+ static void ggml_backend_vk_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) {
14505+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
14506+ auto device = ggml_vk_get_device(ctx->device);
14507+
14508+ vk_event *vkev = (vk_event *)event->context;
14509+
14510+ device->device.destroyFence(vkev->fence);
14511+ device->device.destroyEvent(vkev->event);
14512+ delete vkev;
14513+ delete event;
14514+ }
14515+
14516+ static void ggml_backend_vk_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) {
14517+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
14518+ auto device = ggml_vk_get_device(ctx->device);
14519+ vk_event *vkev = (vk_event *)event->context;
14520+
14521+ VK_CHECK(device->device.waitForFences({ vkev->fence }, true, UINT64_MAX), "event_synchronize");
14522+ }
14523+
1440514524static const struct ggml_backend_device_i ggml_backend_vk_device_i = {
1440614525 /* .get_name = */ ggml_backend_vk_device_get_name,
1440714526 /* .get_description = */ ggml_backend_vk_device_get_description,
@@ -14415,9 +14534,9 @@ static const struct ggml_backend_device_i ggml_backend_vk_device_i = {
1441514534 /* .supports_op = */ ggml_backend_vk_device_supports_op,
1441614535 /* .supports_buft = */ ggml_backend_vk_device_supports_buft,
1441714536 /* .offload_op = */ ggml_backend_vk_device_offload_op,
14418- /* .event_new = */ NULL ,
14419- /* .event_free = */ NULL ,
14420- /* .event_synchronize = */ NULL ,
14537+ /* .event_new = */ ggml_backend_vk_device_event_new ,
14538+ /* .event_free = */ ggml_backend_vk_device_event_free ,
14539+ /* .event_synchronize = */ ggml_backend_vk_device_event_synchronize ,
1442114540};
1442214541
1442314542static const char * ggml_backend_vk_reg_get_name(ggml_backend_reg_t reg) {
0 commit comments