|  | 
| 15 | 15 | import torch._dynamo.config | 
| 16 | 16 | import torch._inductor.config | 
| 17 | 17 | import torch.nn as nn | 
|  | 18 | + | 
|  | 19 | +try: | 
|  | 20 | +    from _torchchat_test_script import flamingo_meta_to_tune | 
|  | 21 | +except ImportError: | 
|  | 22 | +    pass | 
|  | 23 | + | 
| 18 | 24 | from distributed import ( | 
| 19 | 25 |     init_distributed, | 
| 20 | 26 |     launch_distributed, | 
| 21 | 27 |     ParallelDims, | 
| 22 | 28 |     parallelize_llama, | 
| 23 | 29 | ) | 
|  | 30 | + | 
| 24 | 31 | from torch.distributed.device_mesh import DeviceMesh | 
| 25 | 32 | 
 | 
| 26 |  | -from torchchat.model import Model | 
|  | 33 | +from torchtune.models.convert_weights import meta_to_tune | 
|  | 34 | + | 
|  | 35 | +from torchtune.training import set_default_dtype | 
|  | 36 | + | 
|  | 37 | +from torchchat.model import Model, ModelType | 
| 27 | 38 | 
 | 
| 28 | 39 | from torchchat.model_config.model_config import resolve_model_config | 
| 29 | 40 | from torchchat.utils.build_utils import ( | 
|  | 
| 35 | 46 | from torchchat.utils.measure_time import measure_time | 
| 36 | 47 | from torchchat.utils.quantize import quantize_model | 
| 37 | 48 | 
 | 
| 38 |  | -from torchtune.models.convert_weights import meta_to_tune | 
| 39 |  | - | 
| 40 |  | - | 
| 41 |  | - | 
| 42 | 49 | 
 | 
| 43 | 50 | @dataclass | 
| 44 | 51 | class BuilderArgs: | 
| @@ -143,7 +150,6 @@ def from_args(cls, args):  # -> BuilderArgs: | 
| 143 | 150 |                     if "chat" in path_basename or "instruct" in path_basename: | 
| 144 | 151 |                         is_chat_model = True | 
| 145 | 152 | 
 | 
| 146 |  | - | 
| 147 | 153 |         output_pte_path = getattr(args, "output_pte_path", None) | 
| 148 | 154 |         output_dso_path = getattr(args, "output_dso_path", None) | 
| 149 | 155 |         if output_pte_path and args.dtype.startswith("fast"): | 
| @@ -234,7 +240,12 @@ def validate_model( | 
| 234 | 240 | 
 | 
| 235 | 241 |         is_tiktoken = self.is_tiktoken | 
| 236 | 242 |         is_sentencepiece = self.is_sentencepiece | 
| 237 |  | -        use_tiktoken = model.config.transformer_args["text"].use_tiktoken | 
|  | 243 | +        text_args = model.config.transformer_args.get("text") | 
|  | 244 | +        if text_args is None: | 
|  | 245 | +            # TODO: Will be refactored: Currently, the only model that doesn't have text in transfomer_args is Flamingo | 
|  | 246 | +            use_tiktoken = model.config.model_type == ModelType.Flamingo | 
|  | 247 | +        else: | 
|  | 248 | +            use_tiktoken = text_args.use_tiktoken | 
| 238 | 249 | 
 | 
| 239 | 250 |         if not (is_tiktoken == use_tiktoken) or not (is_sentencepiece != use_tiktoken): | 
| 240 | 251 |             raise RuntimeError( | 
| @@ -266,7 +277,9 @@ def from_args(cls, args):  # -> TokenizerArgs: | 
| 266 | 277 |             raise RuntimeError("cannot find tokenizer model") | 
| 267 | 278 | 
 | 
| 268 | 279 |         if not tokenizer_path.is_file(): | 
| 269 |  | -            raise RuntimeError(f"did not find tokenizer at {tokenizer_path}") | 
|  | 280 | +            raise RuntimeError( | 
|  | 281 | +                f"did not find tokenizer at {os.path.abspath(tokenizer_path)}" | 
|  | 282 | +            ) | 
| 270 | 283 | 
 | 
| 271 | 284 |         return cls( | 
| 272 | 285 |             tokenizer_path=tokenizer_path, | 
| @@ -335,7 +348,9 @@ def _load_model_default(builder_args, only_config=False): | 
| 335 | 348 | 
 | 
| 336 | 349 |     if builder_args.params_table and builder_args.params_table.endswith("Tune"): | 
| 337 | 350 |         print("Loading Tune checkpoint") | 
| 338 |  | -        meta_checkpoint = torch.load(str(builder_args.checkpoint_path), mmap=True, weights_only=True) | 
|  | 351 | +        meta_checkpoint = torch.load( | 
|  | 352 | +            str(builder_args.checkpoint_path), mmap=True, weights_only=True | 
|  | 353 | +        ) | 
| 339 | 354 |         checkpoint = meta_to_tune(meta_checkpoint) | 
| 340 | 355 |     elif builder_args.checkpoint_dir is not None: | 
| 341 | 356 |         # Load multiple checkpoint; ignore the single path. | 
| @@ -372,8 +387,17 @@ def _load_model_default(builder_args, only_config=False): | 
| 372 | 387 |     if "model" in checkpoint and "stories" in str(builder_args.checkpoint_path): | 
| 373 | 388 |         checkpoint = checkpoint["model"] | 
| 374 | 389 | 
 | 
| 375 |  | -    checkpoint = {"model." + k: v for k, v in checkpoint.items()} | 
| 376 |  | -    model.load_state_dict(checkpoint, assign=True, strict=True) | 
|  | 390 | +    if model.config.model_type == ModelType.Flamingo: | 
|  | 391 | +        # TODO: Refactor this. For now, overwrite the model with model loaded from params_path | 
|  | 392 | +        with set_default_dtype(builder_args.precision), torch.device( | 
|  | 393 | +            builder_args.device | 
|  | 394 | +        ): | 
|  | 395 | +            model = Model.from_params(builder_args.params_path) | 
|  | 396 | +        state_dict = flamingo_meta_to_tune(checkpoint) | 
|  | 397 | +        model.model.load_state_dict(state_dict) | 
|  | 398 | +    else: | 
|  | 399 | +        checkpoint = {"model." + k: v for k, v in checkpoint.items()} | 
|  | 400 | +        model.load_state_dict(checkpoint, assign=True, strict=True) | 
| 377 | 401 | 
 | 
| 378 | 402 |     return model | 
| 379 | 403 | 
 | 
|  | 
0 commit comments