Skip to content

Commit 465170f

Browse files
pytorchbotSS-JIA
andauthored
[ET-VK] Enable buffer implementation of aten.linear (pytorch#6608)
* [ET-VK] Allow clone op to transfer between memory layouts and storage types Pull Request resolved: pytorch#6596 ## Changes As title. Extend the functionality of the `aten.clone` operator to allow transitioning the storage type and memory layout between the input to the output tensor. ## Context This functionality will be used to transition input tensors to the optimal storage type and memory layout before entering the execution of an op. The transition nodes will be added by a memory metadata tagging pass that will be introduced in a subsequent diff. ghstack-source-id: 251229412 @exported-using-ghexport Differential Revision: [D65277710](https://our.internmc.facebook.com/intern/diff/D65277710/) * [ET-VK] Enable buffer implementation of `aten.linear` Pull Request resolved: pytorch#6597 ## Changes As title. Extend the existing buffer implementation of `matmul` to support the linear operator as well. ghstack-source-id: 251229414 @exported-using-ghexport Differential Revision: [D65277712](https://our.internmc.facebook.com/intern/diff/D65277712/) --------- Co-authored-by: Stephen Jia <[email protected]>
1 parent f0af466 commit 465170f

File tree

4 files changed

+61
-16
lines changed

4 files changed

+61
-16
lines changed

backends/vulkan/runtime/graph/ops/glsl/matmul_naive_buffer.glsl

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,23 @@ ${define_required_extensions(DTYPE)}
1616

1717
layout(std430) buffer;
1818

19-
${layout_declare_tensor(0, "w", "t_out", DTYPE, "buffer")}
20-
${layout_declare_tensor(1, "r", "t_mat1", DTYPE, "buffer")}
21-
${layout_declare_tensor(2, "r", "t_mat2", DTYPE, "buffer")}
22-
${layout_declare_ubo(3, "ivec4", "out_sizes")}
23-
${layout_declare_ubo(4, "ivec4", "out_strides")}
24-
${layout_declare_ubo(5, "ivec4", "mat1_sizes")}
25-
${layout_declare_ubo(6, "ivec4", "mat1_strides")}
26-
${layout_declare_ubo(7, "ivec4", "mat2_sizes")}
27-
${layout_declare_ubo(8, "ivec4", "mat2_strides")}
28-
${layout_declare_ubo(9, "int", "out_numel")}
19+
${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")}
20+
${layout_declare_tensor(B, "r", "t_mat1", DTYPE, "buffer")}
21+
${layout_declare_tensor(B, "r", "t_mat2", DTYPE, "buffer")}
22+
${layout_declare_ubo(B, "ivec4", "out_sizes")}
23+
${layout_declare_ubo(B, "ivec4", "out_strides")}
24+
${layout_declare_ubo(B, "ivec4", "mat1_sizes")}
25+
${layout_declare_ubo(B, "ivec4", "mat1_strides")}
26+
${layout_declare_ubo(B, "ivec4", "mat2_sizes")}
27+
${layout_declare_ubo(B, "ivec4", "mat2_strides")}
28+
${layout_declare_ubo(B, "int", "out_numel")}
2929

3030
#include "indexing_utils.h"
3131

3232
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3333

34+
${layout_declare_spec_const(C, "int", "mat2_is_transposed", "0")}
35+
3436
void main() {
3537
const ivec4 out_bufix = ivec4(
3638
gl_GlobalInvocationID.x,
@@ -44,15 +46,28 @@ void main() {
4446

4547
int mat1_bufi = tidx_to_bufi(
4648
ivec4(0, out_bufix.y, out_bufix.z, out_bufix.w), mat1_strides);
47-
int mat2_bufi = tidx_to_bufi(
48-
ivec4(out_bufix.x, 0, out_bufix.z, out_bufix.w), mat2_strides);
49+
int mat2_bufi;
50+
if (mat2_is_transposed > 0) {
51+
mat2_bufi = tidx_to_bufi(
52+
ivec4(0, out_bufix.x, 0, 0), mat2_strides);
53+
} else {
54+
mat2_bufi = tidx_to_bufi(
55+
ivec4(out_bufix.x, 0, out_bufix.z, out_bufix.w), mat2_strides);
56+
}
57+
58+
int mat2_stride;
59+
if (mat2_is_transposed > 0) {
60+
mat2_stride = mat2_strides.x;
61+
} else {
62+
mat2_stride = mat2_strides.y;
63+
}
4964

5065
T sum = T(0.0);
5166
for (int i = 0; i < mat1_sizes.x; ++i) {
5267
sum += t_mat1[mat1_bufi] * t_mat2[mat2_bufi];
5368

5469
mat1_bufi += mat1_strides.x;
55-
mat2_bufi += mat2_strides.y;
70+
mat2_bufi += mat2_stride;
5671
}
5772

5873
const int out_bufi = tidx_to_bufi(out_bufix, out_strides);

backends/vulkan/runtime/graph/ops/impl/Linear.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,9 +279,12 @@ void linear(ComputeGraph& graph, const std::vector<ValueRef>& args) {
279279
ValueRef weight = prepack_standard(
280280
graph, weight_data, graph.storage_type_of(out), utils::kWidthPacked);
281281
ValueRef mat2_is_transposed = graph.add_scalar(true);
282+
282283
if (graph.val_is_none(bias)) {
283284
return add_matmul_node(graph, input, weight, out, mat2_is_transposed);
284285
} else {
286+
// Buffer implementation does not yet support biases
287+
VK_CHECK_COND(!graph.is_buffer_storage(out));
285288
return add_addmm_node(
286289
graph,
287290
bias,

backends/vulkan/runtime/graph/ops/impl/MatMul.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,11 @@ void add_matmul_naive_buffer_node(
7777
graph.size_at<uint32_t>(-2, out),
7878
graph.size_at<uint32_t>(-3, out) * graph.size_at<uint32_t>(-4, out)};
7979

80+
int mat2_is_transposed_val = (mat2_is_transposed != kDummyValueRef &&
81+
graph.get_bool(mat2_is_transposed))
82+
? 1
83+
: 0;
84+
8085
graph.execute_nodes().emplace_back(new DispatchNode(
8186
graph,
8287
VK_KERNEL_FROM_STR(kernel_name),
@@ -96,7 +101,7 @@ void add_matmul_naive_buffer_node(
96101
graph.numel_ubo(out),
97102
},
98103
// Specialization Constants
99-
{},
104+
{mat2_is_transposed_val},
100105
// Resizing Logic
101106
resize_matmul_node,
102107
{mat2_is_transposed}));

backends/vulkan/test/op_tests/cases.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,7 @@ def get_addmm_inputs():
126126
]
127127

128128

129-
@register_test_suite("aten.linear.default")
130-
def get_linear_inputs():
129+
def get_linear_texture_inputs():
131130
MKN_list = common_MKN_list
132131

133132
inputs_list = [((M, K), (N, K), None) for M, K, N in MKN_list]
@@ -142,9 +141,32 @@ def get_linear_inputs():
142141
"utils::kWidthPacked",
143142
"utils::kChannelsPacked",
144143
]
144+
test_suite.test_name_suffix = "texture"
145+
return test_suite
146+
147+
148+
def get_linear_buffer_inputs():
149+
MKN_list = common_MKN_list
150+
151+
inputs_list = [((M, K), (N, K), None) for M, K, N in MKN_list]
152+
inputs_list += [((3, M, K), (N, K), None) for M, K, N in MKN_list]
153+
154+
test_suite = VkTestSuite(inputs_list)
155+
test_suite.dtypes = ["at::kFloat"]
156+
test_suite.layouts = [
157+
"utils::kWidthPacked",
158+
"utils::kChannelsPacked",
159+
]
160+
test_suite.storage_types = ["utils::kBuffer"]
161+
test_suite.test_name_suffix = "buffer"
145162
return test_suite
146163

147164

165+
@register_test_suite("aten.linear.default")
166+
def get_linear_test_suites():
167+
return [get_linear_texture_inputs(), get_linear_buffer_inputs()]
168+
169+
148170
@register_test_suite("aten._weight_int8pack_mm.default")
149171
def get_weight_int8pack_mm_inputs():
150172
MKN_list = common_MKN_list

0 commit comments

Comments
 (0)