33
44#include < vector>
55
6+ // To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each.
7+ struct mmq_ids_helper_store {
8+ uint32_t data;
9+
10+ __device__ mmq_ids_helper_store (const uint32_t it, const uint32_t iex_used) {
11+ data = (it & 0x003FFFFF ) | (iex_used << 22 );
12+ }
13+
14+ __device__ uint32_t it () const {
15+ return data & 0x003FFFFF ;
16+ }
17+
18+ __device__ uint32_t iex_used () const {
19+ return data >> 22 ;
20+ }
21+ };
22+ static_assert (sizeof (mmq_ids_helper_store) == 4 , " unexpected size for mmq_ids_helper_store" );
23+
24+ // Helper function for mul_mat_id, converts ids to a more convenient format.
25+ // ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert.
26+ // ids_dst describes the same mapping but for the dst tensor.
27+ // The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1].
28+ template <int n_expert_used_template>
29+ __launch_bounds__ (ggml_cuda_get_physical_warp_size(), 1)
30+ static __global__ void mmq_ids_helper(
31+ const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
32+ const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) {
33+ constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
34+ const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template;
35+ const int expert = blockIdx .x ;
36+
37+ extern __shared__ char data_mmq_ids_helper[];
38+ mmq_ids_helper_store * store = (mmq_ids_helper_store *) data_mmq_ids_helper;
39+
40+ int nex_prev = 0 ; // Number of columns for experts with a lower index.
41+ int it_compact = 0 ; // Running index for the compact slice of this expert.
42+
43+ if constexpr (n_expert_used_template == 0 ) {
44+ // Generic implementation:
45+ for (int it = 0 ; it < n_tokens; ++it) {
46+ int iex_used = -1 ; // The index at which the expert is used, if any.
47+ for (int iex = threadIdx .x ; iex < n_expert_used; iex += warp_size) {
48+ const int expert_used = ids[it*si1 + iex];
49+ nex_prev += expert_used < expert;
50+ if (expert_used == expert) {
51+ iex_used = iex;
52+ }
53+ }
54+
55+ if (iex_used != -1 ) {
56+ store[it_compact] = mmq_ids_helper_store (it, iex_used);
57+ }
58+
59+ if (warp_reduce_any<warp_size>(iex_used != -1 )) {
60+ it_compact++;
61+ }
62+ }
63+ } else {
64+ // Implementation optimized for specific numbers of experts used:
65+ static_assert (n_expert_used == 6 || warp_size % n_expert_used == 0 , " bad n_expert_used" );
66+ const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2.
67+ for (int it0 = 0 ; it0 < n_tokens; it0 += warp_size/neu_padded) {
68+ const int it = it0 + threadIdx .x / neu_padded;
69+
70+ const int iex = threadIdx .x % neu_padded; // The index at which the expert is used, if any.
71+ const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ?
72+ ids[it*si1 + iex] : INT_MAX;
73+ const int iex_used = expert_used == expert ? iex : -1 ;
74+ nex_prev += expert_used < expert;
75+
76+ // Whether the threads at this token position have used the expert:
77+ const int it_compact_add_self = warp_reduce_any<neu_padded>(iex_used != -1 );
78+
79+ // Do a scan over threads at lower token positions in warp to get the correct index for writing data:
80+ int it_compact_add_lower = 0 ;
81+ #pragma unroll
82+ for (int offset = neu_padded; offset < warp_size; offset += neu_padded) {
83+ const int tmp = __shfl_up_sync (0xFFFFFFFF , it_compact_add_self, offset, warp_size);
84+ if (threadIdx .x >= offset) {
85+ it_compact_add_lower += tmp;
86+ }
87+ }
88+
89+ if (iex_used != -1 ) {
90+ store[it_compact + it_compact_add_lower] = mmq_ids_helper_store (it, iex_used);
91+ }
92+
93+ // The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads:
94+ it_compact += __shfl_sync (0xFFFFFFFF , it_compact_add_lower + it_compact_add_self, warp_size - 1 , warp_size);
95+ }
96+ }
97+ nex_prev = warp_reduce_sum<warp_size>(nex_prev);
98+
99+ for (int itc = threadIdx .x ; itc < it_compact; itc += warp_size) {
100+ const mmq_ids_helper_store store_it = store[itc];
101+ const int it = store_it.it ();
102+ const int iex_used = store_it.iex_used ();
103+ ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y;
104+ ids_dst [nex_prev + itc] = it*n_expert_used + iex_used;
105+ }
106+
107+ if (threadIdx .x != 0 ) {
108+ return ;
109+ }
110+
111+ expert_bounds[expert] = nex_prev;
112+
113+ if (expert < gridDim .x - 1 ) {
114+ return ;
115+ }
116+
117+ expert_bounds[gridDim .x ] = nex_prev + it_compact;
118+ }
119+
120+ template <int n_expert_used_template>
121+ static void launch_mmq_ids_helper (
122+ const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
123+ const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
124+ GGML_ASSERT (n_tokens < (1 << 22 ) && " too few bits in mmq_ids_helper_store" );
125+ GGML_ASSERT (n_expert_used_var < (1 << 10 ) && " too few bits in mmq_ids_helper_store" );
126+
127+ const int id = ggml_cuda_get_device ();
128+ const int warp_size = ggml_cuda_info ().devices [id].warp_size ;
129+ const size_t smpbo = ggml_cuda_info ().devices [id].smpbo ;
130+ CUDA_SET_SHARED_MEMORY_LIMIT (mmq_ids_helper<n_expert_used_template>, smpbo);
131+
132+ const dim3 num_blocks (n_experts, 1 , 1 );
133+ const dim3 block_size (warp_size, 1 , 1 );
134+ const size_t nbytes_shared = n_tokens*sizeof (mmq_ids_helper_store);
135+ GGML_ASSERT (nbytes_shared <= smpbo);
136+ mmq_ids_helper<n_expert_used_template><<<num_blocks, block_size, nbytes_shared, stream>>>
137+ (ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1);
138+ }
139+
6140static void ggml_cuda_mul_mat_q_switch_type (ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
7141 switch (args.type_x ) {
8142 case GGML_TYPE_Q4_0:
@@ -137,7 +271,7 @@ void ggml_cuda_mul_mat_q(
137271 ne00, ne01, ne1, s01, ne11, s1,
138272 ne02, ne12, s02, s12, s2,
139273 ne03, ne13, s03, s13, s3,
140- use_stream_k};
274+ use_stream_k, ne1 };
141275 ggml_cuda_mul_mat_q_switch_type (ctx, args, stream);
142276 return ;
143277 }
@@ -148,53 +282,49 @@ void ggml_cuda_mul_mat_q(
148282
149283 const int64_t n_expert_used = ids->ne [0 ];
150284 const int64_t ne_get_rows = ne12 * n_expert_used;
285+ GGML_ASSERT (ne1 == n_expert_used);
151286
152- std::vector<char > ids_host (ggml_nbytes (ids));
153- std::vector<int32_t > ids_src1_host;
154- ids_src1_host.reserve (ne_get_rows);
155- std::vector<int32_t > ids_dst_host;
156- ids_dst_host.reserve (ne_get_rows);
157- std::vector<int32_t > tokens_per_expert_host (ne02);
158- std::vector<int32_t > expert_bounds_host (ne02 + 1 );
159- ggml_cuda_pool_alloc<int32_t > ids_buf_dev (ctx.pool ());
160-
161- CUDA_CHECK (cudaMemcpyAsync (ids_host.data (), ids->data , ggml_nbytes (ids), cudaMemcpyDeviceToHost, stream));
162- CUDA_CHECK (cudaStreamSynchronize (stream));
163-
164- for (int64_t i02 = 0 ; i02 < ne02; ++i02) { // expert matrices
165- for (int64_t i12 = 0 ; i12 < ne12; ++i12) { // tokens
166- for (int64_t iex = 0 ; iex < n_expert_used; ++iex) {
167- const int32_t expert_to_use = *(const int32_t *)(ids_host.data () + i12*ids->nb [1 ] + iex*ids->nb [0 ]);
168- assert (expert_to_use >= 0 && expert_to_use < ne02);
169- if (expert_to_use == i02) {
170- ids_src1_host.push_back (i12*(nb12/nb11) + iex % ne11);
171- ids_dst_host.push_back (i12*ne1 + iex);
172- tokens_per_expert_host[i02]++;
173- break ;
174- }
175- }
176- }
177- }
287+ ggml_cuda_pool_alloc<int32_t > ids_src1 (ctx.pool (), ne_get_rows);
288+ ggml_cuda_pool_alloc<int32_t > ids_dst (ctx.pool (), ne_get_rows);
289+ ggml_cuda_pool_alloc<int32_t > expert_bounds (ctx.pool (), ne02 + 1 );
178290
179- int32_t cumsum = 0 ;
180- for (int64_t i = 0 ; i < ne02; ++i) {
181- expert_bounds_host[i] = cumsum;
182- cumsum += tokens_per_expert_host[i];
291+ {
292+ GGML_ASSERT (ids->nb [0 ] == ggml_element_size (ids));
293+ const int si1 = ids->nb [1 ] / ggml_element_size (ids);
294+ const int sis1 = nb12 / nb11;
295+
296+ switch (n_expert_used) {
297+ case 2 :
298+ launch_mmq_ids_helper< 2 > ((const int32_t *) ids->data , ids_src1.get (), ids_dst.get (), expert_bounds.get (),
299+ ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
300+ break ;
301+ case 4 :
302+ launch_mmq_ids_helper< 4 > ((const int32_t *) ids->data , ids_src1.get (), ids_dst.get (), expert_bounds.get (),
303+ ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
304+ break ;
305+ case 6 :
306+ launch_mmq_ids_helper< 6 > ((const int32_t *) ids->data , ids_src1.get (), ids_dst.get (), expert_bounds.get (),
307+ ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
308+ break ;
309+ case 8 :
310+ launch_mmq_ids_helper< 8 > ((const int32_t *) ids->data , ids_src1.get (), ids_dst.get (), expert_bounds.get (),
311+ ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
312+ break ;
313+ case 16 :
314+ launch_mmq_ids_helper<16 > ((const int32_t *) ids->data , ids_src1.get (), ids_dst.get (), expert_bounds.get (),
315+ ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
316+ break ;
317+ case 32 :
318+ launch_mmq_ids_helper<32 > ((const int32_t *) ids->data , ids_src1.get (), ids_dst.get (), expert_bounds.get (),
319+ ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
320+ break ;
321+ default :
322+ launch_mmq_ids_helper< 0 > ((const int32_t *) ids->data , ids_src1.get (), ids_dst.get (), expert_bounds.get (),
323+ ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
324+ break ;
325+ }
326+ CUDA_CHECK (cudaGetLastError ());
183327 }
184- expert_bounds_host[ne02] = cumsum;
185-
186- std::vector<int32_t > ids_buf_host;
187- ids_buf_host.reserve (ids_src1_host.size () + ids_dst_host.size () + expert_bounds_host.size ());
188- ids_buf_host.insert (ids_buf_host.end (), ids_src1_host.begin (), ids_src1_host.end ());
189- ids_buf_host.insert (ids_buf_host.end (), ids_dst_host.begin (), ids_dst_host.end ());
190- ids_buf_host.insert (ids_buf_host.end (), expert_bounds_host.begin (), expert_bounds_host.end ());
191- ids_buf_dev.alloc (ids_buf_host.size () + get_mmq_x_max_host (cc)); // Expert bounds are padded on device.
192- CUDA_CHECK (cudaMemcpyAsync (ids_buf_dev.ptr , ids_buf_host.data (), ids_buf_host.size ()*sizeof (int32_t ), cudaMemcpyHostToDevice, stream));
193- CUDA_CHECK (cudaStreamSynchronize (stream));
194-
195- const int32_t * ids_src1_dev = ids_buf_dev.ptr ;
196- const int32_t * ids_dst_dev = ids_src1_dev + ids_src1_host.size ();
197- const int32_t * expert_bounds_dev = ids_dst_dev + ids_dst_host.size ();
198328
199329 const size_t nbytes_src1_q8_1 = ne12*n_expert_used*ne10_padded * sizeof (block_q8_1)/QK8_1 +
200330 get_mmq_x_max_host (cc)*sizeof (block_q8_1_mmq);
@@ -208,7 +338,7 @@ void ggml_cuda_mul_mat_q(
208338 const int64_t s11 = src1->nb [1 ] / ts_src1;
209339 const int64_t s12 = src1->nb [2 ] / ts_src1;
210340 const int64_t s13 = src1->nb [2 ] / ts_src1;
211- quantize_mmq_q8_1_cuda (src1_d, ids_src1_dev , src1_q8_1.get (), src0->type ,
341+ quantize_mmq_q8_1_cuda (src1_d, ids_src1. get () , src1_q8_1.get (), src0->type ,
212342 ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
213343 CUDA_CHECK (cudaGetLastError ());
214344 }
@@ -218,11 +348,11 @@ void ggml_cuda_mul_mat_q(
218348
219349 // Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid.
220350 const mmq_args args = {
221- src0_d, src0->type , (const int *) src1_q8_1.ptr , ids_dst_dev, expert_bounds_dev , dst_d,
351+ src0_d, src0->type , (const int *) src1_q8_1.get (), ids_dst. get (), expert_bounds. get () , dst_d,
222352 ne00, ne01, ne_get_rows, s01, ne_get_rows, s1,
223353 ne02, ne02, s02, s12, s2,
224354 ne03, ne13, s03, s13, s3,
225- use_stream_k};
355+ use_stream_k, ne12 };
226356
227357 ggml_cuda_mul_mat_q_switch_type (ctx, args, stream);
228358}
@@ -262,7 +392,7 @@ void ggml_cuda_op_mul_mat_q(
262392 ne00, row_diff, src1_ncols, stride01, ne11, nrows_dst,
263393 1 , 1 , 0 , 0 , 0 ,
264394 1 , 1 , 0 , 0 , 0 ,
265- use_stream_k};
395+ use_stream_k, src1_ncols };
266396
267397 ggml_cuda_mul_mat_q_switch_type (ctx, args, stream);
268398
0 commit comments