diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 19594002cf2..044f22a1f08 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -532,7 +532,7 @@ def register_reduce_op(features: OpFeatures): def check_reduce_node(node: torch.fx.Node) -> bool: dim_list = node.args[1] - if isinstance(dim_list, list) and len(dim_list) != 1: + if isinstance(dim_list, list) and len(dim_list) > 2: return False keepdim = node.args[2] diff --git a/backends/vulkan/runtime/graph/ops/glsl/addmm_naive_texture3d.glsl b/backends/vulkan/runtime/graph/ops/glsl/addmm_naive_texture3d.glsl index a4ed494fe6d..4cebde42a34 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/addmm_naive_texture3d.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/addmm_naive_texture3d.glsl @@ -10,6 +10,11 @@ #define PRECISION ${PRECISION} +// Define tile sizes for optimal memory access patterns +#define TILE_SIZE_M 8 // Tile size for output rows +#define TILE_SIZE_N 8 // Tile size for output columns +#define TILE_SIZE_K 8 // Tile size for K dimension + $if MAT2_IS_TRANSPOSED: #define MAT2_IS_TRANSPOSED @@ -31,6 +36,7 @@ $if HAS_BIAS: #include "indexing_utils.h" +// Workgroup size matches tile dimensions layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; ${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} @@ -50,6 +56,11 @@ $if HAS_BIAS: const lowp ivec4 bias_axis_map = unhash_axis_map(bias_layout); const lowp int bias_packed_dim = unhash_packed_dim(bias_layout); +// Shared memory for tiling - use smaller tiles to avoid memory issues +shared vec4 mat1_tile[TILE_SIZE_M][TILE_SIZE_K]; +shared vec4 mat2_tile[TILE_SIZE_K][TILE_SIZE_N]; +shared float out_tile[TILE_SIZE_M][TILE_SIZE_N]; + #ifdef HAS_BIAS vec4 get_bias_texel_W_packed(ivec3 logical_pos) { ivec3 bias_pos = ivec3(0); @@ -71,11 +82,20 @@ vec4 get_bias_texel_W_packed(ivec3 logical_pos) { } #endif // HAS_BIAS -vec4 matmul_naive_k_dim_packed(const ivec3 out_lpos) { +void matmul_tiled_k_dim_packed(ivec3 lpos) { + // Get local thread ID and workgroup size + const uint local_idx = gl_LocalInvocationID.x; + const uint local_idy = gl_LocalInvocationID.y; + const uint workgroup_size_x = gl_WorkGroupSize.x; + const uint workgroup_size_y = gl_WorkGroupSize.y; + const uint workgroup_id_x = gl_WorkGroupID.x; + const uint workgroup_id_y = gl_WorkGroupID.y; + + // Initialize position for reading from mat1 ivec3 mat1_pos; mat1_pos[mat1_axis_map.x] = 0; - mat1_pos[mat1_axis_map.y] = out_lpos.y; - mat1_pos[mat1_axis_map.z] = out_lpos.z; + mat1_pos[mat1_axis_map.y] = lpos.y; + mat1_pos[mat1_axis_map.z] = lpos.z; #ifdef MAT2_IS_TRANSPOSED const int mat2_k_axis = mat2_axis_map.x; const int mat2_row_axis = mat2_axis_map.y; @@ -84,31 +104,89 @@ vec4 matmul_naive_k_dim_packed(const ivec3 out_lpos) { const int mat2_row_axis = mat2_axis_map.x; #endif // MAT2_IS_TRANSPOSED - vec4 texel = vec4(0); + // Initialize position for reading from mat2 + ivec3 mat2_pos; + mat2_pos[mat2_k_axis] = 0; + mat2_pos[mat2_row_axis] = lpos.x; +#ifndef MAT2_IS_TRANSPOSED + mat2_pos[mat2_axis_map.z] = lpos.z; +#else + mat2_pos[mat2_axis_map.z] = 0; +#endif // MAT2_IS_TRANSPOSED + + float sum = 0; const int K = divup4(mat1_sizes.x); - for (int i = 0; i < K; ++i) { - const vec4 mat1_tex = texelFetch(mat1_tensor, mat1_pos, 0); + // Process K dimension in chunks that fit in shared memory + const int chunk_size = min(TILE_SIZE_K, K); + const int num_chunks = (K + chunk_size - 1) / chunk_size; + + for (int chunk = 0; chunk < num_chunks; ++chunk) { + // Calculate start position for this chunk + const int k_start = chunk * chunk_size; + const int k_end = min(k_start + chunk_size, K); + + // Load mat1 data into shared memory + int k_idx = k_start + int(local_idx); + int row_idx = mat1_pos[mat1_axis_map.y]; + if (k_idx < mat1_sizes[mat1_axis_map.x] && row_idx < mat1_sizes[mat1_axis_map.y]) { + ivec3 pos = mat1_pos; + pos[mat1_axis_map.x] = k_idx; + mat1_tile[local_idx][local_idy] = texelFetch(mat1_tensor, pos, 0); + } + else { + mat1_tile[local_idx][local_idy] = vec4(0.0); + } - vec4 sums; - for (int r = 0; r < 4; ++r) { - // On-demand construction of mat2_pos appears to provide the lowest - // latency. Surprisingly, this doesn't translate to mat1_pos. - ivec3 mat2_pos = ivec3(0); - mat2_pos[mat2_k_axis] = i; - mat2_pos[mat2_row_axis] = out_lpos.x * 4 + r; -#ifndef MAT2_IS_TRANSPOSED - mat2_pos[mat2_axis_map.z] = out_lpos.z; -#endif // MAT2_IS_TRANSPOSED - sums[r] = dot(mat1_tex, texelFetch(mat2_tensor, mat2_pos, 0)); + // Load mat2 data into shared memory + k_idx = k_start + int(local_idy); + int col_idx = mat2_pos[mat2_row_axis]; + if (col_idx < mat2_sizes[mat2_row_axis] && k_idx < mat2_sizes[mat2_k_axis]) { + ivec3 pos = mat2_pos; + pos[mat2_k_axis] = k_idx; + mat2_tile[local_idy][local_idx] = texelFetch(mat2_tensor, pos, 0); } + else { + mat2_tile[local_idy][local_idx] = vec4(0.0); + } + + // Ensure all threads finish loading before computation + barrier(); - texel += sums; + // Compute + for (int i = 0; i < k_end - k_start; ++i) { + const vec4 mat1_tex = mat1_tile[i][local_idy]; + const vec4 mat2_tex = mat2_tile[i][local_idx]; - mat1_pos[mat1_axis_map.x]++; + sum += dot(mat1_tex, mat2_tex); + } + + // Ensure all threads finish using shared memory before next chunk + if (chunk < num_chunks - 1) { + barrier(); + } } - return texel; + // Because the out matrix is M x N/4, we need to use out_tile + // to grab the out texels of other threads and condense into vec4s + out_tile[local_idy][local_idx] = sum; + + barrier(); + + if (local_idx%4 == 0) { + + vec4 texel = vec4(out_tile[local_idy][local_idx + 0], + out_tile[local_idy][local_idx + 1], + out_tile[local_idy][local_idx + 2], + out_tile[local_idy][local_idx + 3]); + lpos.x /= 4; +#ifdef HAS_BIAS + vec4 bias_texel = get_bias_texel_W_packed(lpos); + texel = beta * bias_texel + alpha * texel; +#endif // HAS_BIAS + + write_texel_lpos(out_tensor, lpos, texel, out_axis_map); + } } vec4 matmul_naive_k_dim_packed_row_dim_packed(const ivec3 out_lpos) { @@ -152,31 +230,36 @@ vec4 matmul_naive_k_dim_packed_row_dim_packed(const ivec3 out_lpos) { } void main() { - const ivec3 out_lpos = ivec3(gl_GlobalInvocationID); - if (any(greaterThanEqual(out_lpos, out_limits))) { - return; - } + ivec3 out_lpos = ivec3(gl_GlobalInvocationID); vec4 texel = vec4(0); #ifdef MAT2_IS_TRANSPOSED if (mat2_packed_dim == W_DIM) { - texel = matmul_naive_k_dim_packed(out_lpos); + matmul_tiled_k_dim_packed(out_lpos); + return; } else { + if (any(greaterThanEqual(out_lpos, out_limits))) { + return; + } texel = matmul_naive_k_dim_packed_row_dim_packed(out_lpos); } #else if (mat2_packed_dim == W_DIM) { + if (any(greaterThanEqual(out_lpos, out_limits))) { + return; + } texel = matmul_naive_k_dim_packed_row_dim_packed(out_lpos); } else { - texel = matmul_naive_k_dim_packed(out_lpos); + matmul_tiled_k_dim_packed(out_lpos); + return; } #endif // MAT2_IS_TRANSPOSED #ifdef HAS_BIAS - vec4 bias_texel = get_bias_texel_W_packed(out_lpos); - texel = beta * bias_texel + alpha * texel; + vec4 bias_texel = get_bias_texel_W_packed(out_lpos); + texel = beta * bias_texel + alpha * texel; #endif // HAS_BIAS - write_texel_lpos(out_tensor, out_lpos, texel, out_axis_map); + write_texel_lpos(out_tensor, out_lpos, texel, out_axis_map); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/reduce2d.glsl b/backends/vulkan/runtime/graph/ops/glsl/reduce2d.glsl new file mode 100644 index 00000000000..ba2a0be4b3a --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/reduce2d.glsl @@ -0,0 +1,128 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, STORAGE)} + +${define_active_storage_type(STORAGE)} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "tout", DTYPE, STORAGE)} +${layout_declare_tensor(B, "r", "tin", DTYPE, STORAGE)} + +${layout_declare_ubo(B, "ivec3", "tin_limits")} +${layout_declare_ubo(B, "ivec4", "tin_sizes")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +layout(constant_id = 3) const int packed_dim = 0; +layout(constant_id = 4) const int reduce_dim1 = 0; +layout(constant_id = 5) const int reduce_dim2 = 1; +layout(constant_id = 6) const int group_dim = 2; + +// A more verbose name would be NWORKERS_PER_GROUP. This describes the number of +// threads that will co-operate to compute one reduction output. There may be +// multiple groups computing distinct reduction outputs within one work group. +#define NWORKERS 4 + +// Sets an upper limit on the total size of a work group based on how many +// elements are allocated in the shared memory array below. Each thread in the +// work group will write into its assigned element in the shared array. +#define MAX_NTHREADS 16 + + +shared vec4 shared_vecs[MAX_NTHREADS]; + +#include "indexing_utils.h" + +int tid_to_smi(const ivec2 tid) { + return tid.x + tid.y * NWORKERS; +} + +// Initializing the accumulator accepts the first value in the reduction row, +// since some reduction operations (i.e. amax, amin) prefer to initialize with +// a data point instead of a static value. +#define INIT_ACCUM(first_val) ${INIT_ACCUM} +#define UPDATE_ACCUM(accum, new_val) ${UPDATE_ACCUM} +// Useful for operators such as mean which want to perform a final calculation +// with the accumulator. +#define POSTPROCESS(accum) ${POSTPROCESS} + +void reduce_2d(const ivec2 tid, ivec3 scan_pos) { + // shared memory index of this thread + const int smi = tid_to_smi(tid); + + scan_pos[reduce_dim1] = 0; + scan_pos[reduce_dim2] = 0; + vec4 accum = INIT_ACCUM(load_texel(tin, scan_pos)); + + // First dimension reduction + scan_pos[reduce_dim1] = tid.x; + for (int i = tid.x; i < tin_sizes[reduce_dim1]; + i += NWORKERS, scan_pos[reduce_dim1] += NWORKERS) { + + // Second dimension reduction + scan_pos[reduce_dim2] = 0; + for (int j = 0; j < tin_sizes[reduce_dim2]; j++, scan_pos[reduce_dim2]++) { + accum = UPDATE_ACCUM(accum, load_texel(tin, scan_pos)); + } + } + + // Write partial output to shared memory and synchronize + shared_vecs[smi] = accum; + barrier(); + + // Main thread aggregates results + if (tid.x == 0) { + // Iterate over the partial outputs to obtain the overall output + int group_i = tid.y * NWORKERS; + accum = shared_vecs[group_i++]; + for (int i = 1; i < NWORKERS; i++, group_i++) { + accum = UPDATE_ACCUM(accum, shared_vecs[group_i]); + } + + // Determine if there are any padding elements in the final texel of the + // packed dimension + const int nspill = mod4(tin_sizes[packed_dim]); + // Detect if this thread is working on the final texels of the packed + // dimension, which may have padding elements + const bool is_last_texel = + scan_pos[packed_dim] == (tin_limits[packed_dim] - 1); + + // Explicitly set padding elements to 0 + if (is_last_texel && nspill > 0) { + [[unroll]] for (int i = nspill; i < 4; i++) { + accum[i] = 0; + } + } + scan_pos[reduce_dim1] = 0; + scan_pos[reduce_dim2] = 0; + write_texel(tout, scan_pos, POSTPROCESS(accum)); + } +} + +void main() { + ivec3 scan_pos = ivec3(gl_GlobalInvocationID); + scan_pos[reduce_dim1] = 0; + scan_pos[reduce_dim2] = 0; + + const ivec2 tid = ivec2( + gl_LocalInvocationID[reduce_dim1], + gl_LocalInvocationID[group_dim]); + + if (any(greaterThanEqual(scan_pos, tin_limits))) { + return; + } + + reduce_2d(tid, scan_pos); +} \ No newline at end of file diff --git a/backends/vulkan/runtime/graph/ops/glsl/reduce2d.yaml b/backends/vulkan/runtime/graph/ops/glsl/reduce2d.yaml new file mode 100644 index 00000000000..fdc5eb9f105 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/reduce2d.yaml @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +reduce2d: + parameter_names_with_default_values: + DTYPE: float + STORAGE: texture3d + INIT_ACCUM: VEC4_T(0) + UPDATE_ACCUM: accum + new_val + POSTPROCESS: accum + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: sum2d + - NAME: mean2d + POSTPROCESS: (accum / (tin_sizes[reduce_dim1] * tin_sizes[reduce_dim2])) + - NAME: amax2d + INIT_ACCUM: first_val + UPDATE_ACCUM: max(accum, new_val) + POSTPROCESS: accum + - NAME: amin2d + INIT_ACCUM: first_val + UPDATE_ACCUM: min(accum, new_val) + POSTPROCESS: accum diff --git a/backends/vulkan/runtime/graph/ops/impl/Linear.cpp b/backends/vulkan/runtime/graph/ops/impl/Linear.cpp index 14ed9c84a32..b70c64d75ae 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Linear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Linear.cpp @@ -106,12 +106,19 @@ void add_addmm_naive_texture_node( add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); add_dtype_suffix(kernel_name, graph.dtype_of(out)); + // Define workgroup size for shared memory implementation + // Use 8x8 workgroups to match the TILE_SIZE_M and TILE_SIZE_N in the shader + utils::uvec3 local_wg_size = {8, 8, 1}; + utils::uvec3 global_wg_size = graph.logical_limits_of(out); + + // Modify global workgroup size to match M x N, not M x N/4 + global_wg_size[0] *= 4; graph.execute_nodes().emplace_back(new DispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), global_wg_size, - graph.create_local_wg_size(global_wg_size), + local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {{mat1, mat2, self}, vkapi::kRead}}, // Shader params buffers diff --git a/backends/vulkan/runtime/graph/ops/impl/Reduce.cpp b/backends/vulkan/runtime/graph/ops/impl/Reduce.cpp index c0fd442ec50..e9bba71f563 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Reduce.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Reduce.cpp @@ -32,6 +32,24 @@ void resize_reduce_node( out->virtual_resize(new_sizes); } +void resize_reduce2d_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + vTensorPtr out = graph->get_tensor(args[0].refs[0]); + vTensorPtr in = graph->get_tensor(args[1].refs[0]); + + // Extract the dimensions to reduce over + const std::vector dims_list = graph->extract_int_or_symint_list(resize_args.at(0)); + int32_t reduce_dim1_nchw = dims_list[0]; + int32_t reduce_dim2_nchw = dims_list[1]; + + std::vector new_sizes = in->sizes(); + new_sizes.at(normalize(reduce_dim1_nchw, new_sizes.size())) = 1; + new_sizes.at(normalize(reduce_dim2_nchw, new_sizes.size())) = 1; + out->virtual_resize(new_sizes); +} + utils::uvec3 reduce_global_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, @@ -137,15 +155,89 @@ void add_reduce_node( resize_reduce_node)); } +void add_reduce2d_node( + ComputeGraph& graph, + const ValueRef in, + const ValueRef dims_ref, + const ValueRef out, + const std::string& op_name) { + + VK_CHECK_COND( + !graph.is_buffer_storage(in) && !graph.is_buffer_storage(out), + "Vulkan reduction only supports texture storage"); + + const int64_t ndim = graph.dim_of(in); + + // Extract the two dimensions to reduce over + const std::vector dims_list = graph.extract_int_or_symint_list(dims_ref); + VK_CHECK_COND(dims_list.size() == 2, "reduce2d requires exactly 2 dimensions"); + + int32_t reduce_dim1 = normalize(dims_list[0], ndim); + int32_t reduce_dim2 = normalize(dims_list[1], ndim); + + // Convert to WHCN format + reduce_dim1 = nchw_dim_to_whcn_dim(reduce_dim1, ndim); + reduce_dim2 = nchw_dim_to_whcn_dim(reduce_dim2, ndim); + + // Check that the concat dim is not one of the reduction dims + if (graph.dim_of(in) == 4 && graph.size_at(0, in) > 1) { + VK_CHECK_COND(graph.concat_dim_of(in) != reduce_dim1); + VK_CHECK_COND(graph.concat_dim_of(in) != reduce_dim2); + VK_CHECK_COND(graph.concat_dim_of(out) != reduce_dim1); + VK_CHECK_COND(graph.concat_dim_of(out) != reduce_dim2); + } + + std::string kernel_name = op_name + "2d"; // Add "2d" suffix + kernel_name.reserve(kShaderNameReserve); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + // Calculate group_dim for specialization constants (use remaining dimension) + int32_t group_dim = 0; + for (int i = 0; i < 3; i++) { + if (i != reduce_dim1 && i != reduce_dim2) { + group_dim = i; + break; + } + } + + const ValueRef reduce_dim1_whcn_ref = graph.get_or_add_value_for_int(reduce_dim1); + const ValueRef reduce_dim2_whcn_ref = graph.get_or_add_value_for_int(reduce_dim2); + const ValueRef group_dim_whcn_ref = graph.get_or_add_value_for_int(group_dim); + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + reduce_global_wg_size, + reduce_local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, {in, vkapi::kRead}}, + // Shader params buffers + {graph.logical_limits_ubo(in), graph.sizes_ubo(in)}, + // Push Constants + {}, + // Specialization Constants + {graph.packed_dim_of(out), reduce_dim1, reduce_dim2, group_dim}, + // Resize Args + {dims_ref, reduce_dim1_whcn_ref, reduce_dim2_whcn_ref, group_dim_whcn_ref}, + // Resizing Logic + resize_reduce2d_node)); +} + #define DEFINE_REDUCE_FN(op_name, out_arg_idx) \ void op_name(ComputeGraph& graph, const std::vector& args) { \ const std::vector dims_list = \ - graph.extract_int_or_symint_list(args[1]); \ - VK_CHECK_COND(dims_list.size() == 1); \ - const int64_t dim_val = dims_list.at(0); \ - const ValueRef dim_ref = graph.get_or_add_value_for_int(dim_val); \ - return add_reduce_node( \ - graph, args[0], dim_ref, args[out_arg_idx], #op_name); \ + graph.extract_int_or_symint_list(args[1]); \ + if (dims_list.size() == 1) { \ + const int64_t dim_val = dims_list.at(0); \ + const ValueRef dim_ref = graph.get_or_add_value_for_int(dim_val); \ + return add_reduce_node( \ + graph, args[0], dim_ref, args[out_arg_idx], #op_name); \ + } else if (dims_list.size() == 2) { \ + return add_reduce2d_node( \ + graph, args[0], args[1], args[out_arg_idx], #op_name); \ + } else { \ + VK_CHECK_COND(false, "Only 1 or 2 dimensions supported"); \ + } \ } DEFINE_REDUCE_FN(sum, 4)