@@ -56,6 +56,7 @@ class BuilderArgs:
5656 gguf_kwargs : Optional [Dict [str , Any ]] = None
5757 dso_path : Optional [Union [Path , str ]] = None
5858 aoti_package_path : Optional [Union [Path , str ]] = None
59+ snapshot_path : Optional [Union [Path , str ]] = None
5960 pte_path : Optional [Union [Path , str ]] = None
6061 device : Optional [str ] = None
6162 precision : torch .dtype = torch .float32
@@ -87,6 +88,7 @@ def __post_init__(self):
8788 or (self .dso_path and Path (self .dso_path ).is_file ())
8889 or (self .aoti_package_path and Path (self .aoti_package_path ).is_file ())
8990 or (self .pte_path and Path (self .pte_path ).is_file ())
91+ or (self .snapshot_path and Path (self .snapshot_path ).is_file ())
9092 ):
9193 raise RuntimeError (
9294 "need to specify a valid checkpoint path, checkpoint dir, gguf path, DSO path, AOTI PACKAGE or PTE path"
@@ -142,6 +144,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
142144 dso_path = getattr (args , "dso_path" , None )
143145 pte_path = getattr (args , "pte_path" , None )
144146 aoti_package_path = getattr (args , "aoti_package_path" , None )
147+ snapshot_path = getattr (args , "snapshot_path" , None )
145148
146149 is_chat_model = False
147150 if args .is_chat_model :
@@ -169,6 +172,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
169172 output_pte_path = getattr (args , "output_pte_path" , None )
170173 output_aoti_package_path = getattr (args , "output_aoti_package_path" , None )
171174 output_dso_path = getattr (args , "output_dso_path" , None )
175+ output_snapshot_path = getattr (args , "output_snapshot_path" , None )
172176 if output_pte_path and args .dtype .startswith ("fast" ):
173177 if args .dtype == "fast" :
174178 # As per Kimish, float32 should be faster on ET XNNPACK
@@ -206,6 +210,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
206210 dso_path = dso_path ,
207211 aoti_package_path = aoti_package_path ,
208212 pte_path = pte_path ,
213+ snapshot_path = snapshot_path ,
209214 device = args .device ,
210215 precision = dtype ,
211216 setup_caches = (
@@ -631,6 +636,34 @@ def do_nothing(max_batch_size, max_seq_length):
631636 model = PTEModel (config , builder_args .pte_path )
632637 except Exception :
633638 raise RuntimeError (f"Failed to load ET compiled { builder_args .pte_path } " )
639+ elif builder_args .snapshot_path :
640+ # Resolve ModelArgs for constructing the PTEModel
641+ # If a manual params_path is provided, use that
642+ if builder_args .params_path :
643+ config : ModelArgs = ModelArgs .from_params (builder_args .params_path )
644+ else :
645+ # TODO: Instead of loading the whole model, refactor to call a
646+ # helper that generate just model.config
647+ with measure_time ("Time to load model: {time:.02f} seconds" ):
648+ model = _load_model (builder_args )
649+ device_sync (device = builder_args .device )
650+ config = model .config
651+ model = None
652+ try :
653+ model = torch .load (builder_args .snapshot_path , weights_only = False )
654+ except Exception :
655+ raise RuntimeError (f"Failed to load torchchat snapshot { builder_args .snapshot_path } " )
656+ # _active_backend() does not allow DSO & AOTI to be true.
657+ # Choose either.
658+ from torchchat .utils .build_utils import set_backend
659+ set_backend (dso = True , pte = False , aoti_package = False )
660+ if (model .config != config ):
661+ raise RuntimeError ("loaded model architecture mismatch" )
662+ ##
663+ ## import all libraries with custom kernels ans custom operators
664+ ## that quantize may be pulling in
665+ ##
666+
634667 elif builder_args .distributed :
635668 pp_degree = builder_args .pp
636669 tp_degree = builder_args .tp
0 commit comments