@@ -57,31 +57,33 @@ static __global__ void mul_mat_f(
5757    T * tile_xy = (T *) compute_base + threadIdx .y *(tile_A::I * tile_k_padded);
5858
5959    if  constexpr  (has_ids) {
60-         __shared__  int  has_any;
61-         if  (threadIdx .y  == 0 ) {
62-             int  local_has_any = 0 ;
63-             for  (int  j = threadIdx .x ; j < cols_per_block; j += warp_size) {
64-                 int  slot = -1 ;
65-                 for  (int  k = 0 ; k < nchannels_dst; ++k) {
66-                     const  int  idv = ids[j*stride_row_id + k*stride_col_id];
67-                     if  (idv == expert_idx) {
68-                         slot = k;
69-                         break ;
70-                     }
71-                 }
72-                 if  (j < cols_per_block) {
73-                     local_has_any |= (slot >= 0 );
74-                     slot_map[j] = slot;
60+         int  found = 0 ;
61+ 
62+         for  (int  j0 = 0 ; j0 < cols_per_block; j0 += nwarps) {
63+             const  int  j = j0 + threadIdx .y ;
64+             const  int32_t  * __restrict__  id_row = ids + j*stride_row_id;
65+ 
66+             if  (threadIdx .x  == 0 ) {
67+                 slot_map[j] = -1 ;
68+             }
69+ 
70+             for  (int  k = threadIdx .x ; k < nchannels_dst; k += warp_size) {
71+                 int  match = id_row[k*stride_col_id] == expert_idx;
72+ 
73+                 if  (match) {
74+                     slot_map[j] = k;
75+                     found = 1 ;
76+                     break ;
7577                }
7678            }
77-             has_any = warp_reduce_any (local_has_any);
7879        }
79-          __syncthreads (); 
80-         if  (has_any ==  0 ) {
80+ 
81+         if  (! __syncthreads_or (found) ) {
8182            return ;
8283        }
8384    }
8485
86+ 
8587    for  (int  col = threadIdx .y *warp_size + threadIdx .x ; col < ncols; col += nwarps*warp_size) {
8688        tile_A A[ntA][warp_size / tile_A::J];
8789#pragma  unroll
@@ -106,14 +108,7 @@ static __global__ void mul_mat_f(
106108                    if  constexpr  (!has_ids) {
107109                        tile_xy[j0*tile_k_padded + threadIdx .x ] = j < cols_per_block ? y[j*stride_col_y + col] : 0 .0f ;
108110                    } else  {
109-                         float  val = 0 .0f ;
110-                         if  (j < cols_per_block) {
111-                             const  int  slot = slot_map[j];
112-                             if  (slot >= 0 ) {
113-                                 val = y[slot*stride_channel_y + j*stride_col_y + col];
114-                             }
115-                         }
116-                         tile_xy[j0*tile_k_padded + threadIdx .x ] = val;
111+                         tile_xy[j0*tile_k_padded + threadIdx .x ] = j < cols_per_block ? y[slot_map[j]*stride_channel_y + j*stride_col_y + col] : 0 .0f ;
117112                    }
118113                }
119114            } else  if  constexpr  (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
@@ -125,14 +120,7 @@ static __global__ void mul_mat_f(
125120                        const  float2  tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2 (0 .0f , 0 .0f );
126121                        tile_xy[j0*tile_k_padded + threadIdx .x ] = {tmp.x , tmp.y };
127122                    } else  {
128-                         float2  tmp = make_float2 (0 .0f , 0 .0f );
129-                         if  (j < cols_per_block) {
130-                             const  int  slot = slot_map[j];
131-                             if  (slot >= 0 ) {
132-                                 const  float2  * y2_slot = (const  float2  *)(y + slot*stride_channel_y);
133-                                 tmp = y2_slot[j*stride_col_y + col];
134-                             }
135-                         }
123+                         float2  tmp = j < cols_per_block && slot_map[j] >= 0  ? *(const  float2 *) &y[slot_map[j]*stride_channel_y + 2 *(j*stride_col_y + col)] : make_float2 (0 .0f , 0 .0f );
136124                        tile_xy[j0*tile_k_padded + threadIdx .x ] = {tmp.x , tmp.y };
137125                    }
138126                }
@@ -221,7 +209,7 @@ static inline void mul_mat_f_switch_ids(
221209        const  dim3  & block_nums, const  dim3  & block_dims, const  int  nbytes_shared_total, cudaStream_t stream) {
222210    if  (ids) {
223211        mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true ><<<block_nums, block_dims, nbytes_shared_total, stream>>> 
224-             (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
212+               (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
225213             stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
226214             sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
227215    } else  {
0 commit comments