@@ -44,6 +44,245 @@ __global__ void apply_repetition_penalties_kernel(
44
44
}
45
45
}
46
46
47
+ static inline __device__ uint16_t extractBinIdx (float x) {
48
+ union {
49
+ __half h;
50
+ uint16_t u16 ;
51
+ } tmp;
52
+ tmp.h = __float2half_rn (x);
53
+ tmp.u16 = (x < 0 .f ) ? (~tmp.u16 & 0xffff ) : (tmp.u16 | 0x8000 );
54
+ return 511 - (tmp.u16 >> 7 );
55
+ }
56
+
57
+ template <int kNumThreadsPerBlock = 512 >
58
+ static __global__ void topKPerRow (const float * logits, const int * rowStarts,
59
+ const int * rowEnds, int * outIndices,
60
+ float * outLogits, int stride0, int stride1) {
61
+ // The number of bins in the histogram.
62
+ static constexpr int kNumBins = 512 ;
63
+
64
+ // The top-k width.
65
+ static constexpr int kTopK = 2048 ;
66
+ // The number of elements per thread for the final top-k sort.
67
+ static constexpr int kNumTopKItemsPerThread = kTopK / kNumThreadsPerBlock ;
68
+ // The class to sort the elements during the final top-k sort.
69
+ using TopKSort = cub::BlockRadixSort<float , kNumThreadsPerBlock ,
70
+ kNumTopKItemsPerThread , int >;
71
+
72
+ // The number of slots for the final pass.
73
+ static constexpr int kNumFinalItems = 3072 ;
74
+ // The number of elements per thread for the final sort.
75
+ static constexpr int kNumFinalItemsPerThread =
76
+ kNumFinalItems / kNumThreadsPerBlock ;
77
+ // The class to sort the elements during the final pass.
78
+ using FinalSort = cub::BlockRadixSort<float , kNumThreadsPerBlock ,
79
+ kNumFinalItemsPerThread , int >;
80
+
81
+ // The class to compute the inclusive prefix-sum over the histogram.
82
+ using Scan = cub::BlockScan<int , kNumThreadsPerBlock >;
83
+
84
+ // Shared memory to compute the block scan.
85
+ __shared__ typename Scan::TempStorage smemScan;
86
+
87
+ // The structure to store the final items (for the final pass).
88
+ struct FinalItems {
89
+ // Shared memory to store the indices for the final pass.
90
+ int indices[kNumFinalItems ];
91
+ // Shared memory to store the logits for the final pass.
92
+ float logits[kNumFinalItems ];
93
+ };
94
+
95
+ // Shared memory to compute the block sort.
96
+ __shared__ union {
97
+ FinalItems items;
98
+ typename FinalSort::TempStorage finalSort;
99
+ typename TopKSort::TempStorage topKSort;
100
+ } smemFinal;
101
+
102
+ // Shared memory to store the histogram.
103
+ __shared__ int smemHistogram[kNumBins ];
104
+ // Shared memory to store the selected indices.
105
+ __shared__ int smemIndices[kTopK ];
106
+ // Shared memory to store the selected logits.
107
+ __shared__ float smemLogits[kTopK ];
108
+ // Shared memory to store the threshold bin.
109
+ __shared__ int smemThresholdBinIdx[1 ];
110
+ // Shared memory counter to register the candidates for the final phase.
111
+ __shared__ int smemFinalDstIdx[1 ];
112
+
113
+ // The row computed by this block.
114
+ int rowIdx = blockIdx .x ;
115
+ // The range of logits within the row.
116
+ int rowStart = rowStarts[rowIdx], rowEnd = rowEnds[rowIdx];
117
+ // The length of the row.
118
+ int rowLen = rowEnd - rowStart;
119
+
120
+ // Shortcut if the length of the row is smaller than Top-K. Indices are not
121
+ // sorted by their corresponding logit.
122
+ if (rowLen <= kTopK ) {
123
+ for (int rowIt = threadIdx .x ; rowIt < rowLen;
124
+ rowIt += kNumThreadsPerBlock ) {
125
+ int idx = rowStart + rowIt;
126
+ outIndices[rowIdx * kTopK + rowIt] = idx - rowStart;
127
+ outLogits[rowIdx * kTopK + rowIt] =
128
+ logits[rowIdx * stride0 + idx * stride1];
129
+ }
130
+ for (int rowIt = rowLen + threadIdx .x ; rowIt < kTopK ;
131
+ rowIt += kNumThreadsPerBlock ) {
132
+ outIndices[rowIdx * kTopK + rowIt] = -1 ;
133
+ outLogits[rowIdx * kTopK + rowIt] = -FLT_MAX;
134
+ }
135
+ return ;
136
+ }
137
+
138
+ // Clear the histogram.
139
+ if (threadIdx .x < kNumBins ) {
140
+ smemHistogram[threadIdx .x ] = 0 ;
141
+ }
142
+
143
+ // Make sure the histogram is ready.
144
+ __syncthreads ();
145
+
146
+ // Fetch elements one-by-one.
147
+ for (int rowIt = rowStart + threadIdx .x ; rowIt < rowEnd;
148
+ rowIt += kNumThreadsPerBlock ) {
149
+ uint16_t idx = extractBinIdx (logits[rowIdx * stride0 + rowIt * stride1]);
150
+ atomicAdd (&smemHistogram[idx], 1 );
151
+ }
152
+
153
+ // Make sure the histogram is ready.
154
+ __syncthreads ();
155
+
156
+ // Read the values from SMEM.
157
+ int binCount{0 };
158
+ if (threadIdx .x < kNumBins ) {
159
+ binCount = smemHistogram[threadIdx .x ];
160
+ }
161
+
162
+ // Make sure each thread has read its value.
163
+ __syncthreads ();
164
+
165
+ // Compute the prefix sum.
166
+ int prefixSum{0 }, totalSum{0 };
167
+ Scan (smemScan).ExclusiveSum (binCount, prefixSum, totalSum);
168
+
169
+ // Update the histogram with the prefix sums.
170
+ if (threadIdx .x < kNumBins ) {
171
+ smemHistogram[threadIdx .x ] = prefixSum;
172
+ }
173
+
174
+ // Make sure the data is in shared memory.
175
+ __syncthreads ();
176
+
177
+ // Find the last valid bin.
178
+ if (threadIdx .x < kNumBins ) {
179
+ int nextPrefixSum =
180
+ threadIdx .x == kNumBins - 1 ? totalSum : smemHistogram[threadIdx .x + 1 ];
181
+ if (prefixSum < kTopK && nextPrefixSum >= kTopK ) {
182
+ smemThresholdBinIdx[0 ] = threadIdx .x ;
183
+ }
184
+ }
185
+
186
+ // Clear the counter to store the items for the final phase.
187
+ if (threadIdx .x == 0 ) {
188
+ smemFinalDstIdx[0 ] = 0 ;
189
+ }
190
+
191
+ // Make sure the data is in shared memory.
192
+ __syncthreads ();
193
+
194
+ // The threshold bin.
195
+ int thresholdBinIdx = smemThresholdBinIdx[0 ];
196
+
197
+ // Fetch elements one-by-one and populate the shared memory buffers.
198
+ for (int rowIt = rowStart + threadIdx .x ; rowIt < rowEnd;
199
+ rowIt += kNumThreadsPerBlock ) {
200
+ float logit = logits[rowIdx * stride0 + rowIt * stride1];
201
+ uint16_t idx = extractBinIdx (logit);
202
+ if (idx < thresholdBinIdx) {
203
+ int dstIdx = atomicAdd (&smemHistogram[idx], 1 );
204
+ smemLogits[dstIdx] = logit;
205
+ smemIndices[dstIdx] = rowIt;
206
+ } else if (idx == thresholdBinIdx) {
207
+ int dstIdx = atomicAdd (&smemFinalDstIdx[0 ], 1 );
208
+ if (dstIdx < kNumFinalItems ) {
209
+ smemFinal.items .logits [dstIdx] = logit;
210
+ smemFinal.items .indices [dstIdx] = rowIt;
211
+ }
212
+ }
213
+ }
214
+
215
+ // Make sure the elements are in shared memory.
216
+ __syncthreads ();
217
+
218
+ // The logits of the elements to be sorted in the final pass.
219
+ float finalLogits[kNumFinalItemsPerThread ];
220
+ // The indices of the elements to be sorted in the final pass.
221
+ int finalIndices[kNumFinalItemsPerThread ];
222
+
223
+ // Init.
224
+ #pragma unroll
225
+ for (int ii = 0 ; ii < kNumFinalItemsPerThread ; ++ii) {
226
+ finalLogits[ii] = -FLT_MAX;
227
+ }
228
+
229
+ // Read the elements from SMEM.
230
+ #pragma unroll
231
+ for (int ii = 0 ; ii < kNumFinalItemsPerThread ; ++ii) {
232
+ int srcIdx = ii * kNumThreadsPerBlock + threadIdx .x ;
233
+ if (srcIdx < smemFinalDstIdx[0 ]) {
234
+ finalLogits[ii] = smemFinal.items .logits [srcIdx];
235
+ finalIndices[ii] = smemFinal.items .indices [srcIdx];
236
+ }
237
+ }
238
+
239
+ // Make sure the shared memory has been read.
240
+ __syncthreads ();
241
+
242
+ // Sort the elements.
243
+ FinalSort (smemFinal.finalSort )
244
+ .SortDescendingBlockedToStriped (finalLogits, finalIndices);
245
+
246
+ // Copy the data back to the shared memory storage.
247
+ int baseIdx = thresholdBinIdx > 0 ? smemHistogram[thresholdBinIdx - 1 ] : 0 ;
248
+ #pragma unroll
249
+ for (int ii = 0 ; ii < kNumFinalItemsPerThread ; ++ii) {
250
+ int srcIdx = ii * kNumThreadsPerBlock + threadIdx .x ;
251
+ int dstIdx = baseIdx + srcIdx;
252
+ if (dstIdx < kTopK ) {
253
+ smemLogits[dstIdx] = finalLogits[ii];
254
+ smemIndices[dstIdx] = finalIndices[ii];
255
+ }
256
+ }
257
+
258
+ // Make sure the data is in shared memory.
259
+ __syncthreads ();
260
+
261
+ // The topK logits.
262
+ float topKLogits[kNumTopKItemsPerThread ];
263
+ // The topK indices.
264
+ int topKIndices[kNumTopKItemsPerThread ];
265
+
266
+ // Load from shared memory.
267
+ #pragma unroll
268
+ for (int ii = 0 ; ii < kNumTopKItemsPerThread ; ++ii) {
269
+ topKLogits[ii] = smemLogits[ii * kNumThreadsPerBlock + threadIdx .x ];
270
+ topKIndices[ii] = smemIndices[ii * kNumThreadsPerBlock + threadIdx .x ];
271
+ }
272
+
273
+ // Sort the elements.
274
+ TopKSort (smemFinal.topKSort )
275
+ .SortDescendingBlockedToStriped (topKLogits, topKIndices);
276
+
277
+ // Store to global memory.
278
+ #pragma unroll
279
+ for (int ii = 0 ; ii < kNumTopKItemsPerThread ; ++ii) {
280
+ int offset = rowIdx * kTopK + ii * kNumThreadsPerBlock + threadIdx .x ;
281
+ outIndices[offset] = topKIndices[ii] - rowStart;
282
+ outLogits[offset] = topKLogits[ii];
283
+ }
284
+ }
285
+
47
286
} // namespace vllm
48
287
49
288
void apply_repetition_penalties_ (
@@ -85,4 +324,20 @@ void apply_repetition_penalties_(
85
324
repetition_penalties.data_ptr <scalar_t >(), num_seqs, vocab_size,
86
325
tile_size);
87
326
});
88
- }
327
+ }
328
+
329
+ void top_k_per_row (const torch::Tensor& logits, const torch::Tensor& rowStarts,
330
+ const torch::Tensor& rowEnds, torch::Tensor& indices,
331
+ torch::Tensor& values, int64_t numRows, int64_t stride0,
332
+ int64_t stride1) {
333
+ // Compute the results on the device.
334
+ constexpr int kNumThreadsPerBlock = 512 ;
335
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
336
+
337
+ vllm::topKPerRow<kNumThreadsPerBlock >
338
+ <<<numRows, kNumThreadsPerBlock , 0 , stream>>> (
339
+ logits.data_ptr <float >(), rowStarts.data_ptr <int >(),
340
+ rowEnds.data_ptr <int >(), indices.data_ptr <int >(),
341
+ values.data_ptr <float >(), static_cast <int >(stride0),
342
+ static_cast <int >(stride1));
343
+ }
0 commit comments