@@ -19,16 +19,18 @@ layout(std430) buffer;
1919
2020#include "indexing_utils.h"
2121
22- ${layout_declare_tensor(B, "w ", "t_out", DTYPE, "texture3d")}
22+ ${layout_declare_tensor(B, "rw ", "t_out", DTYPE, "texture3d")}
2323
2424$for i in range(NUM_INPUTS):
25- ${layout_declare_tensor(B, "r", "t_in" + str(i + 1 ), DTYPE, "texture3d")}
25+ ${layout_declare_tensor(B, "r", "t_inp" + str(i), DTYPE, "texture3d")}
26+
27+ ${layout_declare_tensor(B, "r", "t_concat_offset", "int ", "buffer ")}
2628
2729${layout_declare_ubo(B, "int ", "concat_dim")}
2830
2931$in_metadata = ""
3032$for i in range(NUM_INPUTS):
31- $in_metadata += "ivec4 in " + str(i + 1 ) + "_sizes;\n"
33+ $in_metadata += "ivec4 inp " + str(i) + "_sizes;\n"
3234
3335layout (push_constant) uniform restrict Block {
3436 ivec4 out_sizes;
@@ -40,90 +42,135 @@ const lowp ivec4 out_axis_map = unhash_axis_map(out_layout);
4042const lowp int out_packed_dim = unhash_packed_dim(out_layout);
4143
4244$for i in range(NUM_INPUTS):
43- ${layout_declare_spec_const(C, "int ", "in " + str(i+ 1 ) + "_layout", "DEFAULT_LAYOUT")}
44- const lowp ivec4 in ${i+ 1 }_axis_map = unhash_axis_map(in ${i+ 1 }_layout);
45- const lowp int in ${i+ 1 }_packed_dim = unhash_packed_dim(in ${i+ 1 }_layout);
45+ ${layout_declare_spec_const(C, "int ", "inp " + str(i) + "_layout", "DEFAULT_LAYOUT")}
46+ const lowp ivec4 inp ${i}_axis_map = unhash_axis_map(inp ${i}_layout);
47+ const lowp int inp ${i}_packed_dim = unhash_packed_dim(inp ${i}_layout);
4648
4749layout (local_size_x_id = 0 , local_size_y_id = 1 , local_size_z_id = 2 ) in ;
4850
49- // Check if we can use the fast path (no texel merging required)
50- bool can_use_fast_path() {
51- // Fast path is possible when:
52- // 1. The concat dimension is not the packed dimension, or
53- // 2. The concat dimension is the packed dimension but both input tensors have dimensions
54- // that are multiples of 4 along the packed dimension
55- if (concat_dim != out_packed_dim) {
56- return true;
57- }
58-
59- // Check if all input tensors have dimensions that are multiples of 4 along the packed dimension
60- bool all_concat_dim_size_multiple_of_4 = true;
61- $for i in range(NUM_INPUTS):
62- all_concat_dim_size_multiple_of_4 =
63- all_concat_dim_size_multiple_of_4 &&
64- (in ${i+ 1 }_sizes[concat_dim] % 4 == 0 );
51+ #define NUM_INPUTS ${NUM_INPUTS}
6552
66- return all_concat_dim_size_multiple_of_4;
67- }
53+ #include "concat_utils.glslh"
6854
55+ /*
56+ * This shader template concatenates up to NUM_INPUT input tensors to the
57+ * output tensor along the concat_dim. Elements from the input tensor will
58+ * be inserted along the output's concat_dim starting at concat_offset.
59+ *
60+ * Each thread is responsible for writing out one output texel. The data
61+ * required for the output texel may be read from multiple input texels of one
62+ * input tensor.
63+ */
6964void main() {
70- const ivec3 lpos = ivec3 (gl_GlobalInvocationID);
71- ivec4 out_tidx = lpos_to_tidx(lpos, out_sizes, out_axis_map.w, out_packed_dim);
72-
73- if (any (greaterThanEqual (out_tidx, out_sizes))) {
65+ const int tid = ivec3 (gl_GlobalInvocationID).x;
66+
67+ // Sum of the sizes of all input tensors along the concat_dim
68+ const int concat_numel = total_concat_dim_numel();
69+
70+ // The 1-3 input tensors are interpreted as one concatenated tensor ("volume")
71+ // along the concat_dim for the purposes of tensor indexing. Each thread is
72+ // responsible for writing out 4 elements along the packed dim of the output
73+ // tensor by reading the source data from the input tensor(s).
74+ ivec4 inp_volume_sizes = out_sizes;
75+ inp_volume_sizes[concat_dim] = total_concat_dim_numel();
76+
77+ // Reconstruct inp_volume_texel_sizes from Concat.cpp
78+ ivec4 inp_volume_texel_sizes = inp_volume_sizes;
79+ inp_volume_texel_sizes[out_packed_dim] = DIV_UP_4(
80+ inp_volume_texel_sizes[out_packed_dim]
81+ ) + 1 ;
82+
83+ // tensor index of the first element that will be read from the input volume
84+ ivec4 inp_volume_start_tidx = nchwi_to_tidx(tid, inp_volume_texel_sizes);
85+ inp_volume_start_tidx[out_packed_dim] = MUL_4(
86+ inp_volume_start_tidx[out_packed_dim]
87+ );
88+
89+ int concat_offset = t_concat_offset[0 ];
90+
91+ // tensor index of the first element that will be written to the output tensor
92+ ivec4 out_write_start_tidx = inp_volume_start_tidx;
93+ out_write_start_tidx[concat_dim] += concat_offset;
94+
95+ // To write to the the desired output element, we will need to load the texel
96+ // to which the element belongs. Calculate the tensor index of the first
97+ // element of that texel.
98+ ivec4 out_read_start_tidx = out_write_start_tidx;
99+ out_read_start_tidx[out_packed_dim] = ALIGN_DOWN_4(
100+ out_write_start_tidx[out_packed_dim]);
101+
102+ // bounds check
103+ if (any (greaterThanEqual (out_read_start_tidx, out_sizes))) {
74104 return ;
75105 }
76106
77- if (can_use_fast_path()) {
78- // Fast path: No texel merging required
79- ivec4 in_tidx = out_tidx;
107+ ivec3 out_pos = tidx_to_pos(
108+ out_read_start_tidx,
109+ out_sizes,
110+ out_axis_map,
111+ out_packed_dim
112+ );
80113
81- $for i in range(NUM_INPUTS):
82- // For each input tensor, check if the tensor index is within bounds. If
83- // so, read the texel from the input tensor and write it to the output
84- if (in_tidx[concat_dim] < in ${i+ 1 }_sizes[concat_dim]) {
85- const ivec3 in_pos = tidx_to_pos(in_tidx, in ${i+ 1 }_sizes, in ${i+ 1 }_axis_map, in ${i+ 1 }_packed_dim);
86- const VEC4_T in_texel = load_texel(t_in${i+ 1 }, in_pos);
87- write_texel_lpos(t_out, lpos, in_texel, out_axis_map);
88- return ;
89- }
90- // Otherwise, adjust the index along the concat dimension and try the next
91- // input tensor.
92- else {
93- in_tidx[concat_dim] -= in ${i+ 1 }_sizes[concat_dim];
94- }
95- }
96- else {
97- // Slow path: Texel merging required
98- VEC4_T out_texel = VEC4_T(0 );
114+ VEC4_T out_texel = imageLoad(t_out, out_pos);
99115
100- // Process each element in the output texel individually
101- for (int texel_i = 0 ; texel_i < 4 ; ++ texel_i) {
102- ivec4 curr_out_tidx = out_tidx;
103- curr_out_tidx[out_packed_dim] += texel_i;
116+ VEC4_T test_texel = VEC4_T(- 1.0 );
104117
105- // Skip if we're out of bounds
106- if (curr_out_tidx[out_packed_dim] >= out_sizes[out_packed_dim]) {
107- continue ;
108- }
118+ for (int comp = 0 ; comp < 4 ; ++ comp) {
119+ ivec4 out_tidx = out_read_start_tidx;
120+ out_tidx[out_packed_dim] += comp;
109121
110- ivec4 in_tidx = curr_out_tidx;
111- $for i in range(NUM_INPUTS):
112- // For each input tensor, check if the tensor index is within bounds. If
113- // so, read the corresponding texel element from the input tensor and
114- // write it to the output texel.
115- if (in_tidx[concat_dim] < in ${i+ 1 }_sizes[concat_dim]) {
116- const ivec4 in_posi = tidx_to_posi(in_tidx, in ${i+ 1 }_sizes, in ${i+ 1 }_axis_map, in ${i+ 1 }_packed_dim);
117- out_texel[texel_i] = load_texel(t_in${i+ 1 }, in_posi.xyz)[in_posi.w];
118- continue ;
119- }
120- // Otherwise, adjust the index along the concat dimension and try the
121- // next input tensor.
122- else {
123- in_tidx[concat_dim] -= in ${i+ 1 }_sizes[concat_dim];
124- }
122+
123+ // It's possible that the current texel element has been written to as part
124+ // of the previous input batch; if so, then don't overwrite this texel
125+ // element
126+ if (out_tidx[concat_dim] < concat_offset) {
127+ test_texel[comp] = - 5.0 ;
128+ continue ;
125129 }
126130
127- write_texel_lpos(t_out, lpos, out_texel, out_axis_map);
131+ // Calculate the tidx of the input volume that corresponds to this output
132+ // element
133+ ivec4 inp_volume_tidx = out_tidx;
134+ inp_volume_tidx[concat_dim] -= concat_offset;
135+
136+ // go through the list of input tensors, and figure out which input this
137+ // output element should be read from.
138+ $for i in range(NUM_INPUTS):
139+ if (inp_volume_tidx[concat_dim] < inp${i}_sizes[concat_dim]) {
140+ // Special fast path case if, for the first output texel element, the
141+ // corresponding input element is at the start of the texel it belongs
142+ // to. In this case, the input texel can be written as-is to the output
143+ // texel. Also require that The entire input texel is valid and does not
144+ // contain any padding elements.
145+ if (comp == 0 &&
146+ out_tidx[out_packed_dim] % 4 == 0 &&
147+ inp_volume_tidx[inp${i}_packed_dim] % 4 == 0 &&
148+ inp_volume_tidx[inp${i}_packed_dim] + 3 < inp${i}_sizes[inp${i}_packed_dim]) {
149+ const ivec3 in_pos = tidx_to_pos(
150+ inp_volume_tidx,
151+ inp${i}_sizes,
152+ inp${i}_axis_map,
153+ inp${i}_packed_dim);
154+
155+ out_texel = texelFetch(t_inp${i}, in_pos, 0 );
156+ break ;
157+ }
158+
159+ // Otherwise, locate the specific input element required
160+ const ivec4 in_posi = tidx_to_posi(
161+ inp_volume_tidx,
162+ inp${i}_sizes,
163+ inp${i}_axis_map,
164+ inp${i}_packed_dim);
165+
166+ out_texel[comp] = texelFetch(t_inp${i}, in_posi.xyz, 0 )[in_posi.w];
167+ test_texel[comp] = out_texel[comp];
168+ continue ;
169+ }
170+ else {
171+ inp_volume_tidx[concat_dim] -= inp${i}_sizes[concat_dim];
172+ }
128173 }
174+
175+ imageStore(t_out, out_pos, out_texel);
129176}
0 commit comments