|
1 |
| -from typing import Optional |
| 1 | +from typing import Optional, Union |
2 | 2 |
|
3 | 3 | import torch
|
4 | 4 | from transformers import PretrainedConfig
|
@@ -58,7 +58,7 @@ def __init__(
|
58 | 58 | trust_remote_code: bool,
|
59 | 59 | download_dir: Optional[str],
|
60 | 60 | load_format: str,
|
61 |
| - dtype: str, |
| 61 | + dtype: Union[str, torch.dtype], |
62 | 62 | seed: int,
|
63 | 63 | revision: Optional[str] = None,
|
64 | 64 | tokenizer_revision: Optional[str] = None,
|
@@ -331,25 +331,31 @@ def _verify_args(self) -> None:
|
331 | 331 |
|
332 | 332 | def _get_and_verify_dtype(
|
333 | 333 | config: PretrainedConfig,
|
334 |
| - dtype: str, |
| 334 | + dtype: Union[str, torch.dtype], |
335 | 335 | ) -> torch.dtype:
|
336 | 336 | # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
|
337 | 337 | # because config.torch_dtype can be None.
|
338 | 338 | config_dtype = getattr(config, "torch_dtype", None)
|
339 | 339 | if config_dtype is None:
|
340 | 340 | config_dtype = torch.float32
|
341 | 341 |
|
342 |
| - dtype = dtype.lower() |
343 |
| - if dtype == "auto": |
344 |
| - if config_dtype == torch.float32: |
345 |
| - # Following the common practice, we use float16 for float32 models. |
346 |
| - torch_dtype = torch.float16 |
| 342 | + if isinstance(dtype, str): |
| 343 | + dtype = dtype.lower() |
| 344 | + if dtype == "auto": |
| 345 | + if config_dtype == torch.float32: |
| 346 | + # Following the common practice, we use float16 for float32 |
| 347 | + # models. |
| 348 | + torch_dtype = torch.float16 |
| 349 | + else: |
| 350 | + torch_dtype = config_dtype |
347 | 351 | else:
|
348 |
| - torch_dtype = config_dtype |
| 352 | + if dtype not in _STR_DTYPE_TO_TORCH_DTYPE: |
| 353 | + raise ValueError(f"Unknown dtype: {dtype}") |
| 354 | + torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] |
| 355 | + elif isinstance(dtype, torch.dtype): |
| 356 | + torch_dtype = dtype |
349 | 357 | else:
|
350 |
| - if dtype not in _STR_DTYPE_TO_TORCH_DTYPE: |
351 |
| - raise ValueError(f"Unknown dtype: {dtype}") |
352 |
| - torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] |
| 358 | + raise ValueError(f"Unknown dtype: {dtype}") |
353 | 359 |
|
354 | 360 | # Verify the dtype.
|
355 | 361 | if torch_dtype != config_dtype:
|
|
0 commit comments