11"""Utilities for selecting and loading models."""
2+ import contextlib
23from typing import Type
34
45import torch
3031}
3132
3233
34+ @contextlib .contextmanager
35+ def _set_default_torch_dtype (dtype : torch .dtype ):
36+ """Sets the default torch dtype to the given dtype."""
37+ old_dtype = torch .get_default_dtype ()
38+ torch .set_default_dtype (dtype )
39+ yield
40+ torch .set_default_dtype (old_dtype )
41+
42+
3343def _get_model_architecture (config : PretrainedConfig ) -> Type [nn .Module ]:
3444 architectures = getattr (config , "architectures" , [])
3545 for arch in architectures :
@@ -42,19 +52,18 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
4252
4353def get_model (model_config : ModelConfig ) -> nn .Module :
4454 model_class = _get_model_architecture (model_config .hf_config )
45- torch .set_default_dtype (model_config .dtype )
46-
47- # Create a model instance.
48- # The weights will be initialized as empty tensors.
49- model = model_class (model_config .hf_config )
50- if model_config .use_dummy_weights :
51- model = model .cuda ()
52- # NOTE(woosuk): For accurate performance evaluation, we assign
53- # random values to the weights.
54- initialize_dummy_weights (model )
55- else :
56- # Load the weights from the cached or downloaded files.
57- model .load_weights (model_config .model , model_config .download_dir ,
58- model_config .use_np_weights )
59- model = model .cuda ()
55+ with _set_default_torch_dtype (model_config .dtype ):
56+ # Create a model instance.
57+ # The weights will be initialized as empty tensors.
58+ model = model_class (model_config .hf_config )
59+ if model_config .use_dummy_weights :
60+ model = model .cuda ()
61+ # NOTE(woosuk): For accurate performance evaluation, we assign
62+ # random values to the weights.
63+ initialize_dummy_weights (model )
64+ else :
65+ # Load the weights from the cached or downloaded files.
66+ model .load_weights (model_config .model , model_config .download_dir ,
67+ model_config .use_np_weights )
68+ model = model .cuda ()
6069 return model .eval ()
0 commit comments