@@ -1185,6 +1185,14 @@ struct vk_staging_memcpy {
11851185 size_t n;
11861186};
11871187
1188+ struct vk_staging_memset {
1189+ vk_staging_memset(void * _dst, uint32_t _val, size_t _n) : dst(_dst), val(_val), n(_n) {}
1190+
1191+ void * dst;
1192+ uint32_t val;
1193+ size_t n;
1194+ };
1195+
11881196struct vk_context_struct {
11891197 vk_submission * s;
11901198 std::vector<vk_sequence> seqs;
@@ -1193,6 +1201,7 @@ struct vk_context_struct {
11931201
11941202 std::vector<vk_staging_memcpy> in_memcpys;
11951203 std::vector<vk_staging_memcpy> out_memcpys;
1204+ std::vector<vk_staging_memset> memsets;
11961205
11971206 vk_command_pool * p {};
11981207};
@@ -5196,6 +5205,14 @@ static void deferred_memcpy(void * dst, const void * src, size_t size, std::vect
51965205 }
51975206}
51985207
5208+ static void deferred_memset(void * dst, uint32_t val, size_t size, std::vector<vk_staging_memset>* memsets = nullptr) {
5209+ if (memsets == nullptr) {
5210+ memset(dst, val, size);
5211+ } else {
5212+ memsets->emplace_back(dst, val, size);
5213+ }
5214+ }
5215+
51995216static void ggml_vk_ensure_sync_staging_buffer(vk_device& device, size_t size) {
52005217 if (device->sync_staging == nullptr || device->sync_staging->size < size) {
52015218 VK_LOG_MEMORY("ggml_vk_ensure_sync_staging_buffer(" << size << ")");
@@ -5391,6 +5408,10 @@ static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void *
53915408 memcpy(cpy.dst, cpy.src, cpy.n);
53925409 }
53935410
5411+ for (auto& mset : subctx->memsets) {
5412+ memset(mset.dst, mset.val, mset.n);
5413+ }
5414+
53945415 ggml_vk_submit(subctx, dst->device->fence);
53955416 VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_buffer_write_2d waitForFences");
53965417 dst->device->device.resetFences({ dst->device->fence });
@@ -5530,12 +5551,25 @@ static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& sr
55305551static void ggml_vk_buffer_memset_async(vk_context& ctx, vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
55315552 VK_LOG_DEBUG("ggml_vk_buffer_memset_async(" << offset << ", " << c << ", " << size << ")");
55325553
5554+ if (dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible &&
5555+ dst->device->uma) {
5556+ deferred_memset((uint8_t*)dst->ptr + offset, c, size, &ctx->memsets);
5557+ return;
5558+ }
5559+
5560+ // Fall back to GPU fillBuffer for non-UMA or non-host-visible buffers
55335561 ctx->s->buffer.fillBuffer(dst->buffer, offset, size, c);
55345562}
55355563
55365564static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
55375565 VK_LOG_DEBUG("ggml_vk_buffer_memset(" << offset << ", " << c << ", " << size << ")");
55385566
5567+ if (dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible &&
5568+ dst->device->uma) {
5569+ memset((uint8_t*)dst->ptr + offset, c, size);
5570+ return;
5571+ }
5572+
55395573 std::lock_guard<std::recursive_mutex> guard(dst->device->mutex);
55405574 vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool);
55415575 ggml_vk_ctx_begin(dst->device, subctx);
@@ -11170,6 +11204,10 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
1117011204 memcpy(cpy.dst, cpy.src, cpy.n);
1117111205 }
1117211206
11207+ for (auto& mset : subctx->memsets) {
11208+ memset(mset.dst, mset.val, mset.n);
11209+ }
11210+
1117311211 if (almost_ready && !ctx->almost_ready_fence_pending && !use_fence) {
1117411212 ggml_vk_submit(subctx, ctx->almost_ready_fence);
1117511213 ctx->almost_ready_fence_pending = true;
@@ -11192,6 +11230,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
1119211230 }
1119311231 subctx->in_memcpys.clear();
1119411232 subctx->out_memcpys.clear();
11233+ subctx->memsets.clear();
1119511234 }
1119611235
1119711236 return true;
0 commit comments