@@ -56,6 +56,7 @@ class BuilderArgs:
5656    gguf_kwargs : Optional [Dict [str , Any ]] =  None 
5757    dso_path : Optional [Union [Path , str ]] =  None 
5858    aoti_package_path : Optional [Union [Path , str ]] =  None 
59+     snapshot_path : Optional [Union [Path , str ]] =  None 
5960    pte_path : Optional [Union [Path , str ]] =  None 
6061    device : Optional [str ] =  None 
6162    precision : torch .dtype  =  torch .float32 
@@ -87,6 +88,7 @@ def __post_init__(self):
8788            or  (self .dso_path  and  Path (self .dso_path ).is_file ())
8889            or  (self .aoti_package_path  and  Path (self .aoti_package_path ).is_file ())
8990            or  (self .pte_path  and  Path (self .pte_path ).is_file ())
91+             or  (self .snapshot_path  and  Path (self .snapshot_path ).is_file ())
9092        ):
9193            raise  RuntimeError (
9294                "need to specify a valid checkpoint path, checkpoint dir, gguf path, DSO path, AOTI PACKAGE or PTE path" 
@@ -142,6 +144,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
142144        dso_path  =  getattr (args , "dso_path" , None )
143145        pte_path  =  getattr (args , "pte_path" , None )
144146        aoti_package_path  =  getattr (args , "aoti_package_path" , None )
147+         snapshot_path  =  getattr (args , "snapshot_path" , None )
145148
146149        is_chat_model  =  False 
147150        if  args .is_chat_model :
@@ -169,6 +172,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
169172        output_pte_path  =  getattr (args , "output_pte_path" , None )
170173        output_aoti_package_path  =  getattr (args , "output_aoti_package_path" , None )
171174        output_dso_path  =  getattr (args , "output_dso_path" , None )
175+         output_snapshot_path  =  getattr (args , "output_snapshot_path" , None )
172176        if  output_pte_path  and  args .dtype .startswith ("fast" ):
173177            if  args .dtype  ==  "fast" :
174178                # As per Kimish, float32 should be faster on ET XNNPACK 
@@ -206,6 +210,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
206210            dso_path = dso_path ,
207211            aoti_package_path = aoti_package_path ,
208212            pte_path = pte_path ,
213+             snapshot_path = snapshot_path ,
209214            device = args .device ,
210215            precision = dtype ,
211216            setup_caches = (
@@ -631,6 +636,34 @@ def do_nothing(max_batch_size, max_seq_length):
631636            model  =  PTEModel (config , builder_args .pte_path )
632637        except  Exception :
633638            raise  RuntimeError (f"Failed to load ET compiled { builder_args .pte_path }  )
639+     elif  builder_args .snapshot_path :
640+         # Resolve ModelArgs for constructing the PTEModel 
641+         # If a manual params_path is provided, use that 
642+         if  builder_args .params_path :
643+             config : ModelArgs  =  ModelArgs .from_params (builder_args .params_path )
644+         else :
645+             # TODO: Instead of loading the whole model, refactor to call a 
646+             # helper that generate just model.config 
647+             with  measure_time ("Time to load model: {time:.02f} seconds" ):
648+                 model  =  _load_model (builder_args )
649+                 device_sync (device = builder_args .device )
650+                 config  =  model .config 
651+                 model  =  None 
652+         try :
653+             model  =  torch .load (builder_args .snapshot_path , weights_only = False )
654+         except  Exception :
655+             raise  RuntimeError (f"Failed to load torchchat snapshot { builder_args .snapshot_path }  )
656+         # _active_backend() does not allow DSO & AOTI to be true.  
657+         # Choose either. 
658+         from  torchchat .utils .build_utils  import  set_backend 
659+         set_backend  (dso = True , pte = False , aoti_package = False )
660+         if  (model .config  !=  config ):
661+             raise  RuntimeError ("loaded model architecture mismatch" )
662+         ##         
663+         ## import all libraries with custom kernels ans custom operators 
664+         ## that quantize may be pulling in 
665+         ## 
666+ 
634667    elif  builder_args .distributed :
635668        pp_degree  =  builder_args .pp 
636669        tp_degree  =  builder_args .tp 
0 commit comments