Skip to content

Commit ec5ae94

Browse files
committed
Automatically cast multi-modal input dtype
Signed-off-by: shen-shanshan <[email protected]>
1 parent b431db1 commit ec5ae94

File tree

4 files changed

+19
-7
lines changed

4 files changed

+19
-7
lines changed

vllm_ascend/worker/draft_model_runner.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,8 +274,11 @@ def execute_model(
274274
input_ids=model_input.input_tokens,
275275
positions=model_input.input_positions,
276276
intermediate_tensors=intermediate_tensors,
277-
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
278-
device=self.device),
277+
**MultiModalKwargs.as_kwargs(
278+
multi_modal_kwargs,
279+
dtype=self.model_runner.model_config.dtype,
280+
device=self.device,
281+
),
279282
**model_execute_kwargs,
280283
)
281284

vllm_ascend/worker/model_runner.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1384,8 +1384,11 @@ def execute_model(
13841384
input_ids=model_input.input_tokens,
13851385
positions=model_input.input_positions,
13861386
intermediate_tensors=intermediate_tensors,
1387-
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
1388-
device=self.device),
1387+
**MultiModalKwargs.as_kwargs(
1388+
multi_modal_kwargs,
1389+
dtype=self.model_runner.model_config.dtype,
1390+
device=self.device,
1391+
),
13891392
**seqlen_agnostic_kwargs,
13901393
**model_kwargs)
13911394

vllm_ascend/worker/model_runner_v1.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1010,7 +1010,10 @@ def _profile_multimodal(self) -> None:
10101010
batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] *
10111011
max_num_mm_items)
10121012
batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
1013-
batched_dummy_mm_inputs, device=self.device)
1013+
batched_dummy_mm_inputs,
1014+
dtype=self.model_config.dtype,
1015+
device=self.device,
1016+
)
10141017

10151018
# Run multimodal encoder.
10161019
dummy_encoder_outputs = self.model.get_multimodal_embeddings(

vllm_ascend/worker/pooling_model_runner.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,11 @@ def execute_model(
148148
input_ids=model_input.input_tokens,
149149
positions=model_input.input_positions,
150150
intermediate_tensors=intermediate_tensors,
151-
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
152-
device=self.device),
151+
**MultiModalKwargs.as_kwargs(
152+
multi_modal_kwargs,
153+
dtype=self.model_runner.model_config.dtype,
154+
device=self.device,
155+
),
153156
**cross_enc_kwargs,
154157
**seqlen_agnostic_kwargs)
155158

0 commit comments

Comments
 (0)