Skip to content

Commit 005ba45

Browse files
authored
Set torch default dtype in a context manager (#971)
Signed-off-by: Antoni Baum <[email protected]>
1 parent 320a622 commit 005ba45

File tree

1 file changed

+24
-15
lines changed

1 file changed

+24
-15
lines changed

vllm/model_executor/model_loader.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Utilities for selecting and loading models."""
2+
import contextlib
23
from typing import Type
34

45
import torch
@@ -30,6 +31,15 @@
3031
}
3132

3233

34+
@contextlib.contextmanager
35+
def _set_default_torch_dtype(dtype: torch.dtype):
36+
"""Sets the default torch dtype to the given dtype."""
37+
old_dtype = torch.get_default_dtype()
38+
torch.set_default_dtype(dtype)
39+
yield
40+
torch.set_default_dtype(old_dtype)
41+
42+
3343
def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
3444
architectures = getattr(config, "architectures", [])
3545
for arch in architectures:
@@ -42,19 +52,18 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
4252

4353
def get_model(model_config: ModelConfig) -> nn.Module:
4454
model_class = _get_model_architecture(model_config.hf_config)
45-
torch.set_default_dtype(model_config.dtype)
46-
47-
# Create a model instance.
48-
# The weights will be initialized as empty tensors.
49-
model = model_class(model_config.hf_config)
50-
if model_config.use_dummy_weights:
51-
model = model.cuda()
52-
# NOTE(woosuk): For accurate performance evaluation, we assign
53-
# random values to the weights.
54-
initialize_dummy_weights(model)
55-
else:
56-
# Load the weights from the cached or downloaded files.
57-
model.load_weights(model_config.model, model_config.download_dir,
58-
model_config.use_np_weights)
59-
model = model.cuda()
55+
with _set_default_torch_dtype(model_config.dtype):
56+
# Create a model instance.
57+
# The weights will be initialized as empty tensors.
58+
model = model_class(model_config.hf_config)
59+
if model_config.use_dummy_weights:
60+
model = model.cuda()
61+
# NOTE(woosuk): For accurate performance evaluation, we assign
62+
# random values to the weights.
63+
initialize_dummy_weights(model)
64+
else:
65+
# Load the weights from the cached or downloaded files.
66+
model.load_weights(model_config.model, model_config.download_dir,
67+
model_config.use_np_weights)
68+
model = model.cuda()
6069
return model.eval()

0 commit comments

Comments
 (0)