Skip to content

Commit 65ea2dd

Browse files
authored
feat(config): support parsing torch.dtype (#1641)
Signed-off-by: Aaron <[email protected]>
1 parent b514d3c commit 65ea2dd

File tree

1 file changed

+18
-12
lines changed

1 file changed

+18
-12
lines changed

vllm/config.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional
1+
from typing import Optional, Union
22

33
import torch
44
from transformers import PretrainedConfig
@@ -58,7 +58,7 @@ def __init__(
5858
trust_remote_code: bool,
5959
download_dir: Optional[str],
6060
load_format: str,
61-
dtype: str,
61+
dtype: Union[str, torch.dtype],
6262
seed: int,
6363
revision: Optional[str] = None,
6464
tokenizer_revision: Optional[str] = None,
@@ -331,25 +331,31 @@ def _verify_args(self) -> None:
331331

332332
def _get_and_verify_dtype(
333333
config: PretrainedConfig,
334-
dtype: str,
334+
dtype: Union[str, torch.dtype],
335335
) -> torch.dtype:
336336
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
337337
# because config.torch_dtype can be None.
338338
config_dtype = getattr(config, "torch_dtype", None)
339339
if config_dtype is None:
340340
config_dtype = torch.float32
341341

342-
dtype = dtype.lower()
343-
if dtype == "auto":
344-
if config_dtype == torch.float32:
345-
# Following the common practice, we use float16 for float32 models.
346-
torch_dtype = torch.float16
342+
if isinstance(dtype, str):
343+
dtype = dtype.lower()
344+
if dtype == "auto":
345+
if config_dtype == torch.float32:
346+
# Following the common practice, we use float16 for float32
347+
# models.
348+
torch_dtype = torch.float16
349+
else:
350+
torch_dtype = config_dtype
347351
else:
348-
torch_dtype = config_dtype
352+
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
353+
raise ValueError(f"Unknown dtype: {dtype}")
354+
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
355+
elif isinstance(dtype, torch.dtype):
356+
torch_dtype = dtype
349357
else:
350-
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
351-
raise ValueError(f"Unknown dtype: {dtype}")
352-
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
358+
raise ValueError(f"Unknown dtype: {dtype}")
353359

354360
# Verify the dtype.
355361
if torch_dtype != config_dtype:

0 commit comments

Comments
 (0)