@@ -20,20 +20,19 @@ ${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)}
2020${layout_declare_tensor(1 , "r", "existing_out", DTYPE, STORAGE)}
2121${layout_declare_tensor(2 , "r", "t_in", DTYPE, STORAGE)}
2222
23- layout (set = 0 , binding = 3 ) uniform PRECISION restrict CopyArgs {
24- ivec4 out_sizes;
25- ivec4 in_sizes;
23+ ${layout_declare_ubo(3 , "ivec4 ", "out_sizes")}
24+ ${layout_declare_ubo(4 , "ivec4 ", "out_axis_map")}
25+ ${layout_declare_ubo(5 , "ivec4 ", "in_sizes")}
26+ ${layout_declare_ubo(6 , "ivec4 ", "in_axis_map")}
27+ layout (set = 0 , binding = 7 ) uniform PRECISION restrict CopyArgs {
28+ // Operates on (x, y, z) logical extents.
29+ ivec3 range;
2630 // Analogus to range variable in copy. It defines the # of channel being
2731 // copied.
2832 int channel_range;
29- int src_channel_offset;
30- int dst_channel_offset;
31- int unused;
32- // Operates on (x, y, z) extents.
33- ivec3 range;
34- int unused1;
3533 ivec3 dst_offset;
36- int unused2;
34+ int dst_channel_offset;
35+ int src_channel_offset;
3736};
3837
3938layout (local_size_x_id = 0 , local_size_y_id = 1 , local_size_z_id = 2 ) in ;
@@ -43,36 +42,36 @@ layout(constant_id = 3) const int packed_dim = C_DIM;
4342void main() {
4443 // Note: Unlike other shaders, the range is often not equal to the destination
4544 // texture extent.
46- const ivec3 pos = ivec3 (gl_GlobalInvocationID);
47- if (any (greaterThanEqual (pos , range))) {
45+ const ivec3 lpos = ivec3 (gl_GlobalInvocationID);
46+ if (any (greaterThanEqual (lpos , range))) {
4847 return ;
4948 }
5049
51- const ivec3 out_pos = pos + dst_offset;
50+ const ivec3 out_lpos = lpos + dst_offset;
5251
53- const ivec4 out_whcn = to_tensor_idx(out_pos , out_sizes, packed_dim);
52+ const ivec4 out_tidx = lpos_to_tidx(out_lpos , out_sizes, out_axis_map.w , packed_dim);
5453
5554 // First read the existing values to make sure the boundary values stay.
56- VEC4_T v = VEC4_T(texelFetch( existing_out, out_pos, 0 ) );
55+ VEC4_T v = load_texel_lpos( existing_out, out_lpos, out_axis_map );
5756
57+ ivec4 in_tidx = out_tidx;
5858 for (int i= 0 ; i< 4 ; i++ ) {
59- ivec4 in_whcn = out_whcn;
6059
61- in_whcn.z = out_whcn.z - dst_channel_offset + i;
60+ in_tidx[packed_dim] = out_tidx[packed_dim] - dst_channel_offset + i;
6261
6362 // Handle the partial update for begining of channel in an existing tensor.
6463 // If the source channel index is below zero or exceeds the range, we skip
6564 // updating the element to avoid overwriting existing data.
66- if ((in_whcn.z < 0 ) || (in_whcn.z >= channel_range)) {
65+ if ((in_tidx[packed_dim] < 0 ) || (in_tidx[packed_dim] >= channel_range)) {
6766 continue ;
6867 }
6968
7069 // Readjust for the source offset.
71- in_whcn.z = in_whcn.z + src_channel_offset;
70+ in_tidx[packed_dim] += src_channel_offset;
7271
73- ivec4 in_elem_pos = to_texture_elem_pos(in_whcn , in_sizes, packed_dim);
74- v[i] = VEC4_T(texelFetch( t_in, in_elem_pos .xyz, 0 ))[in_elem_pos .w];
72+ ivec4 in_posi = tidx_to_posi(in_tidx , in_sizes, in_axis_map , packed_dim);
73+ v[i] = load_texel( t_in, in_posi .xyz)[in_posi .w];
7574 }
7675
77- imageStore (t_out, out_pos , v);
76+ write_texel_lpos (t_out, out_lpos , v, out_axis_map );
7877}
0 commit comments