1616import torch ._inductor .config
1717import torch .nn as nn
1818
19- from torchtune .models .llama3_2_vision ._convert_weights import llama3_vision_meta_to_tune
20-
21- from torchchat .distributed import launch_distributed , ParallelDims , parallelize_llama
22-
2319from torch .distributed .device_mesh import DeviceMesh
2420
25- from torchtune .models .convert_weights import meta_to_tune
26-
27- from torchtune .training import set_default_dtype
21+ from torchchat .distributed import launch_distributed , ParallelDims , parallelize_llama
2822
2923from torchchat .model import Model , ModelArgs , ModelType
3024
31- from torchtune .models .llama3_1 ._position_embeddings import Llama3ScaledRoPE
32-
3325from torchchat .model_config .model_config import resolve_model_config
3426from torchchat .utils .build_utils import (
3527 device_sync ,
4032from torchchat .utils .measure_time import measure_time
4133from torchchat .utils .quantize import quantize_model
4234
35+ from torchtune .models .convert_weights import meta_to_tune
36+
37+ from torchtune .models .llama3_1 ._position_embeddings import Llama3ScaledRoPE
38+
39+ from torchtune .models .llama3_2_vision ._convert_weights import llama3_vision_meta_to_tune
40+
41+ from torchtune .training import set_default_dtype
42+
4343
4444@dataclass
4545class BuilderArgs :
@@ -61,7 +61,7 @@ class BuilderArgs:
6161 dynamic_shapes : bool = False
6262 max_seq_length : Optional [int ] = None
6363
64- quantized_state_path : Optional [Union [Path , str ]] = None
64+ state_dict_path : Optional [Union [Path , str ]] = None
6565
6666 def __post_init__ (self ):
6767 if self .device is None :
@@ -89,7 +89,9 @@ def __post_init__(self):
8989 ]
9090 for param , param_msg in ignored_params :
9191 if param :
92- print (f"Warning: { param_msg } ignored because an exported DSO or PTE path was specified" )
92+ print (
93+ f"Warning: { param_msg } ignored because an exported DSO or PTE path was specified"
94+ )
9395 else :
9496 self .prefill_possible = True
9597
@@ -173,7 +175,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
173175 is_chat_model = is_chat_model ,
174176 dynamic_shapes = getattr (args , "dynamic_shapes" , False ),
175177 max_seq_length = getattr (args , "max_seq_length" , None ),
176- quantized_state_path = args .quantized_state_path ,
178+ state_dict_path = args .state_dict_path ,
177179 )
178180
179181 @classmethod
@@ -400,10 +402,10 @@ def _load_model_default(builder_args: BuilderArgs) -> Model:
400402 # does not host any actual values, need to reinitialize them in the actual
401403 # device. Only do those buffer initialization, without initializing the entire
402404 # model.
403- decoder_config = model .config .transformer_args [' decoder' ]
404- head_dim = decoder_config [' embed_dim' ] // decoder_config [' num_heads' ]
405- max_seq_len = decoder_config [' max_seq_len' ]
406- rope_base = decoder_config [' rope_base' ]
405+ decoder_config = model .config .transformer_args [" decoder" ]
406+ head_dim = decoder_config [" embed_dim" ] // decoder_config [" num_heads" ]
407+ max_seq_len = decoder_config [" max_seq_len" ]
408+ rope_base = decoder_config [" rope_base" ]
407409 for submodule in model .modules ():
408410 if isinstance (submodule , Llama3ScaledRoPE ):
409411 submodule .__init__ (head_dim , max_seq_len , rope_base )
@@ -491,6 +493,7 @@ def _load_model(builder_args: BuilderArgs) -> Model:
491493 model = model .to (device = builder_args .device , dtype = builder_args .precision )
492494 return model .eval ()
493495
496+
494497def _initialize_model (
495498 builder_args : BuilderArgs ,
496499 quantize ,
@@ -568,17 +571,19 @@ def _initialize_model(
568571 model = _load_model (builder_args )
569572 device_sync (device = builder_args .device )
570573
571- cache_path = builder_args .quantized_state_path
572- quant_checkpoint_exists : bool = cache_path and os .path .isfile (cache_path )
573- if quantize or quant_checkpoint_exists :
574+ state_dict_path = builder_args .state_dict_path
575+ state_dict_exists : bool = state_dict_path and os .path .isfile (state_dict_path )
576+ if quantize or state_dict_exists :
574577
575- if quantize and quant_checkpoint_exists :
576- print ("WARNING: Both a quantized checkpoint and quantize arg were provided; Ignoring quantize arg" )
578+ if quantize and state_dict_exists :
579+ print (
580+ "WARNING: Both a state_dict and quantize arg were provided; Ignoring quantize arg"
581+ )
577582
578- if quant_checkpoint_exists :
583+ if state_dict_exists :
579584 with measure_time ("Time to load quantized state: {time:.02f} seconds" ):
580- print (f"Loading the model_state in: { cache_path } " )
581- model .load_state_dict (cache_path )
585+ print (f"Loading the model_state in: { state_dict_path } " )
586+ model .load_state_dict (state_dict_path )
582587 device_sync (device = builder_args .device )
583588 else :
584589 with measure_time ("Time to quantize model: {time:.02f} seconds" ):
@@ -592,11 +597,12 @@ def _initialize_model(
592597 )
593598 device_sync (device = builder_args .device )
594599
595- if cache_path :
596- with measure_time ("Time to save quantized state: {time:.02f} seconds" ):
600+ if state_dict_path :
601+ with measure_time (
602+ "Time to save quantized state: {time:.02f} seconds"
603+ ):
597604 print (f"Saving the quantized state dict" )
598- torch .save (model .state_dict (), cache_path )
599-
605+ torch .save (model .state_dict (), state_dict_path )
600606
601607 if builder_args .setup_caches :
602608 with torch .device (builder_args .device ):
0 commit comments