Skip to content

Commit 955c10b

Browse files
authored
set the device id for the taskflow (PaddlePaddle#1011)
1 parent b9a4cb1 commit 955c10b

File tree

2 files changed

+2
-1
lines changed

2 files changed

+2
-1
lines changed

paddlenlp/taskflow/task.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def _prepare_static_mode(self):
9191
if place == 'cpu':
9292
self._config.disable_gpu()
9393
else:
94-
self._config.enable_use_gpu(100, 0)
94+
self._config.enable_use_gpu(100, self.kwargs['device_id'])
9595
self._config.switch_use_feed_fetch_ops(False)
9696
self._config.disable_glog_info()
9797
self.predictor = paddle.inference.create_predictor(self._config)

paddlenlp/taskflow/taskflow.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def __init__(self, task, model=None, device_id=0, **kwargs):
135135
self.model = model
136136
# Update the task config to kwargs
137137
config_kwargs = TASKS[self.task]['models'][self.model]
138+
kwargs['device_id'] = device_id
138139
kwargs.update(config_kwargs)
139140
self.kwargs = kwargs
140141
task_class = TASKS[self.task]['models'][self.model]['task_class']

0 commit comments

Comments
 (0)