44# This source code is licensed under the license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ import argparse
78import os
89import sys
9- import time
1010from dataclasses import dataclass
1111from pathlib import Path
1212from typing import Any , Dict , Optional , Tuple , Union
2121except ImportError :
2222 pass
2323
24- from distributed import (
25- init_distributed ,
26- launch_distributed ,
27- ParallelDims ,
28- parallelize_llama ,
29- )
24+ from distributed import launch_distributed , ParallelDims , parallelize_llama
3025
3126from torch .distributed .device_mesh import DeviceMesh
3227
@@ -101,7 +96,7 @@ def __post_init__(self):
10196 self .prefill_possible = True
10297
10398 @classmethod
104- def from_args (cls , args ): # -> BuilderArgs:
99+ def from_args (cls , args : argparse . Namespace ) -> " BuilderArgs" :
105100 # Handle disabled checkpoint_dir option
106101 checkpoint_dir = None
107102 if hasattr (args , "checkpoint_dir" ):
@@ -183,7 +178,7 @@ def from_args(cls, args): # -> BuilderArgs:
183178 )
184179
185180 @classmethod
186- def from_speculative_args (cls , args ): # -> BuilderArgs:
181+ def from_speculative_args (cls , args : argparse . Namespace ) -> " BuilderArgs" :
187182 speculative_builder_args = BuilderArgs .from_args (args )
188183 # let's limit multi-checkpoint to checker
189184 speculative_builder_args .checkpoint_dir = None
@@ -229,7 +224,7 @@ def __post_init__(self):
229224
230225 def validate_model (
231226 self ,
232- model : Model ,
227+ model : Optional [ Model ] ,
233228 model_description : str = "model" ,
234229 ) -> None :
235230 if model is None :
@@ -250,10 +245,21 @@ def validate_model(
250245 return
251246
252247 @classmethod
253- def from_args (cls , args ): # -> TokenizerArgs:
254- is_sentencepiece = False
255- is_tiktoken = False
256-
248+ def from_args (cls , args : argparse .Namespace ) -> "TokenizerArgs" :
249+ """
250+ Create a TokenizerArgs object from command line arguments.
251+ Specifically, `tokenizer_path` is resolved with precedence:
252+ * From Explicitly provided tokenizer_path
253+ * Resolve via model_config identified by args.model
254+ * Look in the directory of args.checkpoint_path for tokenizer.model
255+ * Look in the directory of args.checkpoint_dir for tokenizer.model
256+
257+ Args:
258+ args (argparse.Namespace): The command line arguments.
259+
260+ Returns:
261+ TokenizerArgs: A TokenizerArgs object.
262+ """
257263 if args .tokenizer_path :
258264 tokenizer_path = args .tokenizer_path
259265 elif args .model : # Using a named, well-known model
@@ -263,7 +269,6 @@ def from_args(cls, args): # -> TokenizerArgs:
263269 / model_config .name
264270 / model_config .tokenizer_file
265271 )
266-
267272 elif args .checkpoint_path :
268273 tokenizer_path = args .checkpoint_path .parent / "tokenizer.model"
269274 elif hasattr (args , "checkpoint_dir" ) and args .checkpoint_dir :
@@ -276,12 +281,7 @@ def from_args(cls, args): # -> TokenizerArgs:
276281 f"did not find tokenizer at { os .path .abspath (tokenizer_path )} "
277282 )
278283
279- return cls (
280- tokenizer_path = tokenizer_path ,
281- is_sentencepiece = is_sentencepiece ,
282- is_tiktoken = is_tiktoken ,
283- t = None ,
284- )
284+ return cls (tokenizer_path = tokenizer_path )
285285
286286
287287def _initialize_tokenizer (tokenizer_args : TokenizerArgs ):
@@ -299,7 +299,7 @@ def _initialize_tokenizer(tokenizer_args: TokenizerArgs):
299299
300300
301301# TODO: remove these once ET supports _weight_int4pack_mm
302- def _set_gguf_kwargs (builder_args , is_et , context : str ):
302+ def _set_gguf_kwargs (builder_args : BuilderArgs , is_et : bool , context : str ) -> None :
303303 assert context in ["export" , "generate" ]
304304 assert builder_args .gguf_kwargs is None
305305
@@ -312,11 +312,11 @@ def _set_gguf_kwargs(builder_args, is_et, context: str):
312312 builder_args .gguf_kwargs ["load_as_quantized" ] = False
313313
314314
315- def _unset_gguf_kwargs (builder_args ) :
315+ def _unset_gguf_kwargs (builder_args : BuilderArgs ) -> None :
316316 builder_args .gguf_kwargs = None
317317
318318
319- def _init_model_on_meta_device (builder_args ) :
319+ def _init_model_on_meta_device (builder_args : BuilderArgs ) -> Model :
320320 with torch .device ("meta" ):
321321 if builder_args .params_path :
322322 return Model .from_params (builder_args .params_path )
@@ -326,7 +326,7 @@ def _init_model_on_meta_device(builder_args):
326326 return Model .from_name (builder_args .checkpoint_path .parent .name )
327327
328328
329- def _load_model_gguf (builder_args , only_config = False ) :
329+ def _load_model_gguf (builder_args : BuilderArgs ) -> Model :
330330 assert builder_args .gguf_path
331331 if builder_args .gguf_kwargs is None :
332332 kwargs = {}
@@ -336,10 +336,10 @@ def _load_model_gguf(builder_args, only_config=False):
336336 return model
337337
338338
339- def _load_model_default (builder_args , only_config = False ) :
339+ def _load_model_default (builder_args : BuilderArgs ) -> Model :
340340 assert not builder_args .gguf_path
341341
342- model = _init_model_on_meta_device (builder_args )
342+ model : Model = _init_model_on_meta_device (builder_args )
343343
344344 if builder_args .params_table and builder_args .params_table .endswith ("Tune" ):
345345 print ("Loading Tune checkpoint" )
@@ -459,7 +459,7 @@ def _maybe_parellelize_model(
459459 return load_checkpoints_to_model (model , builder_args , world_mesh )
460460
461461
462- def _load_model (builder_args , only_config = False ) :
462+ def _load_model (builder_args : BuilderArgs ) -> Model :
463463 world_mesh , parallel_dims = _maybe_init_distributed (builder_args )
464464 if builder_args .gguf_path :
465465 model = _load_model_gguf (builder_args )
@@ -474,12 +474,12 @@ def _load_model(builder_args, only_config=False):
474474
475475
476476def _initialize_model (
477- builder_args ,
477+ builder_args : BuilderArgs ,
478478 quantize ,
479479 tokenizer = None ,
480480 max_seq_length = None ,
481481 support_tensor_subclass : bool = True ,
482- ):
482+ ) -> Model :
483483 print ("Loading model..." )
484484
485485 if builder_args .gguf_path and (builder_args .dso_path or builder_args .pte_path ):
@@ -505,7 +505,7 @@ def _initialize_model(
505505 # ), "quantize not valid for exported DSO model. Specify quantization during export."
506506
507507 with measure_time ("Time to load model: {time:.02f} seconds" ):
508- model = _load_model (builder_args , only_config = True )
508+ model = _load_model (builder_args )
509509 device_sync (device = builder_args .device )
510510
511511 try :
0 commit comments