Skip to content

Commit e1098ce

Browse files
dcamporayoukaichao
andauthored
Add topk logits torch op for DS3.2. (#25945)
Signed-off-by: Daniel Campora <[email protected]> Signed-off-by: Daniel Cámpora <[email protected]> Co-authored-by: youkaichao <[email protected]>
1 parent d100d78 commit e1098ce

File tree

5 files changed

+446
-25
lines changed

5 files changed

+446
-25
lines changed

csrc/ops.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,11 @@ void apply_repetition_penalties_(torch::Tensor& logits,
100100
const torch::Tensor& output_mask,
101101
const torch::Tensor& repetition_penalties);
102102

103+
void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts,
104+
const torch::Tensor& rowEnds, torch::Tensor& indices,
105+
torch::Tensor& values, int64_t numRows, int64_t stride0,
106+
int64_t stride1);
107+
103108
void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
104109
torch::Tensor& weight, torch::Tensor& scale,
105110
double epsilon);

csrc/sampler.cu

Lines changed: 256 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,245 @@ __global__ void apply_repetition_penalties_kernel(
4444
}
4545
}
4646

47+
static inline __device__ uint16_t extractBinIdx(float x) {
48+
union {
49+
__half h;
50+
uint16_t u16;
51+
} tmp;
52+
tmp.h = __float2half_rn(x);
53+
tmp.u16 = (x < 0.f) ? (~tmp.u16 & 0xffff) : (tmp.u16 | 0x8000);
54+
return 511 - (tmp.u16 >> 7);
55+
}
56+
57+
template <int kNumThreadsPerBlock = 512>
58+
static __global__ void topKPerRow(const float* logits, const int* rowStarts,
59+
const int* rowEnds, int* outIndices,
60+
float* outLogits, int stride0, int stride1) {
61+
// The number of bins in the histogram.
62+
static constexpr int kNumBins = 512;
63+
64+
// The top-k width.
65+
static constexpr int kTopK = 2048;
66+
// The number of elements per thread for the final top-k sort.
67+
static constexpr int kNumTopKItemsPerThread = kTopK / kNumThreadsPerBlock;
68+
// The class to sort the elements during the final top-k sort.
69+
using TopKSort = cub::BlockRadixSort<float, kNumThreadsPerBlock,
70+
kNumTopKItemsPerThread, int>;
71+
72+
// The number of slots for the final pass.
73+
static constexpr int kNumFinalItems = 3072;
74+
// The number of elements per thread for the final sort.
75+
static constexpr int kNumFinalItemsPerThread =
76+
kNumFinalItems / kNumThreadsPerBlock;
77+
// The class to sort the elements during the final pass.
78+
using FinalSort = cub::BlockRadixSort<float, kNumThreadsPerBlock,
79+
kNumFinalItemsPerThread, int>;
80+
81+
// The class to compute the inclusive prefix-sum over the histogram.
82+
using Scan = cub::BlockScan<int, kNumThreadsPerBlock>;
83+
84+
// Shared memory to compute the block scan.
85+
__shared__ typename Scan::TempStorage smemScan;
86+
87+
// The structure to store the final items (for the final pass).
88+
struct FinalItems {
89+
// Shared memory to store the indices for the final pass.
90+
int indices[kNumFinalItems];
91+
// Shared memory to store the logits for the final pass.
92+
float logits[kNumFinalItems];
93+
};
94+
95+
// Shared memory to compute the block sort.
96+
__shared__ union {
97+
FinalItems items;
98+
typename FinalSort::TempStorage finalSort;
99+
typename TopKSort::TempStorage topKSort;
100+
} smemFinal;
101+
102+
// Shared memory to store the histogram.
103+
__shared__ int smemHistogram[kNumBins];
104+
// Shared memory to store the selected indices.
105+
__shared__ int smemIndices[kTopK];
106+
// Shared memory to store the selected logits.
107+
__shared__ float smemLogits[kTopK];
108+
// Shared memory to store the threshold bin.
109+
__shared__ int smemThresholdBinIdx[1];
110+
// Shared memory counter to register the candidates for the final phase.
111+
__shared__ int smemFinalDstIdx[1];
112+
113+
// The row computed by this block.
114+
int rowIdx = blockIdx.x;
115+
// The range of logits within the row.
116+
int rowStart = rowStarts[rowIdx], rowEnd = rowEnds[rowIdx];
117+
// The length of the row.
118+
int rowLen = rowEnd - rowStart;
119+
120+
// Shortcut if the length of the row is smaller than Top-K. Indices are not
121+
// sorted by their corresponding logit.
122+
if (rowLen <= kTopK) {
123+
for (int rowIt = threadIdx.x; rowIt < rowLen;
124+
rowIt += kNumThreadsPerBlock) {
125+
int idx = rowStart + rowIt;
126+
outIndices[rowIdx * kTopK + rowIt] = idx - rowStart;
127+
outLogits[rowIdx * kTopK + rowIt] =
128+
logits[rowIdx * stride0 + idx * stride1];
129+
}
130+
for (int rowIt = rowLen + threadIdx.x; rowIt < kTopK;
131+
rowIt += kNumThreadsPerBlock) {
132+
outIndices[rowIdx * kTopK + rowIt] = -1;
133+
outLogits[rowIdx * kTopK + rowIt] = -FLT_MAX;
134+
}
135+
return;
136+
}
137+
138+
// Clear the histogram.
139+
if (threadIdx.x < kNumBins) {
140+
smemHistogram[threadIdx.x] = 0;
141+
}
142+
143+
// Make sure the histogram is ready.
144+
__syncthreads();
145+
146+
// Fetch elements one-by-one.
147+
for (int rowIt = rowStart + threadIdx.x; rowIt < rowEnd;
148+
rowIt += kNumThreadsPerBlock) {
149+
uint16_t idx = extractBinIdx(logits[rowIdx * stride0 + rowIt * stride1]);
150+
atomicAdd(&smemHistogram[idx], 1);
151+
}
152+
153+
// Make sure the histogram is ready.
154+
__syncthreads();
155+
156+
// Read the values from SMEM.
157+
int binCount{0};
158+
if (threadIdx.x < kNumBins) {
159+
binCount = smemHistogram[threadIdx.x];
160+
}
161+
162+
// Make sure each thread has read its value.
163+
__syncthreads();
164+
165+
// Compute the prefix sum.
166+
int prefixSum{0}, totalSum{0};
167+
Scan(smemScan).ExclusiveSum(binCount, prefixSum, totalSum);
168+
169+
// Update the histogram with the prefix sums.
170+
if (threadIdx.x < kNumBins) {
171+
smemHistogram[threadIdx.x] = prefixSum;
172+
}
173+
174+
// Make sure the data is in shared memory.
175+
__syncthreads();
176+
177+
// Find the last valid bin.
178+
if (threadIdx.x < kNumBins) {
179+
int nextPrefixSum =
180+
threadIdx.x == kNumBins - 1 ? totalSum : smemHistogram[threadIdx.x + 1];
181+
if (prefixSum < kTopK && nextPrefixSum >= kTopK) {
182+
smemThresholdBinIdx[0] = threadIdx.x;
183+
}
184+
}
185+
186+
// Clear the counter to store the items for the final phase.
187+
if (threadIdx.x == 0) {
188+
smemFinalDstIdx[0] = 0;
189+
}
190+
191+
// Make sure the data is in shared memory.
192+
__syncthreads();
193+
194+
// The threshold bin.
195+
int thresholdBinIdx = smemThresholdBinIdx[0];
196+
197+
// Fetch elements one-by-one and populate the shared memory buffers.
198+
for (int rowIt = rowStart + threadIdx.x; rowIt < rowEnd;
199+
rowIt += kNumThreadsPerBlock) {
200+
float logit = logits[rowIdx * stride0 + rowIt * stride1];
201+
uint16_t idx = extractBinIdx(logit);
202+
if (idx < thresholdBinIdx) {
203+
int dstIdx = atomicAdd(&smemHistogram[idx], 1);
204+
smemLogits[dstIdx] = logit;
205+
smemIndices[dstIdx] = rowIt;
206+
} else if (idx == thresholdBinIdx) {
207+
int dstIdx = atomicAdd(&smemFinalDstIdx[0], 1);
208+
if (dstIdx < kNumFinalItems) {
209+
smemFinal.items.logits[dstIdx] = logit;
210+
smemFinal.items.indices[dstIdx] = rowIt;
211+
}
212+
}
213+
}
214+
215+
// Make sure the elements are in shared memory.
216+
__syncthreads();
217+
218+
// The logits of the elements to be sorted in the final pass.
219+
float finalLogits[kNumFinalItemsPerThread];
220+
// The indices of the elements to be sorted in the final pass.
221+
int finalIndices[kNumFinalItemsPerThread];
222+
223+
// Init.
224+
#pragma unroll
225+
for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) {
226+
finalLogits[ii] = -FLT_MAX;
227+
}
228+
229+
// Read the elements from SMEM.
230+
#pragma unroll
231+
for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) {
232+
int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x;
233+
if (srcIdx < smemFinalDstIdx[0]) {
234+
finalLogits[ii] = smemFinal.items.logits[srcIdx];
235+
finalIndices[ii] = smemFinal.items.indices[srcIdx];
236+
}
237+
}
238+
239+
// Make sure the shared memory has been read.
240+
__syncthreads();
241+
242+
// Sort the elements.
243+
FinalSort(smemFinal.finalSort)
244+
.SortDescendingBlockedToStriped(finalLogits, finalIndices);
245+
246+
// Copy the data back to the shared memory storage.
247+
int baseIdx = thresholdBinIdx > 0 ? smemHistogram[thresholdBinIdx - 1] : 0;
248+
#pragma unroll
249+
for (int ii = 0; ii < kNumFinalItemsPerThread; ++ii) {
250+
int srcIdx = ii * kNumThreadsPerBlock + threadIdx.x;
251+
int dstIdx = baseIdx + srcIdx;
252+
if (dstIdx < kTopK) {
253+
smemLogits[dstIdx] = finalLogits[ii];
254+
smemIndices[dstIdx] = finalIndices[ii];
255+
}
256+
}
257+
258+
// Make sure the data is in shared memory.
259+
__syncthreads();
260+
261+
// The topK logits.
262+
float topKLogits[kNumTopKItemsPerThread];
263+
// The topK indices.
264+
int topKIndices[kNumTopKItemsPerThread];
265+
266+
// Load from shared memory.
267+
#pragma unroll
268+
for (int ii = 0; ii < kNumTopKItemsPerThread; ++ii) {
269+
topKLogits[ii] = smemLogits[ii * kNumThreadsPerBlock + threadIdx.x];
270+
topKIndices[ii] = smemIndices[ii * kNumThreadsPerBlock + threadIdx.x];
271+
}
272+
273+
// Sort the elements.
274+
TopKSort(smemFinal.topKSort)
275+
.SortDescendingBlockedToStriped(topKLogits, topKIndices);
276+
277+
// Store to global memory.
278+
#pragma unroll
279+
for (int ii = 0; ii < kNumTopKItemsPerThread; ++ii) {
280+
int offset = rowIdx * kTopK + ii * kNumThreadsPerBlock + threadIdx.x;
281+
outIndices[offset] = topKIndices[ii] - rowStart;
282+
outLogits[offset] = topKLogits[ii];
283+
}
284+
}
285+
47286
} // namespace vllm
48287

49288
void apply_repetition_penalties_(
@@ -85,4 +324,20 @@ void apply_repetition_penalties_(
85324
repetition_penalties.data_ptr<scalar_t>(), num_seqs, vocab_size,
86325
tile_size);
87326
});
88-
}
327+
}
328+
329+
void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts,
330+
const torch::Tensor& rowEnds, torch::Tensor& indices,
331+
torch::Tensor& values, int64_t numRows, int64_t stride0,
332+
int64_t stride1) {
333+
// Compute the results on the device.
334+
constexpr int kNumThreadsPerBlock = 512;
335+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
336+
337+
vllm::topKPerRow<kNumThreadsPerBlock>
338+
<<<numRows, kNumThreadsPerBlock, 0, stream>>>(
339+
logits.data_ptr<float>(), rowStarts.data_ptr<int>(),
340+
rowEnds.data_ptr<int>(), indices.data_ptr<int>(),
341+
values.data_ptr<float>(), static_cast<int>(stride0),
342+
static_cast<int>(stride1));
343+
}

csrc/torch_bindings.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
188188
ops.impl("apply_repetition_penalties_", torch::kCUDA,
189189
&apply_repetition_penalties_);
190190

191+
// Optimized top-k per row operation
192+
ops.def(
193+
"top_k_per_row(Tensor logits, Tensor rowStarts, Tensor rowEnds, "
194+
"Tensor! indices, Tensor! values, int numRows, int stride0, "
195+
"int stride1) -> ()");
196+
ops.impl("top_k_per_row", torch::kCUDA, &top_k_per_row);
197+
191198
// Layernorm-quant
192199
// Apply Root Mean Square (RMS) Normalization to the input tensor.
193200
ops.def(

0 commit comments

Comments
 (0)