Skip to content

Commit 2cf80b9

Browse files
smallv0221FrostML
andauthored
Add origin generate api speed comparison (PaddlePaddle#1047)
* add origan generate api speed comparision * add speed example * minor fix Co-authored-by: liu zhengxi <[email protected]>
1 parent 70debdb commit 2cf80b9

File tree

1 file changed

+25
-1
lines changed
  • examples/language_model/gpt/faster_gpt

1 file changed

+25
-1
lines changed

examples/language_model/gpt/faster_gpt/infer.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,37 @@ def do_predict(args):
126126
start = time.time()
127127
out_seq = gpt(input_ids)
128128
paddle.fluid.core._cuda_synchronize(place)
129-
logger.info("Average test time for decoding is %f ms" % (
129+
logger.info("Average test time for fast decoding is %f ms" % (
130130
(time.time() - start) / 50 * 1000))
131131
output_sequence = out_seq.numpy().transpose()
132132
for i in range(args.batch_size):
133133
print("========== Sample-%d ==========" % i)
134134
print(tokenizer.convert_ids_to_string(output_sequence[i][1:]))
135135

136+
input_ids = paddle.cast(input_ids, "int64")
137+
with paddle.no_grad():
138+
for i in range(100):
139+
# For warmup.
140+
if 50 == i:
141+
paddle.fluid.core._cuda_synchronize(place)
142+
start = time.time()
143+
out_seq, _ = model.generate(
144+
input_ids=input_ids,
145+
max_length=args.max_out_len,
146+
decode_strategy="sampling",
147+
temperature=args.temperature,
148+
top_k=args.topk,
149+
top_p=1.0,
150+
num_return_sequences=1)
151+
paddle.fluid.core._cuda_synchronize(place)
152+
logger.info(
153+
"Average test time for origin generate api decoding is %f ms" % (
154+
(time.time() - start) / 50 * 1000))
155+
output_sequence = out_seq.numpy()
156+
for i in range(args.batch_size):
157+
print("========== Sample-%d ==========" % i)
158+
print(tokenizer.convert_ids_to_string(output_sequence[i][1:]))
159+
136160

137161
if __name__ == "__main__":
138162
args = parse_args()

0 commit comments

Comments
 (0)