@@ -10,10 +10,11 @@ static __global__ void k_get_rows(
1010        /* const size_t nb00,*/   const  size_t  nb01, const  size_t  nb02, const  size_t  nb03,
1111        const  size_t  s10, const  size_t  s11, const  size_t  s12/* , const size_t s13*/  ) {
1212
13-     const  int  i00 = (blockIdx .x *blockDim .x  + threadIdx .x )*2 ;
14-     const  int  i10 =  blockDim .y *blockIdx .y  + threadIdx .y ;
15-     const  int  i11 = (blockIdx .z *blockDim .z  + threadIdx .z )/ne12;
16-     const  int  i12 = (blockIdx .z *blockDim .z  + threadIdx .z )%ne12;
13+     //  The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
14+     const  int  i00 = (blockIdx .y  * blockDim .x  + threadIdx .x )*2 ;
15+     const  int  i10 =  blockIdx .x ;
16+     const  int  i11 =  blockIdx .z  / ne12;
17+     const  int  i12 =  blockIdx .z  % ne12;
1718
1819    if  (i00 >= ne00) {
1920        return ;
@@ -46,10 +47,11 @@ static __global__ void k_get_rows_float(
4647        /* const size_t nb00,*/   const  size_t  nb01, const  size_t  nb02, const  size_t  nb03,
4748        const  size_t  s10, const  size_t  s11, const  size_t  s12/* , const size_t s13*/  ) {
4849
49-     const  int  i00 =  blockIdx .x *blockDim .x  + threadIdx .x ;
50-     const  int  i10 =  blockDim .y *blockIdx .y  + threadIdx .y ;
51-     const  int  i11 = (blockIdx .z *blockDim .z  + threadIdx .z )/ne12;
52-     const  int  i12 = (blockIdx .z *blockDim .z  + threadIdx .z )%ne12;
50+     //  The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
51+     const  int  i00 = blockIdx .y  * blockDim .x  + threadIdx .x ;
52+     const  int  i10 = blockIdx .x ;
53+     const  int  i11 = blockIdx .z  / ne12;
54+     const  int  i12 = blockIdx .z  % ne12;
5355
5456    if  (i00 >= ne00) {
5557        return ;
@@ -94,8 +96,8 @@ static void get_rows_cuda_q(
9496        const  size_t  nb1, const  size_t  nb2, const  size_t  nb3,
9597        cudaStream_t stream) {
9698    const  dim3  block_dims (CUDA_GET_ROWS_BLOCK_SIZE, 1 , 1 );
97-     const  int  block_num_x  = (ne00 + 2 *CUDA_GET_ROWS_BLOCK_SIZE - 1 ) / (2 *CUDA_GET_ROWS_BLOCK_SIZE);
98-     const  dim3  block_nums (block_num_x, ne10 , ne11*ne12);
99+     const  int  block_num_y  = (ne00 + 2 *CUDA_GET_ROWS_BLOCK_SIZE - 1 ) / (2 *CUDA_GET_ROWS_BLOCK_SIZE);
100+     const  dim3  block_nums (ne10, block_num_y , ne11*ne12);
99101
100102    //  strides in elements
101103    //  const size_t s0 = nb0 / sizeof(dst_t);
@@ -127,8 +129,8 @@ static void get_rows_cuda_float(
127129        const  size_t  nb1, const  size_t  nb2, const  size_t  nb3,
128130        cudaStream_t stream) {
129131    const  dim3  block_dims (CUDA_GET_ROWS_BLOCK_SIZE, 1 , 1 );
130-     const  int  block_num_x  = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1 ) / CUDA_GET_ROWS_BLOCK_SIZE;
131-     const  dim3  block_nums (block_num_x, ne10 , ne11*ne12);
132+     const  int  block_num_y  = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1 ) / CUDA_GET_ROWS_BLOCK_SIZE;
133+     const  dim3  block_nums (ne10, block_num_y , ne11*ne12);
132134
133135    //  strides in elements
134136    //  const size_t s0 = nb0 / sizeof(dst_t);
0 commit comments