Skip to content

Commit 97b20d4

Browse files
tianxinLiuChiachiLiuChiachi
authored
Support pre-normalization for FT Encoder fp32 (PaddlePaddle#974)
* add encoder op * support fp16 * support pre-normalization for FT Encoder fp32 * update CMakelists.txt * move code to cuda_kernel_h and cuda_kernel_cu * uncomment FT sample * add jiaqi code * self-attn output diff * finish pre-normalization v1 * implement generalized kernel * add v1 encoder * delete assert for generalize version kernel * support post-normalization fp32 for FT Encoder * delete unused CMake commands Co-authored-by: LiuChiaChi <[email protected]> Co-authored-by: LiuChiachi <[email protected]> Co-authored-by: Jiaqi Liu <[email protected]>
1 parent b659f1f commit 97b20d4

File tree

9 files changed

+618
-250
lines changed

9 files changed

+618
-250
lines changed

paddlenlp/ops/CMakeLists.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,6 @@ file(TO_NATIVE_PATH ${OPS_SOURCE_DIR}/patches/FasterTransformer/bert_encoder_tra
178178
file(TO_NATIVE_PATH ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/fastertransformer/bert_encoder_transformer.h bert_encoder_transformer_h_dst)
179179

180180
set(OPT_OPEN_ATTN_COMMAND sed -i -e "370,392d" -e "410,454d" -e "229d" ${open_attention_h_dst})
181-
#set(OPT_BERT_ENCODER_COMMAND sed -i -e "552,592d" -e "118a bool is_gelu_=true;" ${bert_encoder_transformer_h_dst})
182181

183182
# TODO(guosheng): `find` seems meeting errors missing argument to `-exec', fix it
184183
set(MUTE_COMMAND grep -rl "printf(\"\\[WARNING\\]" ${CMAKE_BINARY_DIR}/${THIRD_PATH}/source/${THIRD_PARTY_NAME}/ | xargs -i{} sed -i "s/printf(\"\\WWARNING\\W decoding[^)]\\{1,\\})/ /" {})
@@ -305,5 +304,3 @@ if(ON_INFER AND WITH_GPT AND WITH_SP)
305304
endif()
306305

307306
add_subdirectory(faster_transformer)
308-
309-

paddlenlp/ops/faster_transformer/src/fusion_encoder_op.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ std::vector<paddle::Tensor> EncoderForward(
4646
const int64_t& num_layer,
4747
const int64_t& layer_idx,
4848
const bool& allow_gemm_test,
49-
const bool& use_trt_kernel) {
49+
const bool& use_trt_kernel,
50+
const bool& normalize_before) {
5051
if (input.place() == paddle::PlaceType::kGPU) {
5152
auto shape = input.shape();
5253
auto encoder_out = paddle::Tensor(paddle::PlaceType::kGPU, shape);
@@ -80,7 +81,8 @@ std::vector<paddle::Tensor> EncoderForward(
8081
num_layer,
8182
layer_idx,
8283
allow_gemm_test,
83-
use_trt_kernel);
84+
use_trt_kernel,
85+
normalize_before);
8486
} else {
8587
PD_THROW("Not implemented place. Only GPU is supported. ");
8688
}
@@ -116,7 +118,8 @@ std::vector<std::vector<int64_t>> EncoderInferShape(
116118
const int64_t& num_layer,
117119
const int64_t& layer_idx,
118120
const bool& allow_gemm_test,
119-
const bool& use_trt_kernel) {
121+
const bool& use_trt_kernel,
122+
const bool& normalize_before) {
120123
return {input_shape};
121124
}
122125

@@ -179,7 +182,8 @@ PD_BUILD_OP(fusion_encoder)
179182
"num_layer: int64_t",
180183
"layer_idx: int64_t",
181184
"allow_gemm_test: bool",
182-
"use_trt_kernel: bool"})
185+
"use_trt_kernel: bool",
186+
"normalize_before: bool"})
183187
.SetKernelFn(PD_KERNEL(EncoderForward))
184188
.SetInferShapeFn(PD_INFER_SHAPE(EncoderInferShape))
185189
.SetInferDtypeFn(PD_INFER_DTYPE(EncoderInferDtype));

paddlenlp/ops/faster_transformer/src/fusion_encoder_op.cu

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ std::vector<paddle::Tensor> encoder_kernel(
6161
int64_t layer_idx_,
6262
bool allow_gemm_test,
6363
bool use_trt_kernel_,
64+
bool normalize_before,
6465
cublasHandle_t cublas_handle_,
6566
cudaStream_t stream) {
6667
int batch_size_ = input.shape()[0];
@@ -148,7 +149,8 @@ std::vector<paddle::Tensor> encoder_kernel(
148149
// }
149150

150151
BertEncoderTransformer<EncoderTraits_>* encoder =
151-
new BertEncoderTransformer<EncoderTraits_>(int8_mode, allow_gemm_test);
152+
new BertEncoderTransformer<EncoderTraits_>(
153+
int8_mode, allow_gemm_test, normalize_before);
152154

153155
encoder->allocateBuffer(allocator_,
154156
batch_size_,
@@ -199,7 +201,8 @@ std::vector<paddle::Tensor> EncoderCUDAForward(
199201
int64_t num_layer,
200202
int64_t layer_idx,
201203
bool allow_gemm_test,
202-
bool use_trt_kernel) {
204+
bool use_trt_kernel,
205+
bool normalize_before) {
203206
auto stream = input.stream();
204207
cublasHandle_t cublas_handle_;
205208
cublasCreate(&cublas_handle_);
@@ -241,6 +244,7 @@ std::vector<paddle::Tensor> EncoderCUDAForward(
241244
layer_idx,
242245
allow_gemm_test,
243246
use_trt_kernel,
247+
normalize_before,
244248
cublas_handle_,
245249
stream);
246250

@@ -279,6 +283,7 @@ std::vector<paddle::Tensor> EncoderCUDAForward(
279283
layer_idx,
280284
allow_gemm_test,
281285
use_trt_kernel,
286+
normalize_before,
282287
cublas_handle_,
283288
stream);
284289
break;

paddlenlp/ops/faster_transformer/src/fusion_encoder_op.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,4 +57,5 @@ std::vector<paddle::Tensor> EncoderCUDAForward(
5757
int64_t num_layer_,
5858
int64_t layer_idx_,
5959
bool allow_gemm_test,
60-
bool use_trt_kernel_);
60+
bool use_trt_kernel_,
61+
bool normalize_before);

paddlenlp/ops/faster_transformer/transformer/encoder.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ def infer_transformer_encoder(
5252
int8_mode=0,
5353
layer_idx=0,
5454
allow_gemm_test=False,
55-
use_trt_kernel=False):
55+
use_trt_kernel=False,
56+
normalize_before=False):
5657
"""
5758
Fusion Encoder API intergrating Encoder inference in FasterTransformer. It
5859
accepts the weight and bias of TransformerEncoder and some other parameters
@@ -92,6 +93,7 @@ def infer_transformer_encoder(
9293
'layer_idx': layer_idx,
9394
'allow_gemm_test': allow_gemm_test,
9495
'use_trt_kernel': use_trt_kernel,
96+
'normalize_before': normalize_before
9597
}
9698
encoder_out = helper.create_variable(dtype=input.dtype)
9799
outputs = {"EncoderOut": encoder_out}
@@ -173,7 +175,8 @@ def encoder_layer_forward(self,
173175
# amax_list=paddle.to_tensor([]), # int8 mode is not supported.
174176
n_head=self._config['nhead'],
175177
size_per_head=self._config['d_model'] // self._config['nhead'],
176-
is_gelu=self._config['activation'] == 'gelu')
178+
is_gelu=self._config['activation'] == 'gelu',
179+
normalize_before=self._config['normalize_before'] == True)
177180
return src
178181

179182

paddlenlp/ops/patches/FasterTransformer/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,4 +223,4 @@ link_directories(
223223

224224
add_subdirectory(fastertransformer)
225225
add_subdirectory(tools)
226-
add_subdirectory(sample)
226+
#add_subdirectory(sample)

0 commit comments

Comments
 (0)