@@ -56,6 +56,7 @@ class BuilderArgs:
5656 gguf_kwargs : Optional [Dict [str , Any ]] = None
5757 dso_path : Optional [Union [Path , str ]] = None
5858 aoti_package_path : Optional [Union [Path , str ]] = None
59+ snapshot_path : Optional [Union [Path , str ]] = None
5960 pte_path : Optional [Union [Path , str ]] = None
6061 device : Optional [str ] = None
6162 precision : torch .dtype = torch .float32
@@ -81,6 +82,7 @@ def __post_init__(self):
8182 or (self .dso_path and Path (self .dso_path ).is_file ())
8283 or (self .aoti_package_path and Path (self .aoti_package_path ).is_file ())
8384 or (self .pte_path and Path (self .pte_path ).is_file ())
85+ or (self .snapshot_path and Path (self .snapshot_path ).is_file ())
8486 ):
8587 raise RuntimeError (
8688 "need to specify a valid checkpoint path, checkpoint dir, gguf path, DSO path, AOTI PACKAGE or PTE path"
@@ -136,6 +138,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
136138 dso_path = getattr (args , "dso_path" , None )
137139 pte_path = getattr (args , "pte_path" , None )
138140 aoti_package_path = getattr (args , "aoti_package_path" , None )
141+ snapshot_path = getattr (args , "snapshot_path" , None )
139142
140143 is_chat_model = False
141144 if args .is_chat_model :
@@ -163,6 +166,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
163166 output_pte_path = getattr (args , "output_pte_path" , None )
164167 output_aoti_package_path = getattr (args , "output_aoti_package_path" , None )
165168 output_dso_path = getattr (args , "output_dso_path" , None )
169+ output_snapshot_path = getattr (args , "output_snapshot_path" , None )
166170 if output_pte_path and args .dtype .startswith ("fast" ):
167171 if args .dtype == "fast" :
168172 # As per Kimish, float32 should be faster on ET XNNPACK
@@ -189,6 +193,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
189193 dso_path = dso_path ,
190194 aoti_package_path = aoti_package_path ,
191195 pte_path = pte_path ,
196+ snapshot_path = snapshot_path ,
192197 device = args .device ,
193198 precision = dtype ,
194199 setup_caches = (
@@ -614,6 +619,33 @@ def do_nothing(max_batch_size, max_seq_length):
614619 model = PTEModel (config , builder_args .pte_path )
615620 except Exception :
616621 raise RuntimeError (f"Failed to load ET compiled { builder_args .pte_path } " )
622+ elif builder_args .snapshot_path :
623+ # Resolve ModelArgs for constructing the PTEModel
624+ # If a manual params_path is provided, use that
625+ if builder_args .params_path :
626+ config : ModelArgs = ModelArgs .from_params (builder_args .params_path )
627+ else :
628+ # TODO: Instead of loading the whole model, refactor to call a
629+ # helper that generate just model.config
630+ with measure_time ("Time to load model: {time:.02f} seconds" ):
631+ model = _load_model (builder_args )
632+ device_sync (device = builder_args .device )
633+ config = model .config
634+ model = None
635+ try :
636+ model = torch .load (builder_args .snapshot_path , weights_only = False )
637+ except Exception :
638+ raise RuntimeError (f"Failed to load torchchat snapshot { builder_args .snapshot_path } " )
639+ # _active_backend() does not allow DSO & AOTI to be true.
640+ # Choose either.
641+ set_backend (dso = True , pte = False , aoti_package = False )
642+ if (model .config != config ):
643+ raise RuntimeError ("loaded model architecture mismatch" )
644+ ##
645+ ## import all libraries with custom kernels ans custom operators
646+ ## that quantize may be pulling in
647+ ##
648+
617649 elif builder_args .distributed :
618650 pp_degree = builder_args .pp
619651 tp_degree = builder_args .tp
0 commit comments