|
12 | 12 | import json |
13 | 13 | from torch.nn import Module |
14 | 14 | from typing import Any, Dict, Tuple, Set, Optional |
| 15 | +from pathlib import Path |
15 | 16 |
|
16 | 17 | from torch.distributed._tensor import DTensor |
17 | 18 | from torchchat.distributed.dtensor_utils import convert_to_dtensor |
| 19 | +from torchchat.cli.builder import BuilderArgs, _load_checkpoint |
18 | 20 |
|
19 | 21 |
|
20 | 22 | _DEFAULT_SAFETENSOR_FILE_NAME = "model.safetensors.index.json" |
@@ -182,10 +184,10 @@ def load_safetensor_weights( |
182 | 184 | update_state_dict( |
183 | 185 | stage_state_dict, |
184 | 186 | checkpoint, |
185 | | - new_to_old_keymap, |
186 | | - updated_states, |
187 | 187 | device, |
188 | | - model_config, |
| 188 | + model_config=model_config, |
| 189 | + new_to_old_keymap=new_to_old_keymap, |
| 190 | + updated_states=updated_states, |
189 | 191 | ) |
190 | 192 | except FileNotFoundError: |
191 | 193 | logger.error(f"File not found: {full_path}") |
@@ -264,24 +266,36 @@ def permute_weight_to_attn_heads(w, n_heads, head_dim, model_dim): |
264 | 266 | def update_state_dict( |
265 | 267 | state_dict: Dict[str, torch.Tensor], |
266 | 268 | checkpoint: Dict[str, torch.Tensor], |
267 | | - new_to_old_keymap: Dict[str, str], |
268 | | - updated_states: Set[str], |
269 | 269 | device: torch.device, |
270 | 270 | model_config: Optional[Dict] = None, |
| 271 | + new_to_old_keymap: Optional[Dict[str, str]] = None, |
| 272 | + updated_states: Optional[Set[str]]= None, |
271 | 273 | ): |
| 274 | + """ |
| 275 | + Update the state dict with the checkpoint tensors. |
| 276 | + Note: |
| 277 | + - For HF format, `new_to_old_keymap` is a mapping from the new key to the old |
| 278 | + key. |
| 279 | + - For torchchat format, `new_to_old_keymap` is None (because FQN conversion |
| 280 | + has been doen by torchchat download script). |
| 281 | + """ |
272 | 282 | # for handling attn head permuting |
273 | 283 | num_heads = model_config.n_heads |
274 | 284 | dim = model_config.dim |
275 | 285 | num_local_heads = model_config.n_local_heads |
276 | 286 | head_dim = model_config.head_dim |
277 | 287 |
|
278 | 288 | for param in state_dict.keys(): |
279 | | - # TODO: clean this up together with `purge_fqn_prefix` when we switch |
280 | | - # from creating Transformer to creating model |
281 | | - model_param = ( |
282 | | - "output.weight" if param == "output.weight" else f"model.{param}" |
283 | | - ) |
284 | | - old_param = new_to_old_keymap.get(model_param) |
| 289 | + if new_to_old_keymap is not None: |
| 290 | + # TODO: clean the following manual prefix together with |
| 291 | + # `purge_fqn_prefix` when we switch from creating Transformer to |
| 292 | + # creating model |
| 293 | + model_param = ( |
| 294 | + "output.weight" if param == "output.weight" else f"model.{param}" |
| 295 | + ) |
| 296 | + old_param = new_to_old_keymap[model_param] |
| 297 | + else: |
| 298 | + old_param = param |
285 | 299 |
|
286 | 300 | if old_param not in checkpoint: |
287 | 301 | # Maybe this param is in other files |
@@ -309,7 +323,9 @@ def update_state_dict( |
309 | 323 |
|
310 | 324 | # Update model state dict with checkpoint tensor |
311 | 325 | state_dict[param] = checkpoint_tensor |
312 | | - updated_states.add(param) |
| 326 | + |
| 327 | + if updated_states is not None: |
| 328 | + updated_states.add(param) |
313 | 329 |
|
314 | 330 |
|
315 | 331 | def format_tensor_info(tensor: torch.Tensor) -> str: |
@@ -378,3 +394,59 @@ def load_weights_from_hf_format(stage_module, distribution, device, model_config |
378 | 394 | ) |
379 | 395 | if num_missing_weights > 0: |
380 | 396 | raise ValueError(f"Missing {num_missing_weights} weights") |
| 397 | + |
| 398 | + |
| 399 | +# HACK: assuming single file for torchchat's converted checkpoints. We should |
| 400 | +# remove this after converging to torchchat's model building process. |
| 401 | +# In particular, |
| 402 | +# builder_args = BuilderArgs.from_args(args) |
| 403 | +# will tell us if there is a single file or a directory. |
| 404 | +TORCHCHCAT_SINGLE_FILE_CHECKPOINT = True |
| 405 | + |
| 406 | +def load_weights_from_torchchat_format(stage_module, distribution, device, model_config): |
| 407 | + """ |
| 408 | + Load the weights from torchchat format (single binary file), and fill into |
| 409 | + `stage_module`. Model config is needed b/c we permute wq and wk weights |
| 410 | + based on attn heads. |
| 411 | + """ |
| 412 | + stage_state_dict = stage_module.state_dict() |
| 413 | + # TODO: clean this up together with `purge_fqn_prefix` when we switch |
| 414 | + stage_state_dict = purge_fqn_prefix(stage_state_dict, "model.") |
| 415 | + |
| 416 | + # Load checkpoint from torchchat cache |
| 417 | + default_cache_dir = Path( |
| 418 | + os.getenv("TORCHCHAT_MODELDIR", "~/.torchchat/model-cache") |
| 419 | + ).expanduser() |
| 420 | + # Distribution is like "meta-llama/Meta-Llama-3-8B-Instruct" |
| 421 | + # Join it with the default cache dir to get the checkpoint dir |
| 422 | + checkpoint_dir = default_cache_dir / distribution |
| 423 | + # Provide path in single-file case, provide dir in multi-file case. See |
| 424 | + # `_load_checkpoint`. |
| 425 | + if TORCHCHCAT_SINGLE_FILE_CHECKPOINT: |
| 426 | + checkpoint_path = checkpoint_dir / "model.pth" |
| 427 | + checkpoint_dir = None |
| 428 | + else: |
| 429 | + checkpoint_path = None |
| 430 | + # First, construct BuilderArgs |
| 431 | + args_dict = { |
| 432 | + "device": device, |
| 433 | + "checkpoint_dir": checkpoint_dir, |
| 434 | + "checkpoint_path": checkpoint_path, |
| 435 | + } |
| 436 | + builder_args = BuilderArgs(**args_dict) |
| 437 | + # Then, load the checkpoint using torchchat util |
| 438 | + checkpoint = _load_checkpoint(builder_args) |
| 439 | + |
| 440 | + updated_states: Set[str] = set() |
| 441 | + # This step converts full tensor into DTensor |
| 442 | + update_state_dict( |
| 443 | + stage_state_dict, |
| 444 | + checkpoint, |
| 445 | + device, |
| 446 | + model_config=model_config, |
| 447 | + updated_states=updated_states, |
| 448 | + ) |
| 449 | + |
| 450 | + # Fill state dict into stage module |
| 451 | + stage_module.load_state_dict(stage_state_dict, strict=False, assign=True) |
| 452 | + logger.info(f"Successfully loaded {len(updated_states)} weights into stage module") |
0 commit comments