Let's say I have following mesh structure.
flags.DEFINE_string('mesh_shape', 'rows:3, columns:4', 'mesh shape')
flags.DEFINE_integer('image_nx_block', 3, 'The number of x blocks.')
flags.DEFINE_integer('image_ny_block', 4, 'The number of y blocks.')
flags.DEFINE_string('layout', 'image_nx_block:rows, image_ny_block:columns', 'layout rules')
Basically, I have split rows and columns of a (1200, 1200) image into 3 and 4 blocks, respectively. These Blocks resides on a GPU mesh of 3x4. So, each block of size (400, 300) is assigned to an individual GPU. The mesh tensor (t) representing it will be of size (nx=3, ny=4, sx=400, sy=300). Now, I only want to change values in a specific slice in a specific block(s).
For example,
t(0, :, :, 0) = constant or
t(0, :, :, 0) = t(0, :, :, 1)
How do I perform this operation?
One way is use t_sy = mtf.slice(t, 0, 1, 'sy') but this gives slices from all GPUs. Further using, t_synx = mtf.slice(t_sy, 0, 1, 'nx') to reach to a specific GPU block throws an error: 'can't slice along split axis'.
Can someone please help here ?