Skip to content

Commit c3705e4

Browse files
authored
Add lightseq hierarchical beam_search (PaddlePaddle#978)
* Add offset mapping doc * fix eval hang because of unique endpoint * generate api support encoder-decoder * Add lightseq beam_search * optimize performence * add blockroughk kernel * optimize * minor fix
1 parent 74dc997 commit c3705e4

File tree

3 files changed

+159
-0
lines changed

3 files changed

+159
-0
lines changed

paddlenlp/ops/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,8 @@ file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}
140140
file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/cuda/topk_kernels.cu topk_kernels_src)
141141
file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/cuda/topk_kernels.cu topk_kernels_dst)
142142

143+
file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/cuda/lightseq_kernels.cu lightseq_kernels_cu_src)
144+
143145
file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/cuda/open_decoder.cu open_decoder_cu_dst)
144146
file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/open_decoder.h open_decoder_h_dst)
145147

@@ -190,6 +192,7 @@ set(FT_PATCH_COMMAND
190192
&& cp ${arguments_h_src} ${trans_dst}
191193
&& cp ${bert_encoder_transformer_h_src} ${bert_encoder_transformer_h_dst}
192194
&& cat ${cuda_kernels_h_src} >> ${cuda_kernels_h_dst}
195+
&& cat ${lightseq_kernels_cu_src} >> ${topk_kernels_dst}
193196
&& cat ${cuda_kernels_cu_src} >> ${cuda_kernels_cu_dst}
194197
&& cat ${decoding_kernels_cu_src} >> ${decoding_kernels_cu_dst}
195198
&& cat ${topk_kernels_cuh_src} >> ${topk_kernels_cuh_dst}
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
2+
namespace fastertransformer {
3+
4+
const unsigned int WARP_REDUCE_MASK = 0xffffffff;
5+
const float CUDA_FLOAT_INF_NEG = -100000000.f;
6+
const unsigned int WARP_SIZE = 32;
7+
8+
template <typename T>
9+
__forceinline__ __device__ T warpReduceMax(T val) {
10+
for (int mask = (WARP_SIZE >> 1); mask > 0; mask >>= 1)
11+
val = max(val, __shfl_xor_sync(WARP_REDUCE_MASK, val, mask, WARP_SIZE));
12+
return val;
13+
}
14+
15+
16+
/* Calculate the maximum of all elements in a block */
17+
template <typename T>
18+
__forceinline__ __device__ T blockReduceMax(T val) {
19+
static __shared__ T shared[32];
20+
int lane = threadIdx.x & 0x1f;
21+
int wid = threadIdx.x >> 5;
22+
23+
val = warpReduceMax<T>(val);
24+
25+
if (lane == 0) shared[wid] = val;
26+
__syncthreads();
27+
28+
val = (threadIdx.x < ((blockDim.x + 31) >> 5)) ? shared[lane]
29+
: CUDA_FLOAT_INF_NEG;
30+
val = warpReduceMax<T>(val);
31+
return val;
32+
}
33+
34+
/* Calculate the rough topk-th value in a block, rough but safe */
35+
template <typename T, int K>
36+
__forceinline__ __device__ T blockRoughTopK(T val) {
37+
static __shared__ T shared[32];
38+
int lane = threadIdx.x & 0x1f;
39+
int wid = threadIdx.x >> 5;
40+
val = warpReduceMax(val);
41+
42+
if (lane == 0) shared[wid] = val;
43+
__syncthreads();
44+
45+
// we do not care about result of threadIdx.x bigger than (blockDim.x >> 5)
46+
val = (threadIdx.x < (blockDim.x >> 5)) ? shared[lane] : CUDA_FLOAT_INF_NEG;
47+
48+
// K should be 2, 4, 6, 8, 16 or 32
49+
for (int mask = 16; mask >= K; mask >>= 1)
50+
val = max(val, __shfl_xor_sync(WARP_REDUCE_MASK, val, mask, 32));
51+
for (int mask = (K >> 1); mask > 0; mask >>= 1)
52+
val = min(val, __shfl_xor_sync(WARP_REDUCE_MASK, val, mask, 32));
53+
54+
return val;
55+
}
56+
}

paddlenlp/ops/patches/FasterTransformer/cuda/topk_kernels.cu

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,98 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__
6565
}
6666
}
6767

68+
template <typename T, int K>
69+
__forceinline__ __device__ T blockRoughTopK(T val);
70+
71+
template <typename T, int beam_size, int THREADBLOCK_SIZE>
72+
__launch_bounds__(THREADBLOCK_SIZE) __global__
73+
void beam_topK_kernel_hierarchical(const T* log_probs,
74+
T* can_score_buf,
75+
int* can_idx_buf,
76+
int* topk_tmp_id_buf,
77+
T* topk_tmp_val_buf,
78+
const int vocab_size,
79+
T diversity_rate) {
80+
__shared__ T s_topk;
81+
__shared__ int num_cur_beam_can;
82+
typedef cub::BlockReduce<TopK<T, beam_size>, THREADBLOCK_SIZE> BlockReduce;
83+
__shared__ typename BlockReduce::TempStorage temp_storage;
84+
85+
int thread_id = threadIdx.x;
86+
int block_id = blockIdx.x;
87+
const bool IS_FP16 = std::is_same<T, half>::value;
88+
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
89+
T rough_top_kth_logit = -MAX_T_VAL;
90+
91+
#pragma unroll
92+
for (int elem_id = thread_id; elem_id < vocab_size;
93+
elem_id += THREADBLOCK_SIZE) {
94+
int index = elem_id + block_id * vocab_size;
95+
rough_top_kth_logit = fmaxf(rough_top_kth_logit, log_probs[index]);
96+
}
97+
rough_top_kth_logit = blockRoughTopK<float, beam_size>(rough_top_kth_logit);
98+
if (thread_id == 0) {
99+
s_topk = rough_top_kth_logit;
100+
num_cur_beam_can = 0;
101+
}
102+
103+
int idx = block_id * vocab_size + thread_id;
104+
105+
__shared__ int l_n; // current iteration candidate number
106+
for (int iter = 0;
107+
iter < (vocab_size + THREADBLOCK_SIZE - 1) / THREADBLOCK_SIZE;
108+
iter++) {
109+
// zero the counter
110+
if (threadIdx.x == 0) l_n = 0;
111+
__syncthreads();
112+
T lgt = -MAX_T_VAL; // min s_topk is CUDA_FLOAT_INF_NEG
113+
int pos;
114+
int vocab_id = idx - block_id * vocab_size;
115+
116+
if (vocab_id < vocab_size) {
117+
lgt = log_probs[idx];
118+
if (lgt >= s_topk) pos = atomicAdd(&l_n, 1);
119+
}
120+
__syncthreads();
121+
if (threadIdx.x == 0) {
122+
l_n = atomicAdd(&num_cur_beam_can, l_n);
123+
}
124+
__syncthreads();
125+
126+
if (lgt >= s_topk) {
127+
pos += l_n;
128+
can_score_buf[pos + block_id * vocab_size] = lgt;
129+
can_idx_buf[pos + block_id * vocab_size] = idx;
130+
}
131+
__syncthreads();
132+
idx += THREADBLOCK_SIZE;
133+
}
134+
135+
TopK<T, beam_size> partial;
136+
#pragma unroll
137+
for (int i = 0; i < beam_size; ++i) {
138+
partial.p[i] = -1;
139+
partial.u[i] = -MAX_T_VAL;
140+
}
141+
for (int elem_id = thread_id; elem_id < num_cur_beam_can;
142+
elem_id += THREADBLOCK_SIZE) {
143+
int index = elem_id + block_id * vocab_size;
144+
partial.insert(can_score_buf[index], index);
145+
}
146+
TopK<T, beam_size> total =
147+
BlockReduce(temp_storage).Reduce(partial, reduce_topk_op<T, beam_size>);
148+
149+
if (thread_id == 0) {
150+
int index = block_id * beam_size;
151+
152+
#pragma unroll
153+
for (int i = 0; i < beam_size; ++i) {
154+
topk_tmp_id_buf[index + i] = can_idx_buf[total.p[i]];
155+
topk_tmp_val_buf[index + i] = total.u[i] + diversity_rate * (T)i;
156+
}
157+
}
158+
}
159+
68160
template <typename T, int THREADBLOCK_SIZE>
69161
__global__ void beam_topK_kernel_general(const T* log_probs,
70162
T* tmp_log_probs,
@@ -453,21 +545,29 @@ void topK_kernelLauncher(void* workspace,
453545
batch_size * beam_width * beam_width * max_block_per_beam; // type int
454546
int topk_tmp_val_buf_size =
455547
batch_size * beam_width * beam_width * max_block_per_beam; // type float
548+
// int can_score_buf_size = batch_size * beam_width * vocab_size;
549+
// int can_idx_buf_size = batch_size * beam_width * vocab_size;
456550

457551
// prevent memory misalinged address
458552
temp_log_probs_buf_size = (int)(ceil(temp_log_probs_buf_size / 4.)) * 4;
553+
// can_score_buf_size = (int)(ceil(can_score_buf_size / 4.)) * 4;
554+
// can_idx_buf_size = (int)(ceil(can_idx_buf_size / 4.)) * 4;
459555
topk_tmp_ids_buf_size = (int)(ceil(topk_tmp_ids_buf_size / 4.)) * 4;
460556
topk_tmp_val_buf_size = (int)(ceil(topk_tmp_val_buf_size / 4.)) * 4;
461557

462558
if (workspace == nullptr) {
463559
workspace_size = sizeof(float) * temp_log_probs_buf_size +
464560
sizeof(int) * topk_tmp_ids_buf_size +
465561
sizeof(float) * topk_tmp_val_buf_size;
562+
// sizeof(float) * can_score_buf_size +
563+
// sizeof(int) * can_idx_buf_size;
466564
return;
467565
} else {
468566
T* temp_log_probs = (T*)workspace;
469567
int* topk_tmp_id_buf = (int*)(temp_log_probs + temp_log_probs_buf_size);
470568
T* topk_tmp_val_buf = (T*)(topk_tmp_id_buf + topk_tmp_ids_buf_size);
569+
// T* can_score_buf = (T*)(topk_tmp_val_buf + topk_tmp_val_buf_size);
570+
// int* can_idx_buf = (int*)(can_score_buf + can_score_buf_size);
471571
if (diversity_rate == 0.0f) {
472572
switch (beam_width) {
473573
CASE_K(1, 128, 128, 8);

0 commit comments

Comments
 (0)