@@ -65,6 +65,98 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__
65
65
}
66
66
}
67
67
68
+ template <typename T, int K>
69
+ __forceinline__ __device__ T blockRoughTopK (T val);
70
+
71
+ template <typename T, int beam_size, int THREADBLOCK_SIZE>
72
+ __launch_bounds__ (THREADBLOCK_SIZE) __global__
73
+ void beam_topK_kernel_hierarchical (const T* log_probs,
74
+ T* can_score_buf,
75
+ int * can_idx_buf,
76
+ int * topk_tmp_id_buf,
77
+ T* topk_tmp_val_buf,
78
+ const int vocab_size,
79
+ T diversity_rate) {
80
+ __shared__ T s_topk;
81
+ __shared__ int num_cur_beam_can;
82
+ typedef cub::BlockReduce<TopK<T, beam_size>, THREADBLOCK_SIZE> BlockReduce;
83
+ __shared__ typename BlockReduce::TempStorage temp_storage;
84
+
85
+ int thread_id = threadIdx .x ;
86
+ int block_id = blockIdx .x ;
87
+ const bool IS_FP16 = std::is_same<T, half>::value;
88
+ const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
89
+ T rough_top_kth_logit = -MAX_T_VAL;
90
+
91
+ #pragma unroll
92
+ for (int elem_id = thread_id; elem_id < vocab_size;
93
+ elem_id += THREADBLOCK_SIZE) {
94
+ int index = elem_id + block_id * vocab_size;
95
+ rough_top_kth_logit = fmaxf (rough_top_kth_logit, log_probs[index]);
96
+ }
97
+ rough_top_kth_logit = blockRoughTopK<float , beam_size>(rough_top_kth_logit);
98
+ if (thread_id == 0 ) {
99
+ s_topk = rough_top_kth_logit;
100
+ num_cur_beam_can = 0 ;
101
+ }
102
+
103
+ int idx = block_id * vocab_size + thread_id;
104
+
105
+ __shared__ int l_n; // current iteration candidate number
106
+ for (int iter = 0 ;
107
+ iter < (vocab_size + THREADBLOCK_SIZE - 1 ) / THREADBLOCK_SIZE;
108
+ iter++) {
109
+ // zero the counter
110
+ if (threadIdx .x == 0 ) l_n = 0 ;
111
+ __syncthreads ();
112
+ T lgt = -MAX_T_VAL; // min s_topk is CUDA_FLOAT_INF_NEG
113
+ int pos;
114
+ int vocab_id = idx - block_id * vocab_size;
115
+
116
+ if (vocab_id < vocab_size) {
117
+ lgt = log_probs[idx];
118
+ if (lgt >= s_topk) pos = atomicAdd (&l_n, 1 );
119
+ }
120
+ __syncthreads ();
121
+ if (threadIdx .x == 0 ) {
122
+ l_n = atomicAdd (&num_cur_beam_can, l_n);
123
+ }
124
+ __syncthreads ();
125
+
126
+ if (lgt >= s_topk) {
127
+ pos += l_n;
128
+ can_score_buf[pos + block_id * vocab_size] = lgt;
129
+ can_idx_buf[pos + block_id * vocab_size] = idx;
130
+ }
131
+ __syncthreads ();
132
+ idx += THREADBLOCK_SIZE;
133
+ }
134
+
135
+ TopK<T, beam_size> partial;
136
+ #pragma unroll
137
+ for (int i = 0 ; i < beam_size; ++i) {
138
+ partial.p [i] = -1 ;
139
+ partial.u [i] = -MAX_T_VAL;
140
+ }
141
+ for (int elem_id = thread_id; elem_id < num_cur_beam_can;
142
+ elem_id += THREADBLOCK_SIZE) {
143
+ int index = elem_id + block_id * vocab_size;
144
+ partial.insert (can_score_buf[index], index);
145
+ }
146
+ TopK<T, beam_size> total =
147
+ BlockReduce (temp_storage).Reduce (partial, reduce_topk_op<T, beam_size>);
148
+
149
+ if (thread_id == 0 ) {
150
+ int index = block_id * beam_size;
151
+
152
+ #pragma unroll
153
+ for (int i = 0 ; i < beam_size; ++i) {
154
+ topk_tmp_id_buf[index + i] = can_idx_buf[total.p [i]];
155
+ topk_tmp_val_buf[index + i] = total.u [i] + diversity_rate * (T)i;
156
+ }
157
+ }
158
+ }
159
+
68
160
template <typename T, int THREADBLOCK_SIZE>
69
161
__global__ void beam_topK_kernel_general (const T* log_probs,
70
162
T* tmp_log_probs,
@@ -453,21 +545,29 @@ void topK_kernelLauncher(void* workspace,
453
545
batch_size * beam_width * beam_width * max_block_per_beam; // type int
454
546
int topk_tmp_val_buf_size =
455
547
batch_size * beam_width * beam_width * max_block_per_beam; // type float
548
+ // int can_score_buf_size = batch_size * beam_width * vocab_size;
549
+ // int can_idx_buf_size = batch_size * beam_width * vocab_size;
456
550
457
551
// prevent memory misalinged address
458
552
temp_log_probs_buf_size = (int )(ceil (temp_log_probs_buf_size / 4 .)) * 4 ;
553
+ // can_score_buf_size = (int)(ceil(can_score_buf_size / 4.)) * 4;
554
+ // can_idx_buf_size = (int)(ceil(can_idx_buf_size / 4.)) * 4;
459
555
topk_tmp_ids_buf_size = (int )(ceil (topk_tmp_ids_buf_size / 4 .)) * 4 ;
460
556
topk_tmp_val_buf_size = (int )(ceil (topk_tmp_val_buf_size / 4 .)) * 4 ;
461
557
462
558
if (workspace == nullptr ) {
463
559
workspace_size = sizeof (float ) * temp_log_probs_buf_size +
464
560
sizeof (int ) * topk_tmp_ids_buf_size +
465
561
sizeof (float ) * topk_tmp_val_buf_size;
562
+ // sizeof(float) * can_score_buf_size +
563
+ // sizeof(int) * can_idx_buf_size;
466
564
return ;
467
565
} else {
468
566
T* temp_log_probs = (T*)workspace;
469
567
int * topk_tmp_id_buf = (int *)(temp_log_probs + temp_log_probs_buf_size);
470
568
T* topk_tmp_val_buf = (T*)(topk_tmp_id_buf + topk_tmp_ids_buf_size);
569
+ // T* can_score_buf = (T*)(topk_tmp_val_buf + topk_tmp_val_buf_size);
570
+ // int* can_idx_buf = (int*)(can_score_buf + can_score_buf_size);
471
571
if (diversity_rate == 0 .0f ) {
472
572
switch (beam_width) {
473
573
CASE_K (1 , 128 , 128 , 8 );
0 commit comments