88
99#include < executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
1010
11+ #include < executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
1112#include < executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
1213
1314#include < executorch/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h>
1920
2021namespace vkcompute {
2122
23+ enum class Conv2dMethod : uint8_t {
24+ Depthwise,
25+ Pointwise,
26+ SlidingWindow,
27+ Transposed,
28+ };
29+
2230void resize_conv2d_node (
2331 ComputeGraph* graph,
2432 const std::vector<ArgGroup>& args,
@@ -114,13 +122,6 @@ ValueRef prepack_biases(
114122 return v;
115123}
116124
117- enum class Conv2dMethod : uint8_t {
118- Depthwise,
119- Pointwise,
120- SlidingWindow,
121- Transposed,
122- };
123-
124125vkapi::ShaderInfo get_conv2d_shader (
125126 ComputeGraph& graph,
126127 const ValueRef out,
@@ -327,6 +328,108 @@ utils::uvec3 create_conv2d_global_wg_size(
327328 }
328329}
329330
331+ // Custom global workgroup size function for conv2d
332+ utils::uvec3 conv2d_global_wg_size (
333+ ComputeGraph* graph,
334+ const vkapi::ShaderInfo& shader,
335+ const std::vector<ArgGroup>& args,
336+ const std::vector<ValueRef>& resize_args) {
337+ const ValueRef out = args.at (0 ).refs .at (0 );
338+ const ValueRef weight_data = resize_args.at (0 );
339+
340+ // Determine method from shader name
341+ Conv2dMethod method;
342+ if (shader.kernel_name .find (" conv2d_dw" ) != std::string::npos) {
343+ method = Conv2dMethod::Depthwise;
344+ } else if (
345+ shader.kernel_name .find (" conv2d_pw" ) != std::string::npos ||
346+ (shader.kernel_name .find (" conv2d" ) != std::string::npos &&
347+ shader.kernel_name .find (" conv_transpose2d" ) == std::string::npos)) {
348+ // Check if it's pointwise by examining weight sizes
349+ const auto & weight_sizes = graph->get_tref (weight_data)->sizes ;
350+ if (weight_sizes.at (2 ) == 1 && weight_sizes.at (3 ) == 1 ) {
351+ method = Conv2dMethod::Pointwise;
352+ } else {
353+ method = Conv2dMethod::SlidingWindow;
354+ }
355+ } else if (shader.kernel_name .find (" conv_transpose2d" ) != std::string::npos) {
356+ method = Conv2dMethod::Transposed;
357+ } else {
358+ method = Conv2dMethod::SlidingWindow;
359+ }
360+
361+ // Determine stride_equals_dilation from shader name
362+ bool stride_equals_dilation =
363+ shader.kernel_name .find (" _sned" ) == std::string::npos;
364+
365+ utils::uvec3 wg_size = create_conv2d_global_wg_size (
366+ *graph, method, out, weight_data, stride_equals_dilation);
367+
368+ if (method == Conv2dMethod::Depthwise || method == Conv2dMethod::Pointwise) {
369+ wg_size = {wg_size[0 ] * wg_size[1 ], wg_size[2 ], 1 };
370+ }
371+
372+ return wg_size;
373+ }
374+
375+ // Custom local workgroup size function for conv2d
376+ utils::uvec3 conv2d_local_wg_size (
377+ ComputeGraph* graph,
378+ const vkapi::ShaderInfo& shader,
379+ const utils::uvec3& global_workgroup_size,
380+ const std::vector<ArgGroup>& args,
381+ const std::vector<ValueRef>& resize_args) {
382+ (void )args;
383+ (void )resize_args;
384+
385+ // Determine method from shader name
386+ Conv2dMethod method;
387+ if (shader.kernel_name .find (" conv2d_dw" ) != std::string::npos) {
388+ method = Conv2dMethod::Depthwise;
389+ } else if (
390+ shader.kernel_name .find (" conv2d_pw" ) != std::string::npos ||
391+ (shader.kernel_name .find (" conv2d" ) != std::string::npos &&
392+ shader.kernel_name .find (" conv_transpose2d" ) == std::string::npos)) {
393+ method = Conv2dMethod::Pointwise;
394+ } else {
395+ method = Conv2dMethod::SlidingWindow;
396+ }
397+
398+ if (method == Conv2dMethod::Pointwise) {
399+ uint32_t local_wg_size_y = 1 ;
400+ if (global_workgroup_size[1 ] % 8 == 0 ) {
401+ local_wg_size_y = 8 ;
402+ } else if (global_workgroup_size[1 ] % 4 == 0 ) {
403+ local_wg_size_y = 4 ;
404+ } else if (global_workgroup_size[1 ] % 2 == 0 ) {
405+ local_wg_size_y = 2 ;
406+ }
407+ return {64 / local_wg_size_y, local_wg_size_y, 1 };
408+ } else if (method == Conv2dMethod::Depthwise) {
409+ return {64 , 1 , 1 };
410+ } else {
411+ return graph->create_local_wg_size (global_workgroup_size);
412+ }
413+ }
414+
415+ // Custom global workgroup size function for conv1d
416+ utils::uvec3 conv1d_global_wg_size (
417+ ComputeGraph* graph,
418+ const vkapi::ShaderInfo& shader,
419+ const std::vector<ArgGroup>& args,
420+ const std::vector<ValueRef>& resize_args) {
421+ (void )shader;
422+ (void )resize_args;
423+ const ValueRef out = args.at (0 ).refs .at (0 );
424+
425+ return {// out length
426+ graph->size_at <uint32_t >(-1 , out),
427+ // out channels
428+ static_cast <uint32_t >(graph->size_at <int64_t >(-2 , out)),
429+ // out batches
430+ utils::div_up_4 (graph->size_at <uint32_t >(-3 , out))};
431+ }
432+
330433void add_conv2d_node (
331434 ComputeGraph& graph,
332435 const ValueRef in,
@@ -486,11 +589,11 @@ void add_conv2d_node(
486589 };
487590 }
488591
489- graph.execute_nodes ().emplace_back (new DispatchNode (
592+ graph.execute_nodes ().emplace_back (new DynamicDispatchNode (
490593 graph,
491594 shader,
492- wg_size ,
493- local_wg_size ,
595+ conv2d_global_wg_size ,
596+ conv2d_local_wg_size ,
494597 // Inputs and Outputs
495598 {{out, vkapi::kWrite }, {{in, arg_weight, arg_bias}, vkapi::kRead }},
496599 // Shader params buffers
@@ -560,15 +663,6 @@ void add_conv1d_node(
560663 const int32_t out_group_size =
561664 static_cast <int64_t >(out_channels / groups_val);
562665
563- const utils::uvec3 global_size = {
564- // out length
565- graph.size_at <uint32_t >(-1 , out),
566- // out channels
567- static_cast <uint32_t >(out_channels),
568- // out batches
569- utils::div_up_4 (graph.size_at <uint32_t >(-3 , out))};
570- const utils::uvec3 local_size = graph.create_local_wg_size (global_size);
571-
572666 Kernel1dParams kernel_params = {
573667 kernel_size,
574668 stride_size,
@@ -587,11 +681,11 @@ void add_conv1d_node(
587681
588682 add_dtype_suffix (kernel_name, graph.dtype_of (out));
589683
590- graph.execute_nodes ().emplace_back (new DispatchNode (
684+ graph.execute_nodes ().emplace_back (new DynamicDispatchNode (
591685 graph,
592686 VK_KERNEL_FROM_STR (kernel_name),
593- global_size ,
594- local_size ,
687+ conv1d_global_wg_size ,
688+ default_pick_local_wg_size ,
595689 // Inputs and Outputs
596690 {{out, vkapi::kWrite }, {{in, arg_weight, arg_bias}, vkapi::kRead }},
597691 // Shader params buffers
0 commit comments