1
1
"""Utilities for selecting and loading models."""
2
+ import contextlib
2
3
from typing import Type
3
4
4
5
import torch
30
31
}
31
32
32
33
34
+ @contextlib .contextmanager
35
+ def _set_default_torch_dtype (dtype : torch .dtype ):
36
+ """Sets the default torch dtype to the given dtype."""
37
+ old_dtype = torch .get_default_dtype ()
38
+ torch .set_default_dtype (dtype )
39
+ yield
40
+ torch .set_default_dtype (old_dtype )
41
+
42
+
33
43
def _get_model_architecture (config : PretrainedConfig ) -> Type [nn .Module ]:
34
44
architectures = getattr (config , "architectures" , [])
35
45
for arch in architectures :
@@ -42,19 +52,18 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
42
52
43
53
def get_model (model_config : ModelConfig ) -> nn .Module :
44
54
model_class = _get_model_architecture (model_config .hf_config )
45
- torch .set_default_dtype (model_config .dtype )
46
-
47
- # Create a model instance.
48
- # The weights will be initialized as empty tensors.
49
- model = model_class (model_config .hf_config )
50
- if model_config .use_dummy_weights :
51
- model = model .cuda ()
52
- # NOTE(woosuk): For accurate performance evaluation, we assign
53
- # random values to the weights.
54
- initialize_dummy_weights (model )
55
- else :
56
- # Load the weights from the cached or downloaded files.
57
- model .load_weights (model_config .model , model_config .download_dir ,
58
- model_config .use_np_weights )
59
- model = model .cuda ()
55
+ with _set_default_torch_dtype (model_config .dtype ):
56
+ # Create a model instance.
57
+ # The weights will be initialized as empty tensors.
58
+ model = model_class (model_config .hf_config )
59
+ if model_config .use_dummy_weights :
60
+ model = model .cuda ()
61
+ # NOTE(woosuk): For accurate performance evaluation, we assign
62
+ # random values to the weights.
63
+ initialize_dummy_weights (model )
64
+ else :
65
+ # Load the weights from the cached or downloaded files.
66
+ model .load_weights (model_config .model , model_config .download_dir ,
67
+ model_config .use_np_weights )
68
+ model = model .cuda ()
60
69
return model .eval ()
0 commit comments