From aa3eb82ec78dcfb17f534d5857917807376e6212 Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Sat, 21 Sep 2024 05:33:16 -0700 Subject: [PATCH 1/2] Removed unused only_config arg; Added typehints to builder --- torchchat/cli/builder.py | 62 ++++++++++++++++++++-------------------- 1 file changed, 31 insertions(+), 31 deletions(-) diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 2933edade..54874ba15 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -4,9 +4,9 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import argparse import os import sys -import time from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, Optional, Tuple, Union @@ -21,12 +21,7 @@ except ImportError: pass -from distributed import ( - init_distributed, - launch_distributed, - ParallelDims, - parallelize_llama, -) +from distributed import launch_distributed, ParallelDims, parallelize_llama from torch.distributed.device_mesh import DeviceMesh @@ -101,7 +96,7 @@ def __post_init__(self): self.prefill_possible = True @classmethod - def from_args(cls, args): # -> BuilderArgs: + def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": # Handle disabled checkpoint_dir option checkpoint_dir = None if hasattr(args, "checkpoint_dir"): @@ -183,7 +178,7 @@ def from_args(cls, args): # -> BuilderArgs: ) @classmethod - def from_speculative_args(cls, args): # -> BuilderArgs: + def from_speculative_args(cls, args: argparse.Namespace) -> "BuilderArgs": speculative_builder_args = BuilderArgs.from_args(args) # let's limit multi-checkpoint to checker speculative_builder_args.checkpoint_dir = None @@ -229,7 +224,7 @@ def __post_init__(self): def validate_model( self, - model: Model, + model: Optional[Model], model_description: str = "model", ) -> None: if model is None: @@ -250,10 +245,21 @@ def validate_model( return @classmethod - def from_args(cls, args): # -> TokenizerArgs: - is_sentencepiece = False - is_tiktoken = False - + def from_args(cls, args: argparse.Namespace) -> "TokenizerArgs": + """ + Create a TokenizerArgs object from command line arguments. + Specifically, `tokenizer_path` is resolved with precedence: + * From Explicitly provided tokenizer_path + * Resolve via model_config identified by args.model + * Look in the directory of args.checkpoint_path for tokenizer.model + * Look in the directory of args.checkpoint_dir for tokenizer.model + + Args: + args (argparse.Namespace): The command line arguments. + + Returns: + TokenizerArgs: A TokenizerArgs object. + """ if args.tokenizer_path: tokenizer_path = args.tokenizer_path elif args.model: # Using a named, well-known model @@ -263,7 +269,6 @@ def from_args(cls, args): # -> TokenizerArgs: / model_config.name / model_config.tokenizer_file ) - elif args.checkpoint_path: tokenizer_path = args.checkpoint_path.parent / "tokenizer.model" elif hasattr(args, "checkpoint_dir") and args.checkpoint_dir: @@ -276,12 +281,7 @@ def from_args(cls, args): # -> TokenizerArgs: f"did not find tokenizer at {os.path.abspath(tokenizer_path)}" ) - return cls( - tokenizer_path=tokenizer_path, - is_sentencepiece=is_sentencepiece, - is_tiktoken=is_tiktoken, - t=None, - ) + return cls(tokenizer_path=tokenizer_path) def _initialize_tokenizer(tokenizer_args: TokenizerArgs): @@ -299,7 +299,7 @@ def _initialize_tokenizer(tokenizer_args: TokenizerArgs): # TODO: remove these once ET supports _weight_int4pack_mm -def _set_gguf_kwargs(builder_args, is_et, context: str): +def _set_gguf_kwargs(builder_args: BuilderArgs, is_et: bool, context: str) -> None: assert context in ["export", "generate"] assert builder_args.gguf_kwargs is None @@ -312,11 +312,11 @@ def _set_gguf_kwargs(builder_args, is_et, context: str): builder_args.gguf_kwargs["load_as_quantized"] = False -def _unset_gguf_kwargs(builder_args): +def _unset_gguf_kwargs(builder_args: BuilderArgs) -> None: builder_args.gguf_kwargs = None -def _init_model_on_meta_device(builder_args): +def _init_model_on_meta_device(builder_args: BuilderArgs) -> Model: with torch.device("meta"): if builder_args.params_path: return Model.from_params(builder_args.params_path) @@ -326,7 +326,7 @@ def _init_model_on_meta_device(builder_args): return Model.from_name(builder_args.checkpoint_path.parent.name) -def _load_model_gguf(builder_args, only_config=False): +def _load_model_gguf(builder_args: BuilderArgs) -> Model: assert builder_args.gguf_path if builder_args.gguf_kwargs is None: kwargs = {} @@ -336,10 +336,10 @@ def _load_model_gguf(builder_args, only_config=False): return model -def _load_model_default(builder_args, only_config=False): +def _load_model_default(builder_args: BuilderArgs) -> Model: assert not builder_args.gguf_path - model = _init_model_on_meta_device(builder_args) + model: Model = _init_model_on_meta_device(builder_args) if builder_args.params_table and builder_args.params_table.endswith("Tune"): print("Loading Tune checkpoint") @@ -459,7 +459,7 @@ def _maybe_parellelize_model( return load_checkpoints_to_model(model, builder_args, world_mesh) -def _load_model(builder_args, only_config=False): +def _load_model(builder_args: BuilderArgs) -> Model: world_mesh, parallel_dims = _maybe_init_distributed(builder_args) if builder_args.gguf_path: model = _load_model_gguf(builder_args) @@ -474,12 +474,12 @@ def _load_model(builder_args, only_config=False): def _initialize_model( - builder_args, + builder_args: BuilderArgs, quantize, tokenizer=None, max_seq_length=None, support_tensor_subclass: bool = True, -): +) -> Model: print("Loading model...") if builder_args.gguf_path and (builder_args.dso_path or builder_args.pte_path): @@ -505,7 +505,7 @@ def _initialize_model( # ), "quantize not valid for exported DSO model. Specify quantization during export." with measure_time("Time to load model: {time:.02f} seconds"): - model = _load_model(builder_args, only_config=True) + model = _load_model(builder_args) device_sync(device=builder_args.device) try: From a44d22b6f025b99a95ed11d4893868b04eea5a72 Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Sun, 22 Sep 2024 01:06:45 -0700 Subject: [PATCH 2/2] Remove missed arg --- torchchat/cli/builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 54874ba15..5dbf48529 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -532,7 +532,7 @@ def _initialize_model( # ), "quantize not valid for exported PTE model. Specify quantization during export." with measure_time("Time to load model: {time:.02f} seconds"): - model = _load_model(builder_args, only_config=True) + model = _load_model(builder_args) device_sync(device=builder_args.device) try: