Skip to content

Commit 45967dc

Browse files
authored
fix export and infer model of gpt. (PaddlePaddle#1019)
1 parent 71aff69 commit 45967dc

File tree

2 files changed

+1
-5
lines changed

2 files changed

+1
-5
lines changed

examples/language_model/gpt/deploy/python/inference.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,11 +113,9 @@ def main():
113113
"问题:中国的首都是哪里?答案:",
114114
"问题:世界上最高的山峰是? 答案:",
115115
]
116-
end_id = tokenizer.eol_token_id
117116

118117
dataset = [[
119-
np.array(tokenizer(text)["input_ids"]).astype("int64").reshape([1, -1]),
120-
np.array(end_id).astype("int32").reshape([1])
118+
np.array(tokenizer(text)["input_ids"]).astype("int64").reshape([1, -1])
121119
] for text in ds]
122120
outs = predictor.predict(dataset)
123121
for res in outs:

examples/language_model/gpt/export_model.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,6 @@ def main():
7171
input_spec=[
7272
paddle.static.InputSpec(
7373
shape=[None, None], dtype="int64"), # input_ids
74-
paddle.static.InputSpec(
75-
shape=[1], dtype="int32"), # end_id
7674
])
7775

7876
# Save converted static graph model

0 commit comments

Comments
 (0)