@@ -33,6 +33,25 @@ void resize_reduce_node(
3333 graph->virtual_resize (out, new_sizes);
3434}
3535
36+ void resize_reduce2d_node (
37+ ComputeGraph* graph,
38+ const std::vector<ArgGroup>& args,
39+ const std::vector<ValueRef>& resize_args) {
40+ vTensorPtr out = graph->get_tensor (args[0 ].refs [0 ]);
41+ vTensorPtr in = graph->get_tensor (args[1 ].refs [0 ]);
42+
43+ // Extract the dimensions to reduce over
44+ const std::vector<int64_t > dims_list =
45+ graph->extract_int_or_symint_list (resize_args.at (0 ));
46+ int32_t reduce_dim1_nchw = dims_list[0 ];
47+ int32_t reduce_dim2_nchw = dims_list[1 ];
48+
49+ std::vector<int64_t > new_sizes = in->sizes ();
50+ new_sizes.at (normalize (reduce_dim1_nchw, new_sizes.size ())) = 1 ;
51+ new_sizes.at (normalize (reduce_dim2_nchw, new_sizes.size ())) = 1 ;
52+ out->virtual_resize (new_sizes);
53+ }
54+
3655utils::uvec3 reduce_global_wg_size (
3756 ComputeGraph* graph,
3857 const vkapi::ShaderInfo& shader,
@@ -138,15 +157,101 @@ void add_reduce_node(
138157 resize_reduce_node));
139158}
140159
160+ void add_reduce2d_node (
161+ ComputeGraph& graph,
162+ const ValueRef in,
163+ const ValueRef dims_ref,
164+ const ValueRef out,
165+ const std::string& op_name) {
166+ VK_CHECK_COND (
167+ !graph.is_buffer_storage (in) && !graph.is_buffer_storage (out),
168+ " Vulkan reduction only supports texture storage" );
169+
170+ const int64_t ndim = graph.dim_of (in);
171+
172+ // Extract the two dimensions to reduce over
173+ const std::vector<int64_t > dims_list =
174+ graph.extract_int_or_symint_list (dims_ref);
175+ VK_CHECK_COND (
176+ dims_list.size () == 2 , " reduce2d requires exactly 2 dimensions" );
177+
178+ int32_t reduce_dim1 = normalize (dims_list[0 ], ndim);
179+ int32_t reduce_dim2 = normalize (dims_list[1 ], ndim);
180+
181+ // Convert to WHCN format
182+ reduce_dim1 = nchw_dim_to_whcn_dim (reduce_dim1, ndim);
183+ reduce_dim2 = nchw_dim_to_whcn_dim (reduce_dim2, ndim);
184+
185+ // Check that none of the reduction dims are packed
186+ VK_CHECK_COND (graph.packed_dim_of (in) != reduce_dim1);
187+ VK_CHECK_COND (graph.packed_dim_of (in) != reduce_dim2);
188+ VK_CHECK_COND (graph.packed_dim_of (out) != reduce_dim1);
189+ VK_CHECK_COND (graph.packed_dim_of (out) != reduce_dim2);
190+
191+ // Check that the concat dim is not one of the reduction dims
192+ if (graph.dim_of (in) == 4 && graph.size_at <int >(0 , in) > 1 ) {
193+ VK_CHECK_COND (graph.concat_dim_of (in) != reduce_dim1);
194+ VK_CHECK_COND (graph.concat_dim_of (in) != reduce_dim2);
195+ VK_CHECK_COND (graph.concat_dim_of (out) != reduce_dim1);
196+ VK_CHECK_COND (graph.concat_dim_of (out) != reduce_dim2);
197+ }
198+
199+ std::string kernel_name = op_name + " 2d" ; // Add "2d" suffix
200+ kernel_name.reserve (kShaderNameReserve );
201+ add_dtype_suffix (kernel_name, graph.dtype_of (out));
202+
203+ // Calculate group_dim for specialization constants (use remaining dimension)
204+ int32_t group_dim = 0 ;
205+ for (int i = 0 ; i < 3 ; i++) {
206+ if (i != reduce_dim1 && i != reduce_dim2) {
207+ group_dim = i;
208+ break ;
209+ }
210+ }
211+
212+ const ValueRef reduce_dim1_whcn_ref =
213+ graph.get_or_add_value_for_int (reduce_dim1);
214+ const ValueRef reduce_dim2_whcn_ref =
215+ graph.get_or_add_value_for_int (reduce_dim2);
216+ const ValueRef group_dim_whcn_ref = graph.get_or_add_value_for_int (group_dim);
217+
218+ graph.execute_nodes ().emplace_back (new DynamicDispatchNode (
219+ graph,
220+ VK_KERNEL_FROM_STR (kernel_name),
221+ reduce_global_wg_size,
222+ reduce_local_wg_size,
223+ // Inputs and Outputs
224+ {{out, vkapi::kWrite }, {in, vkapi::kRead }},
225+ // Shader params buffers
226+ {graph.logical_limits_ubo (in), graph.sizes_ubo (in)},
227+ // Push Constants
228+ {},
229+ // Specialization Constants
230+ {graph.packed_dim_of (out), reduce_dim1, reduce_dim2, group_dim},
231+ // Resize Args
232+ {dims_ref,
233+ reduce_dim1_whcn_ref,
234+ reduce_dim2_whcn_ref,
235+ group_dim_whcn_ref},
236+ // Resizing Logic
237+ resize_reduce2d_node));
238+ }
239+
141240#define DEFINE_REDUCE_FN (op_name, out_arg_idx ) \
142241 void op_name (ComputeGraph& graph, const std::vector<ValueRef>& args) { \
143242 const std::vector<int64_t > dims_list = \
144243 graph.extract_int_or_symint_list (args[1 ]); \
145- VK_CHECK_COND (dims_list.size () == 1 ); \
146- const int64_t dim_val = dims_list.at (0 ); \
147- const ValueRef dim_ref = graph.get_or_add_value_for_int (dim_val); \
148- return add_reduce_node ( \
149- graph, args[0 ], dim_ref, args[out_arg_idx], #op_name); \
244+ if (dims_list.size () == 1 ) { \
245+ const int64_t dim_val = dims_list.at (0 ); \
246+ const ValueRef dim_ref = graph.get_or_add_value_for_int (dim_val); \
247+ return add_reduce_node ( \
248+ graph, args[0 ], dim_ref, args[out_arg_idx], #op_name); \
249+ } \
250+ if (dims_list.size () == 2 ) { \
251+ return add_reduce2d_node ( \
252+ graph, args[0 ], args[1 ], args[out_arg_idx], #op_name); \
253+ } \
254+ VK_CHECK_COND (false , " Only 1 or 2 dimensions supported" ); \
150255 }
151256
152257DEFINE_REDUCE_FN (sum, 4 )
0 commit comments