Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def register_reduce_op(features: OpFeatures):

def check_reduce_node(node: torch.fx.Node) -> bool:
dim_list = node.args[1]
if isinstance(dim_list, list) and len(dim_list) != 1:
if isinstance(dim_list, list) and len(dim_list) > 2:
return False

keepdim = node.args[2]
Expand Down
141 changes: 112 additions & 29 deletions backends/vulkan/runtime/graph/ops/glsl/addmm_naive_texture3d.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@

#define PRECISION ${PRECISION}

// Define tile sizes for optimal memory access patterns
#define TILE_SIZE_M 8 // Tile size for output rows
#define TILE_SIZE_N 8 // Tile size for output columns
#define TILE_SIZE_K 8 // Tile size for K dimension

$if MAT2_IS_TRANSPOSED:
#define MAT2_IS_TRANSPOSED

Expand All @@ -31,6 +36,7 @@ $if HAS_BIAS:

#include "indexing_utils.h"

// Workgroup size matches tile dimensions
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")}
Expand All @@ -50,6 +56,11 @@ $if HAS_BIAS:
const lowp ivec4 bias_axis_map = unhash_axis_map(bias_layout);
const lowp int bias_packed_dim = unhash_packed_dim(bias_layout);

// Shared memory for tiling - use smaller tiles to avoid memory issues
shared vec4 mat1_tile[TILE_SIZE_M][TILE_SIZE_K];
shared vec4 mat2_tile[TILE_SIZE_K][TILE_SIZE_N];
shared float out_tile[TILE_SIZE_M][TILE_SIZE_N];

#ifdef HAS_BIAS
vec4 get_bias_texel_W_packed(ivec3 logical_pos) {
ivec3 bias_pos = ivec3(0);
Expand All @@ -71,11 +82,20 @@ vec4 get_bias_texel_W_packed(ivec3 logical_pos) {
}
#endif // HAS_BIAS

vec4 matmul_naive_k_dim_packed(const ivec3 out_lpos) {
void matmul_tiled_k_dim_packed(ivec3 lpos) {
// Get local thread ID and workgroup size
const uint local_idx = gl_LocalInvocationID.x;
const uint local_idy = gl_LocalInvocationID.y;
const uint workgroup_size_x = gl_WorkGroupSize.x;
const uint workgroup_size_y = gl_WorkGroupSize.y;
const uint workgroup_id_x = gl_WorkGroupID.x;
const uint workgroup_id_y = gl_WorkGroupID.y;

// Initialize position for reading from mat1
ivec3 mat1_pos;
mat1_pos[mat1_axis_map.x] = 0;
mat1_pos[mat1_axis_map.y] = out_lpos.y;
mat1_pos[mat1_axis_map.z] = out_lpos.z;
mat1_pos[mat1_axis_map.y] = lpos.y;
mat1_pos[mat1_axis_map.z] = lpos.z;
#ifdef MAT2_IS_TRANSPOSED
const int mat2_k_axis = mat2_axis_map.x;
const int mat2_row_axis = mat2_axis_map.y;
Expand All @@ -84,31 +104,89 @@ vec4 matmul_naive_k_dim_packed(const ivec3 out_lpos) {
const int mat2_row_axis = mat2_axis_map.x;
#endif // MAT2_IS_TRANSPOSED

vec4 texel = vec4(0);
// Initialize position for reading from mat2
ivec3 mat2_pos;
mat2_pos[mat2_k_axis] = 0;
mat2_pos[mat2_row_axis] = lpos.x;
#ifndef MAT2_IS_TRANSPOSED
mat2_pos[mat2_axis_map.z] = lpos.z;
#else
mat2_pos[mat2_axis_map.z] = 0;
#endif // MAT2_IS_TRANSPOSED

float sum = 0;
const int K = divup4(mat1_sizes.x);

for (int i = 0; i < K; ++i) {
const vec4 mat1_tex = texelFetch(mat1_tensor, mat1_pos, 0);
// Process K dimension in chunks that fit in shared memory
const int chunk_size = min(TILE_SIZE_K, K);
const int num_chunks = (K + chunk_size - 1) / chunk_size;

for (int chunk = 0; chunk < num_chunks; ++chunk) {
// Calculate start position for this chunk
const int k_start = chunk * chunk_size;
const int k_end = min(k_start + chunk_size, K);

// Load mat1 data into shared memory
int k_idx = k_start + int(local_idx);
int row_idx = mat1_pos[mat1_axis_map.y];
if (k_idx < mat1_sizes[mat1_axis_map.x] && row_idx < mat1_sizes[mat1_axis_map.y]) {
ivec3 pos = mat1_pos;
pos[mat1_axis_map.x] = k_idx;
mat1_tile[local_idx][local_idy] = texelFetch(mat1_tensor, pos, 0);
}
else {
mat1_tile[local_idx][local_idy] = vec4(0.0);
}

vec4 sums;
for (int r = 0; r < 4; ++r) {
// On-demand construction of mat2_pos appears to provide the lowest
// latency. Surprisingly, this doesn't translate to mat1_pos.
ivec3 mat2_pos = ivec3(0);
mat2_pos[mat2_k_axis] = i;
mat2_pos[mat2_row_axis] = out_lpos.x * 4 + r;
#ifndef MAT2_IS_TRANSPOSED
mat2_pos[mat2_axis_map.z] = out_lpos.z;
#endif // MAT2_IS_TRANSPOSED
sums[r] = dot(mat1_tex, texelFetch(mat2_tensor, mat2_pos, 0));
// Load mat2 data into shared memory
k_idx = k_start + int(local_idy);
int col_idx = mat2_pos[mat2_row_axis];
if (col_idx < mat2_sizes[mat2_row_axis] && k_idx < mat2_sizes[mat2_k_axis]) {
ivec3 pos = mat2_pos;
pos[mat2_k_axis] = k_idx;
mat2_tile[local_idy][local_idx] = texelFetch(mat2_tensor, pos, 0);
}
else {
mat2_tile[local_idy][local_idx] = vec4(0.0);
}

// Ensure all threads finish loading before computation
barrier();

texel += sums;
// Compute
for (int i = 0; i < k_end - k_start; ++i) {
const vec4 mat1_tex = mat1_tile[i][local_idy];
const vec4 mat2_tex = mat2_tile[i][local_idx];

mat1_pos[mat1_axis_map.x]++;
sum += dot(mat1_tex, mat2_tex);
}

// Ensure all threads finish using shared memory before next chunk
if (chunk < num_chunks - 1) {
barrier();
}
}

return texel;
// Because the out matrix is M x N/4, we need to use out_tile
// to grab the out texels of other threads and condense into vec4s
out_tile[local_idy][local_idx] = sum;

barrier();

if (local_idx%4 == 0) {

vec4 texel = vec4(out_tile[local_idy][local_idx + 0],
out_tile[local_idy][local_idx + 1],
out_tile[local_idy][local_idx + 2],
out_tile[local_idy][local_idx + 3]);
lpos.x /= 4;
#ifdef HAS_BIAS
vec4 bias_texel = get_bias_texel_W_packed(lpos);
texel = beta * bias_texel + alpha * texel;
#endif // HAS_BIAS

write_texel_lpos(out_tensor, lpos, texel, out_axis_map);
}
}

vec4 matmul_naive_k_dim_packed_row_dim_packed(const ivec3 out_lpos) {
Expand Down Expand Up @@ -152,31 +230,36 @@ vec4 matmul_naive_k_dim_packed_row_dim_packed(const ivec3 out_lpos) {
}

void main() {
const ivec3 out_lpos = ivec3(gl_GlobalInvocationID);
if (any(greaterThanEqual(out_lpos, out_limits))) {
return;
}
ivec3 out_lpos = ivec3(gl_GlobalInvocationID);

vec4 texel = vec4(0);

#ifdef MAT2_IS_TRANSPOSED
if (mat2_packed_dim == W_DIM) {
texel = matmul_naive_k_dim_packed(out_lpos);
matmul_tiled_k_dim_packed(out_lpos);
return;
} else {
if (any(greaterThanEqual(out_lpos, out_limits))) {
return;
}
texel = matmul_naive_k_dim_packed_row_dim_packed(out_lpos);
}
#else
if (mat2_packed_dim == W_DIM) {
if (any(greaterThanEqual(out_lpos, out_limits))) {
return;
}
texel = matmul_naive_k_dim_packed_row_dim_packed(out_lpos);
} else {
texel = matmul_naive_k_dim_packed(out_lpos);
matmul_tiled_k_dim_packed(out_lpos);
return;
}
#endif // MAT2_IS_TRANSPOSED

#ifdef HAS_BIAS
vec4 bias_texel = get_bias_texel_W_packed(out_lpos);
texel = beta * bias_texel + alpha * texel;
vec4 bias_texel = get_bias_texel_W_packed(out_lpos);
texel = beta * bias_texel + alpha * texel;
#endif // HAS_BIAS

write_texel_lpos(out_tensor, out_lpos, texel, out_axis_map);
write_texel_lpos(out_tensor, out_lpos, texel, out_axis_map);
}
128 changes: 128 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/reduce2d.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

#define PRECISION ${PRECISION}
#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}

${define_active_storage_type(STORAGE)}

#extension GL_EXT_control_flow_attributes : require

layout(std430) buffer;

${layout_declare_tensor(B, "w", "tout", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "tin", DTYPE, STORAGE)}

${layout_declare_ubo(B, "ivec3", "tin_limits")}
${layout_declare_ubo(B, "ivec4", "tin_sizes")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

layout(constant_id = 3) const int packed_dim = 0;
layout(constant_id = 4) const int reduce_dim1 = 0;
layout(constant_id = 5) const int reduce_dim2 = 1;
layout(constant_id = 6) const int group_dim = 2;

// A more verbose name would be NWORKERS_PER_GROUP. This describes the number of
// threads that will co-operate to compute one reduction output. There may be
// multiple groups computing distinct reduction outputs within one work group.
#define NWORKERS 4

// Sets an upper limit on the total size of a work group based on how many
// elements are allocated in the shared memory array below. Each thread in the
// work group will write into its assigned element in the shared array.
#define MAX_NTHREADS 16


shared vec4 shared_vecs[MAX_NTHREADS];

#include "indexing_utils.h"

int tid_to_smi(const ivec2 tid) {
return tid.x + tid.y * NWORKERS;
}

// Initializing the accumulator accepts the first value in the reduction row,
// since some reduction operations (i.e. amax, amin) prefer to initialize with
// a data point instead of a static value.
#define INIT_ACCUM(first_val) ${INIT_ACCUM}
#define UPDATE_ACCUM(accum, new_val) ${UPDATE_ACCUM}
// Useful for operators such as mean which want to perform a final calculation
// with the accumulator.
#define POSTPROCESS(accum) ${POSTPROCESS}

void reduce_2d(const ivec2 tid, ivec3 scan_pos) {
// shared memory index of this thread
const int smi = tid_to_smi(tid);

scan_pos[reduce_dim1] = 0;
scan_pos[reduce_dim2] = 0;
vec4 accum = INIT_ACCUM(load_texel(tin, scan_pos));

// First dimension reduction
scan_pos[reduce_dim1] = tid.x;
for (int i = tid.x; i < tin_sizes[reduce_dim1];
i += NWORKERS, scan_pos[reduce_dim1] += NWORKERS) {

// Second dimension reduction
scan_pos[reduce_dim2] = 0;
for (int j = 0; j < tin_sizes[reduce_dim2]; j++, scan_pos[reduce_dim2]++) {
accum = UPDATE_ACCUM(accum, load_texel(tin, scan_pos));
}
}

// Write partial output to shared memory and synchronize
shared_vecs[smi] = accum;
barrier();

// Main thread aggregates results
if (tid.x == 0) {
// Iterate over the partial outputs to obtain the overall output
int group_i = tid.y * NWORKERS;
accum = shared_vecs[group_i++];
for (int i = 1; i < NWORKERS; i++, group_i++) {
accum = UPDATE_ACCUM(accum, shared_vecs[group_i]);
}

// Determine if there are any padding elements in the final texel of the
// packed dimension
const int nspill = mod4(tin_sizes[packed_dim]);
// Detect if this thread is working on the final texels of the packed
// dimension, which may have padding elements
const bool is_last_texel =
scan_pos[packed_dim] == (tin_limits[packed_dim] - 1);

// Explicitly set padding elements to 0
if (is_last_texel && nspill > 0) {
[[unroll]] for (int i = nspill; i < 4; i++) {
accum[i] = 0;
}
}
scan_pos[reduce_dim1] = 0;
scan_pos[reduce_dim2] = 0;
write_texel(tout, scan_pos, POSTPROCESS(accum));
}
}

void main() {
ivec3 scan_pos = ivec3(gl_GlobalInvocationID);
scan_pos[reduce_dim1] = 0;
scan_pos[reduce_dim2] = 0;

const ivec2 tid = ivec2(
gl_LocalInvocationID[reduce_dim1],
gl_LocalInvocationID[group_dim]);

if (any(greaterThanEqual(scan_pos, tin_limits))) {
return;
}

reduce_2d(tid, scan_pos);
}
29 changes: 29 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/reduce2d.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

reduce2d:
parameter_names_with_default_values:
DTYPE: float
STORAGE: texture3d
INIT_ACCUM: VEC4_T(0)
UPDATE_ACCUM: accum + new_val
POSTPROCESS: accum
generate_variant_forall:
DTYPE:
- VALUE: half
- VALUE: float
shader_variants:
- NAME: sum2d
- NAME: mean2d
POSTPROCESS: (accum / (tin_sizes[reduce_dim1] * tin_sizes[reduce_dim2]))
- NAME: amax2d
INIT_ACCUM: first_val
UPDATE_ACCUM: max(accum, new_val)
POSTPROCESS: accum
- NAME: amin2d
INIT_ACCUM: first_val
UPDATE_ACCUM: min(accum, new_val)
POSTPROCESS: accum
9 changes: 8 additions & 1 deletion backends/vulkan/runtime/graph/ops/impl/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,19 @@ void add_addmm_naive_texture_node(
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
add_dtype_suffix(kernel_name, graph.dtype_of(out));

// Define workgroup size for shared memory implementation
// Use 8x8 workgroups to match the TILE_SIZE_M and TILE_SIZE_N in the shader
utils::uvec3 local_wg_size = {8, 8, 1};

utils::uvec3 global_wg_size = graph.logical_limits_of(out);

// Modify global workgroup size to match M x N, not M x N/4
global_wg_size[0] *= 4;
graph.execute_nodes().emplace_back(new DispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
global_wg_size,
graph.create_local_wg_size(global_wg_size),
local_wg_size,
// Inputs and Outputs
{{out, vkapi::kWrite}, {{mat1, mat2, self}, vkapi::kRead}},
// Shader params buffers
Expand Down
Loading