@@ -120,7 +120,7 @@ def encoder_layer_forward(self,
120
120
121
121
After inference, `disable_faster_encoder` could be called to restore the
122
122
`forward` function of `paddle.nn.TransformerEncoder` and
123
- `paddle.nn.TransformerEncoder `.
123
+ `paddle.nn.TransformerEncoderLayer `.
124
124
125
125
Args:
126
126
src (Tensor):
@@ -130,14 +130,13 @@ def encoder_layer_forward(self,
130
130
src_mask (Tensor, optional):
131
131
A tensor used in multi-head attention to prevents attention to some
132
132
unwanted positions, usually the paddings or the subsequent
133
- positions. It is a tensor with shape broadcasted to
134
- `[batch_size, n_head, sequence_length, sequence_length]`. When the
135
- data type is bool, the unwanted positions have `False` values and
136
- the others have `True` values. When the data type is int, the
137
- unwanted positions have 0 values and the others have 1 values. When
138
- the data type is float, the unwanted positions have `-INF` values
139
- and the others have 0 values. It can be None when nothing wanted or
140
- needed to be prevented attention to. Defaults to None.
133
+ positions. It is a tensor with shape `[batch_size, 1, 1, sequence_length]`.
134
+ When the data type is bool, the unwanted positions have `False`
135
+ values and the others have `True` values. When the data type is int,
136
+ the unwanted positions have 0 values and the others have 1 values.
137
+ When the data type is float, the unwanted positions have `-INF`
138
+ values and the others have 0 values. It can be None when nothing
139
+ wanted or needed to be prevented attention to. Defaults to None.
141
140
142
141
Returns:
143
142
src(Tensor|tuple):
@@ -192,7 +191,7 @@ def encoder_forward(self, src, src_mask=None, cache=None):
192
191
193
192
After inference, `disable_faster_encoder` could be called to restore the
194
193
`forward` function of `paddle.nn.TransformerEncoder` and
195
- `paddle.nn.TransformerEncoder `.
194
+ `paddle.nn.TransformerEncoderLayer `.
196
195
197
196
Args:
198
197
src (Tensor):
@@ -202,14 +201,14 @@ def encoder_forward(self, src, src_mask=None, cache=None):
202
201
src_mask (Tensor, optional):
203
202
A tensor used in multi-head attention to prevents attention to
204
203
some unwanted positions, usually the paddings or the subsequent
205
- positions. It is a tensor with shape broadcasted to
206
- `[batch_size, n_head, sequence_length, sequence_length]`. When the
207
- data type is bool, the unwanted positions have `False ` values and
208
- the others have `True` values. When the data type is int, the
209
- unwanted positions have 0 values and the others have 1 values.
210
- When the data type is float, the unwanted positions have `-INF`
211
- values and the others have 0 values. It can be None when nothing
212
- wanted or needed to be prevented attention to. Default None.
204
+ positions. It is a tensor with shape `[batch_size, 1, 1, sequence_length]`.
205
+ When the data type is bool, the unwanted positions have `False`
206
+ values and the others have `True ` values. When the data type is
207
+ int, the unwanted positions have 0 values and the others have 1
208
+ values. When the data type is float, the unwanted positions have
209
+ `-INF` values and the others have 0 values. It can be None when
210
+ nothing wanted or needed to be prevented attention to. Defaults
211
+ to None.
213
212
214
213
Returns:
215
214
output (Tensor|tuple):
@@ -252,35 +251,34 @@ def enable_faster_encoder(self):
252
251
model = disable_faster_encoder(model)
253
252
"""
254
253
255
- def check_if_usable (layer ):
256
- for sub_layer in layer .children ():
257
- if isinstance (sub_layer ,
258
- TransformerEncoderLayer ) and sub_layer ._config [
259
- 'bias_attr' ] == False :
254
+ def init_func (layer ):
255
+ if isinstance (layer , TransformerEncoderLayer ):
256
+ is_usable = True
257
+ if layer ._config ['bias_attr' ] == False :
260
258
logger .warning ("`False` for paddle.nn.TransformerEncoder's" \
261
259
" parameter `bias_attr` is not supported in " \
262
- "FasterTransformer by now. Original Paddle API " \
263
- "would be called." )
264
- return False
265
- elif not check_if_usable (sub_layer ):
266
- return False
267
- return True
268
-
269
- def init_func (layer ):
270
- if isinstance (layer , (TransformerEncoderLayer , TransformerEncoder )):
260
+ "FasterTransformer by now. The original forward" \
261
+ " will be involved." )
262
+ is_usable = False
263
+ if layer ._config ['activation' ] not in ('relu' , 'gelu' ):
264
+ logger .warning ("Only 'relu' or 'gelu' is supported by now. " \
265
+ "The original forward will be involved." )
266
+ is_usable = False
267
+ if is_usable :
268
+ layer .forward = layer ._ft_forward
269
+ elif isinstance (layer , TransformerEncoder ):
271
270
layer .forward = layer ._ft_forward
272
271
273
272
if not self .training :
274
- if not check_if_usable (self ):
275
- return self
276
273
try :
277
274
load ("FasterTransformer" , verbose = True )
278
- for layer in self .children ():
279
- layer .apply (init_func )
280
275
except Exception :
281
276
logger .warning (
282
277
"Exception occurs when using Faster Transformer. " \
283
278
"The original forward will be involved. " )
279
+ return self
280
+ for layer in self .children ():
281
+ layer .apply (init_func )
284
282
return self
285
283
286
284
0 commit comments