|
15 | 15 | import torch._dynamo.config |
16 | 16 | import torch._inductor.config |
17 | 17 | import torch.nn as nn |
| 18 | +try: |
| 19 | + from _torchchat_test_script import flamingo_meta_to_tune |
| 20 | +except ImportError: |
| 21 | + pass |
| 22 | + |
18 | 23 | from distributed import ( |
19 | 24 | init_distributed, |
20 | 25 | launch_distributed, |
21 | 26 | ParallelDims, |
22 | 27 | parallelize_llama, |
23 | 28 | ) |
| 29 | + |
24 | 30 | from torch.distributed.device_mesh import DeviceMesh |
25 | 31 |
|
26 | | -from torchchat.model import Model |
| 32 | +from torchtune.models.convert_weights import meta_to_tune |
| 33 | + |
| 34 | +from torchtune.training import set_default_dtype |
| 35 | + |
| 36 | +from torchchat.model import Model, ModelType |
27 | 37 |
|
28 | 38 | from torchchat.model_config.model_config import resolve_model_config |
29 | 39 | from torchchat.utils.build_utils import ( |
|
35 | 45 | from torchchat.utils.measure_time import measure_time |
36 | 46 | from torchchat.utils.quantize import quantize_model |
37 | 47 |
|
38 | | -from torchtune.models.convert_weights import meta_to_tune |
39 | | - |
40 | | - |
41 | | - |
42 | 48 |
|
43 | 49 | @dataclass |
44 | 50 | class BuilderArgs: |
@@ -143,7 +149,6 @@ def from_args(cls, args): # -> BuilderArgs: |
143 | 149 | if "chat" in path_basename or "instruct" in path_basename: |
144 | 150 | is_chat_model = True |
145 | 151 |
|
146 | | - |
147 | 152 | output_pte_path = getattr(args, "output_pte_path", None) |
148 | 153 | output_dso_path = getattr(args, "output_dso_path", None) |
149 | 154 | if output_pte_path and args.dtype.startswith("fast"): |
@@ -234,7 +239,12 @@ def validate_model( |
234 | 239 |
|
235 | 240 | is_tiktoken = self.is_tiktoken |
236 | 241 | is_sentencepiece = self.is_sentencepiece |
237 | | - use_tiktoken = model.config.transformer_args["text"].use_tiktoken |
| 242 | + text_args = model.config.transformer_args.get("text") |
| 243 | + if text_args is None: |
| 244 | + # TODO: Will be refactored: Currently, the only model that doesn't have text in transfomer_args is Flamingo |
| 245 | + use_tiktoken = model.config.model_type == ModelType.Flamingo |
| 246 | + else: |
| 247 | + use_tiktoken = text_args.use_tiktoken |
238 | 248 |
|
239 | 249 | if not (is_tiktoken == use_tiktoken) or not (is_sentencepiece != use_tiktoken): |
240 | 250 | raise RuntimeError( |
@@ -266,7 +276,9 @@ def from_args(cls, args): # -> TokenizerArgs: |
266 | 276 | raise RuntimeError("cannot find tokenizer model") |
267 | 277 |
|
268 | 278 | if not tokenizer_path.is_file(): |
269 | | - raise RuntimeError(f"did not find tokenizer at {tokenizer_path}") |
| 279 | + raise RuntimeError( |
| 280 | + f"did not find tokenizer at {os.path.abspath(tokenizer_path)}" |
| 281 | + ) |
270 | 282 |
|
271 | 283 | return cls( |
272 | 284 | tokenizer_path=tokenizer_path, |
@@ -335,7 +347,9 @@ def _load_model_default(builder_args, only_config=False): |
335 | 347 |
|
336 | 348 | if builder_args.params_table and builder_args.params_table.endswith("Tune"): |
337 | 349 | print("Loading Tune checkpoint") |
338 | | - meta_checkpoint = torch.load(str(builder_args.checkpoint_path), mmap=True, weights_only=True) |
| 350 | + meta_checkpoint = torch.load( |
| 351 | + str(builder_args.checkpoint_path), mmap=True, weights_only=True |
| 352 | + ) |
339 | 353 | checkpoint = meta_to_tune(meta_checkpoint) |
340 | 354 | elif builder_args.checkpoint_dir is not None: |
341 | 355 | # Load multiple checkpoint; ignore the single path. |
@@ -372,8 +386,17 @@ def _load_model_default(builder_args, only_config=False): |
372 | 386 | if "model" in checkpoint and "stories" in str(builder_args.checkpoint_path): |
373 | 387 | checkpoint = checkpoint["model"] |
374 | 388 |
|
375 | | - checkpoint = {"model." + k: v for k, v in checkpoint.items()} |
376 | | - model.load_state_dict(checkpoint, assign=True, strict=True) |
| 389 | + if model.config.model_type == ModelType.Flamingo: |
| 390 | + # TODO: Refactor this. For now, overwrite the model with model loaded from params_path |
| 391 | + with set_default_dtype(builder_args.precision), torch.device( |
| 392 | + builder_args.device |
| 393 | + ): |
| 394 | + model = Model.from_params(builder_args.params_path) |
| 395 | + state_dict = flamingo_meta_to_tune(checkpoint) |
| 396 | + model.model.load_state_dict(state_dict) |
| 397 | + else: |
| 398 | + checkpoint = {"model." + k: v for k, v in checkpoint.items()} |
| 399 | + model.load_state_dict(checkpoint, assign=True, strict=True) |
377 | 400 |
|
378 | 401 | return model |
379 | 402 |
|
|
0 commit comments