|
2 | 2 | import importlib |
3 | 3 | import inspect |
4 | 4 | import re |
5 | | -from typing import Any, Dict, Optional, Tuple |
| 5 | +from typing import Any, Callable, Dict, Optional, Tuple |
6 | 6 | import torch |
7 | 7 | import transformers |
8 | 8 | from ...cache_helpers import make_dynamic_cache |
@@ -46,6 +46,104 @@ def _update_config(config: Any, kwargs: Dict[str, Any]): |
46 | 46 | setattr(config, k, v) |
47 | 47 |
|
48 | 48 |
|
| 49 | +def reduce_model_config(config: Any, task: str) -> Dict[str, Any]: |
| 50 | + """Reduces a model size.""" |
| 51 | + if task == "text-generation": |
| 52 | + kwargs = dict( |
| 53 | + head_dim=getattr( |
| 54 | + config, "head_dim", config.hidden_size // config.num_attention_heads |
| 55 | + ), |
| 56 | + num_hidden_layers=min(config.num_hidden_layers, 2), |
| 57 | + num_key_value_heads=( |
| 58 | + config.num_key_value_heads |
| 59 | + if hasattr(config, "num_key_value_heads") |
| 60 | + else config.num_attention_heads |
| 61 | + ), |
| 62 | + intermediate_size=( |
| 63 | + min(config.intermediate_size, 24576 // 4) |
| 64 | + if config.intermediate_size % 4 == 0 |
| 65 | + else config.intermediate_size |
| 66 | + ), |
| 67 | + hidden_size=( |
| 68 | + min(config.hidden_size, 3072 // 4) |
| 69 | + if config.hidden_size % 4 == 0 |
| 70 | + else config.hidden_size |
| 71 | + ), |
| 72 | + ) |
| 73 | + elif task == "image-classification": |
| 74 | + if isinstance(config.image_size, int): |
| 75 | + kwargs = dict( |
| 76 | + batch_size=2, |
| 77 | + input_width=config.image_size, |
| 78 | + input_height=config.image_size, |
| 79 | + input_channels=config.num_channels, |
| 80 | + ) |
| 81 | + else: |
| 82 | + kwargs = dict( |
| 83 | + batch_size=2, |
| 84 | + input_width=config.image_size[0], |
| 85 | + input_height=config.image_size[1], |
| 86 | + input_channels=config.num_channels, |
| 87 | + ) |
| 88 | + else: |
| 89 | + raise NotImplementedError(f"Input generation for task {task!r} not implemented yet.") |
| 90 | + |
| 91 | + for k, v in kwargs.items(): |
| 92 | + setattr(config, k, v) |
| 93 | + return kwargs |
| 94 | + |
| 95 | + |
| 96 | +def random_input_kwargs(config: Any, task: str) -> Tuple[Dict[str, Any], Callable]: |
| 97 | + """Inputs kwargs""" |
| 98 | + if task == "text-generation": |
| 99 | + kwargs = dict( |
| 100 | + batch_size=2, |
| 101 | + sequence_length=30, |
| 102 | + sequence_length2=3, |
| 103 | + head_dim=getattr( |
| 104 | + config, "head_dim", config.hidden_size // config.num_attention_heads |
| 105 | + ), |
| 106 | + dummy_max_token_id=config.vocab_size - 1, |
| 107 | + num_hidden_layers=min(config.num_hidden_layers, 2), |
| 108 | + num_key_value_heads=( |
| 109 | + config.num_key_value_heads |
| 110 | + if hasattr(config, "num_key_value_heads") |
| 111 | + else config.num_attention_heads |
| 112 | + ), |
| 113 | + intermediate_size=( |
| 114 | + min(config.intermediate_size, 24576 // 4) |
| 115 | + if config.intermediate_size % 4 == 0 |
| 116 | + else config.intermediate_size |
| 117 | + ), |
| 118 | + hidden_size=( |
| 119 | + min(config.hidden_size, 3072 // 4) |
| 120 | + if config.hidden_size % 4 == 0 |
| 121 | + else config.hidden_size |
| 122 | + ), |
| 123 | + ) |
| 124 | + fct = get_inputs_for_text_generation |
| 125 | + elif task == "image-classification": |
| 126 | + if isinstance(config.image_size, int): |
| 127 | + kwargs = dict( |
| 128 | + batch_size=2, |
| 129 | + input_width=config.image_size, |
| 130 | + input_height=config.image_size, |
| 131 | + input_channels=config.num_channels, |
| 132 | + ) |
| 133 | + else: |
| 134 | + kwargs = dict( |
| 135 | + batch_size=2, |
| 136 | + input_width=config.image_size[0], |
| 137 | + input_height=config.image_size[1], |
| 138 | + input_channels=config.num_channels, |
| 139 | + ) |
| 140 | + fct = get_inputs_for_image_classification # type: ignore |
| 141 | + else: |
| 142 | + raise NotImplementedError(f"Input generation for task {task!r} not implemented yet.") |
| 143 | + |
| 144 | + return kwargs, fct |
| 145 | + |
| 146 | + |
49 | 147 | def get_untrained_model_with_inputs( |
50 | 148 | model_id: str, |
51 | 149 | config: Optional[Any] = None, |
@@ -114,63 +212,26 @@ def get_untrained_model_with_inputs( |
114 | 212 | config.rope_scaling = ( |
115 | 213 | {"rope_type": "dynamic", "factor": 10.0} if dynamic_rope else None |
116 | 214 | ) |
| 215 | + |
| 216 | + # updating the configuration |
| 217 | + if not same_as_pretrained: |
| 218 | + mkwargs = reduce_model_config(config, task) |
| 219 | + else: |
| 220 | + mkwargs = {} |
117 | 221 | if model_kwargs: |
118 | 222 | for k, v in model_kwargs.items(): |
119 | 223 | setattr(config, k, v) |
120 | | - |
121 | | - if task == "text-generation": |
122 | | - kwargs = dict( |
123 | | - batch_size=2, |
124 | | - sequence_length=30, |
125 | | - sequence_length2=3, |
126 | | - head_dim=getattr( |
127 | | - config, "head_dim", config.hidden_size // config.num_attention_heads |
128 | | - ), |
129 | | - dummy_max_token_id=config.vocab_size - 1, |
130 | | - num_hidden_layers=min(config.num_hidden_layers, 2), |
131 | | - num_key_value_heads=( |
132 | | - config.num_key_value_heads |
133 | | - if hasattr(config, "num_key_value_heads") |
134 | | - else config.num_attention_heads |
135 | | - ), |
136 | | - intermediate_size=( |
137 | | - min(config.intermediate_size, 24576 // 4) |
138 | | - if config.intermediate_size % 4 == 0 |
139 | | - else config.intermediate_size |
140 | | - ), |
141 | | - hidden_size=( |
142 | | - min(config.hidden_size, 3072 // 4) |
143 | | - if config.hidden_size % 4 == 0 |
144 | | - else config.hidden_size |
145 | | - ), |
146 | | - ) |
147 | | - |
148 | | - fct = get_inputs_for_text_generation |
149 | | - elif task == "image-classification": |
150 | | - if isinstance(config.image_size, int): |
151 | | - kwargs = dict( |
152 | | - batch_size=2, |
153 | | - input_width=config.image_size, |
154 | | - input_height=config.image_size, |
155 | | - input_channels=config.num_channels, |
156 | | - ) |
157 | | - else: |
158 | | - kwargs = dict( |
159 | | - batch_size=2, |
160 | | - input_width=config.image_size[0], |
161 | | - input_height=config.image_size[1], |
162 | | - input_channels=config.num_channels, |
163 | | - ) |
164 | | - fct = get_inputs_for_image_classification |
165 | | - else: |
166 | | - raise NotImplementedError(f"Input generation for task {task!r} not implemented yet.") |
167 | | - |
| 224 | + mkwargs[k] = v |
| 225 | + # input kwargs |
| 226 | + kwargs, fct = random_input_kwargs(config, task) |
168 | 227 | if inputs_kwargs: |
169 | 228 | kwargs.update(inputs_kwargs) |
170 | | - true_kwargs = (inputs_kwargs or {}) if same_as_pretrained else kwargs |
171 | | - _update_config(config, true_kwargs) |
| 229 | + |
172 | 230 | model = getattr(transformers, arch)(config) |
173 | | - return fct(model, config, **true_kwargs) |
| 231 | + res = fct(model, config, **kwargs) |
| 232 | + res["input_kwargs"] = kwargs |
| 233 | + res["model_kwargs"] = mkwargs |
| 234 | + return res |
174 | 235 |
|
175 | 236 |
|
176 | 237 | def compute_model_size(model: torch.nn.Module) -> Tuple[int, int]: |
|
0 commit comments