@@ -126,13 +126,37 @@ def do_predict(args):
126
126
start = time .time ()
127
127
out_seq = gpt (input_ids )
128
128
paddle .fluid .core ._cuda_synchronize (place )
129
- logger .info ("Average test time for decoding is %f ms" % (
129
+ logger .info ("Average test time for fast decoding is %f ms" % (
130
130
(time .time () - start ) / 50 * 1000 ))
131
131
output_sequence = out_seq .numpy ().transpose ()
132
132
for i in range (args .batch_size ):
133
133
print ("========== Sample-%d ==========" % i )
134
134
print (tokenizer .convert_ids_to_string (output_sequence [i ][1 :]))
135
135
136
+ input_ids = paddle .cast (input_ids , "int64" )
137
+ with paddle .no_grad ():
138
+ for i in range (100 ):
139
+ # For warmup.
140
+ if 50 == i :
141
+ paddle .fluid .core ._cuda_synchronize (place )
142
+ start = time .time ()
143
+ out_seq , _ = model .generate (
144
+ input_ids = input_ids ,
145
+ max_length = args .max_out_len ,
146
+ decode_strategy = "sampling" ,
147
+ temperature = args .temperature ,
148
+ top_k = args .topk ,
149
+ top_p = 1.0 ,
150
+ num_return_sequences = 1 )
151
+ paddle .fluid .core ._cuda_synchronize (place )
152
+ logger .info (
153
+ "Average test time for origin generate api decoding is %f ms" % (
154
+ (time .time () - start ) / 50 * 1000 ))
155
+ output_sequence = out_seq .numpy ()
156
+ for i in range (args .batch_size ):
157
+ print ("========== Sample-%d ==========" % i )
158
+ print (tokenizer .convert_ids_to_string (output_sequence [i ][1 :]))
159
+
136
160
137
161
if __name__ == "__main__" :
138
162
args = parse_args ()
0 commit comments