|
15 | 15 | import torch._dynamo.config |
16 | 16 | import torch._inductor.config |
17 | 17 | import torch.nn as nn |
| 18 | + |
| 19 | +try: |
| 20 | + from _torchchat_test_script import flamingo_meta_to_tune |
| 21 | +except ImportError: |
| 22 | + pass |
| 23 | + |
18 | 24 | from distributed import ( |
19 | 25 | init_distributed, |
20 | 26 | launch_distributed, |
21 | 27 | ParallelDims, |
22 | 28 | parallelize_llama, |
23 | 29 | ) |
| 30 | + |
24 | 31 | from torch.distributed.device_mesh import DeviceMesh |
25 | 32 |
|
26 | | -from torchchat.model import Model |
| 33 | +from torchtune.models.convert_weights import meta_to_tune |
| 34 | + |
| 35 | +from torchtune.training import set_default_dtype |
| 36 | + |
| 37 | +from torchchat.model import Model, ModelType |
27 | 38 |
|
28 | 39 | from torchchat.model_config.model_config import resolve_model_config |
29 | 40 | from torchchat.utils.build_utils import ( |
|
35 | 46 | from torchchat.utils.measure_time import measure_time |
36 | 47 | from torchchat.utils.quantize import quantize_model |
37 | 48 |
|
38 | | -from torchtune.models.convert_weights import meta_to_tune |
39 | | - |
40 | | - |
41 | | - |
42 | 49 |
|
43 | 50 | @dataclass |
44 | 51 | class BuilderArgs: |
@@ -143,7 +150,6 @@ def from_args(cls, args): # -> BuilderArgs: |
143 | 150 | if "chat" in path_basename or "instruct" in path_basename: |
144 | 151 | is_chat_model = True |
145 | 152 |
|
146 | | - |
147 | 153 | output_pte_path = getattr(args, "output_pte_path", None) |
148 | 154 | output_dso_path = getattr(args, "output_dso_path", None) |
149 | 155 | if output_pte_path and args.dtype.startswith("fast"): |
@@ -234,7 +240,12 @@ def validate_model( |
234 | 240 |
|
235 | 241 | is_tiktoken = self.is_tiktoken |
236 | 242 | is_sentencepiece = self.is_sentencepiece |
237 | | - use_tiktoken = model.config.transformer_args["text"].use_tiktoken |
| 243 | + text_args = model.config.transformer_args.get("text") |
| 244 | + if text_args is None: |
| 245 | + # TODO: Will be refactored: Currently, the only model that doesn't have text in transfomer_args is Flamingo |
| 246 | + use_tiktoken = model.config.model_type == ModelType.Flamingo |
| 247 | + else: |
| 248 | + use_tiktoken = text_args.use_tiktoken |
238 | 249 |
|
239 | 250 | if not (is_tiktoken == use_tiktoken) or not (is_sentencepiece != use_tiktoken): |
240 | 251 | raise RuntimeError( |
@@ -266,7 +277,9 @@ def from_args(cls, args): # -> TokenizerArgs: |
266 | 277 | raise RuntimeError("cannot find tokenizer model") |
267 | 278 |
|
268 | 279 | if not tokenizer_path.is_file(): |
269 | | - raise RuntimeError(f"did not find tokenizer at {tokenizer_path}") |
| 280 | + raise RuntimeError( |
| 281 | + f"did not find tokenizer at {os.path.abspath(tokenizer_path)}" |
| 282 | + ) |
270 | 283 |
|
271 | 284 | return cls( |
272 | 285 | tokenizer_path=tokenizer_path, |
@@ -335,7 +348,9 @@ def _load_model_default(builder_args, only_config=False): |
335 | 348 |
|
336 | 349 | if builder_args.params_table and builder_args.params_table.endswith("Tune"): |
337 | 350 | print("Loading Tune checkpoint") |
338 | | - meta_checkpoint = torch.load(str(builder_args.checkpoint_path), mmap=True, weights_only=True) |
| 351 | + meta_checkpoint = torch.load( |
| 352 | + str(builder_args.checkpoint_path), mmap=True, weights_only=True |
| 353 | + ) |
339 | 354 | checkpoint = meta_to_tune(meta_checkpoint) |
340 | 355 | elif builder_args.checkpoint_dir is not None: |
341 | 356 | # Load multiple checkpoint; ignore the single path. |
@@ -372,8 +387,17 @@ def _load_model_default(builder_args, only_config=False): |
372 | 387 | if "model" in checkpoint and "stories" in str(builder_args.checkpoint_path): |
373 | 388 | checkpoint = checkpoint["model"] |
374 | 389 |
|
375 | | - checkpoint = {"model." + k: v for k, v in checkpoint.items()} |
376 | | - model.load_state_dict(checkpoint, assign=True, strict=True) |
| 390 | + if model.config.model_type == ModelType.Flamingo: |
| 391 | + # TODO: Refactor this. For now, overwrite the model with model loaded from params_path |
| 392 | + with set_default_dtype(builder_args.precision), torch.device( |
| 393 | + builder_args.device |
| 394 | + ): |
| 395 | + model = Model.from_params(builder_args.params_path) |
| 396 | + state_dict = flamingo_meta_to_tune(checkpoint) |
| 397 | + model.model.load_state_dict(state_dict) |
| 398 | + else: |
| 399 | + checkpoint = {"model." + k: v for k, v in checkpoint.items()} |
| 400 | + model.load_state_dict(checkpoint, assign=True, strict=True) |
377 | 401 |
|
378 | 402 | return model |
379 | 403 |
|
|
0 commit comments