Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 971ed93

Browse files
authored
Removed unused only_config arg; Added typehints to builder (#1175)
* Removed unused only_config arg; Added typehints to builder * Remove missed arg
1 parent 8d01d9b commit 971ed93

File tree

1 file changed

+32
-32
lines changed

1 file changed

+32
-32
lines changed

torchchat/cli/builder.py

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import argparse
78
import os
89
import sys
9-
import time
1010
from dataclasses import dataclass
1111
from pathlib import Path
1212
from typing import Any, Dict, Optional, Tuple, Union
@@ -21,12 +21,7 @@
2121
except ImportError:
2222
pass
2323

24-
from distributed import (
25-
init_distributed,
26-
launch_distributed,
27-
ParallelDims,
28-
parallelize_llama,
29-
)
24+
from distributed import launch_distributed, ParallelDims, parallelize_llama
3025

3126
from torch.distributed.device_mesh import DeviceMesh
3227

@@ -101,7 +96,7 @@ def __post_init__(self):
10196
self.prefill_possible = True
10297

10398
@classmethod
104-
def from_args(cls, args): # -> BuilderArgs:
99+
def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
105100
# Handle disabled checkpoint_dir option
106101
checkpoint_dir = None
107102
if hasattr(args, "checkpoint_dir"):
@@ -183,7 +178,7 @@ def from_args(cls, args): # -> BuilderArgs:
183178
)
184179

185180
@classmethod
186-
def from_speculative_args(cls, args): # -> BuilderArgs:
181+
def from_speculative_args(cls, args: argparse.Namespace) -> "BuilderArgs":
187182
speculative_builder_args = BuilderArgs.from_args(args)
188183
# let's limit multi-checkpoint to checker
189184
speculative_builder_args.checkpoint_dir = None
@@ -229,7 +224,7 @@ def __post_init__(self):
229224

230225
def validate_model(
231226
self,
232-
model: Model,
227+
model: Optional[Model],
233228
model_description: str = "model",
234229
) -> None:
235230
if model is None:
@@ -250,10 +245,21 @@ def validate_model(
250245
return
251246

252247
@classmethod
253-
def from_args(cls, args): # -> TokenizerArgs:
254-
is_sentencepiece = False
255-
is_tiktoken = False
256-
248+
def from_args(cls, args: argparse.Namespace) -> "TokenizerArgs":
249+
"""
250+
Create a TokenizerArgs object from command line arguments.
251+
Specifically, `tokenizer_path` is resolved with precedence:
252+
* From Explicitly provided tokenizer_path
253+
* Resolve via model_config identified by args.model
254+
* Look in the directory of args.checkpoint_path for tokenizer.model
255+
* Look in the directory of args.checkpoint_dir for tokenizer.model
256+
257+
Args:
258+
args (argparse.Namespace): The command line arguments.
259+
260+
Returns:
261+
TokenizerArgs: A TokenizerArgs object.
262+
"""
257263
if args.tokenizer_path:
258264
tokenizer_path = args.tokenizer_path
259265
elif args.model: # Using a named, well-known model
@@ -263,7 +269,6 @@ def from_args(cls, args): # -> TokenizerArgs:
263269
/ model_config.name
264270
/ model_config.tokenizer_file
265271
)
266-
267272
elif args.checkpoint_path:
268273
tokenizer_path = args.checkpoint_path.parent / "tokenizer.model"
269274
elif hasattr(args, "checkpoint_dir") and args.checkpoint_dir:
@@ -276,12 +281,7 @@ def from_args(cls, args): # -> TokenizerArgs:
276281
f"did not find tokenizer at {os.path.abspath(tokenizer_path)}"
277282
)
278283

279-
return cls(
280-
tokenizer_path=tokenizer_path,
281-
is_sentencepiece=is_sentencepiece,
282-
is_tiktoken=is_tiktoken,
283-
t=None,
284-
)
284+
return cls(tokenizer_path=tokenizer_path)
285285

286286

287287
def _initialize_tokenizer(tokenizer_args: TokenizerArgs):
@@ -299,7 +299,7 @@ def _initialize_tokenizer(tokenizer_args: TokenizerArgs):
299299

300300

301301
# TODO: remove these once ET supports _weight_int4pack_mm
302-
def _set_gguf_kwargs(builder_args, is_et, context: str):
302+
def _set_gguf_kwargs(builder_args: BuilderArgs, is_et: bool, context: str) -> None:
303303
assert context in ["export", "generate"]
304304
assert builder_args.gguf_kwargs is None
305305

@@ -312,11 +312,11 @@ def _set_gguf_kwargs(builder_args, is_et, context: str):
312312
builder_args.gguf_kwargs["load_as_quantized"] = False
313313

314314

315-
def _unset_gguf_kwargs(builder_args):
315+
def _unset_gguf_kwargs(builder_args: BuilderArgs) -> None:
316316
builder_args.gguf_kwargs = None
317317

318318

319-
def _init_model_on_meta_device(builder_args):
319+
def _init_model_on_meta_device(builder_args: BuilderArgs) -> Model:
320320
with torch.device("meta"):
321321
if builder_args.params_path:
322322
return Model.from_params(builder_args.params_path)
@@ -326,7 +326,7 @@ def _init_model_on_meta_device(builder_args):
326326
return Model.from_name(builder_args.checkpoint_path.parent.name)
327327

328328

329-
def _load_model_gguf(builder_args, only_config=False):
329+
def _load_model_gguf(builder_args: BuilderArgs) -> Model:
330330
assert builder_args.gguf_path
331331
if builder_args.gguf_kwargs is None:
332332
kwargs = {}
@@ -336,10 +336,10 @@ def _load_model_gguf(builder_args, only_config=False):
336336
return model
337337

338338

339-
def _load_model_default(builder_args, only_config=False):
339+
def _load_model_default(builder_args: BuilderArgs) -> Model:
340340
assert not builder_args.gguf_path
341341

342-
model = _init_model_on_meta_device(builder_args)
342+
model: Model = _init_model_on_meta_device(builder_args)
343343

344344
if builder_args.params_table and builder_args.params_table.endswith("Tune"):
345345
print("Loading Tune checkpoint")
@@ -459,7 +459,7 @@ def _maybe_parellelize_model(
459459
return load_checkpoints_to_model(model, builder_args, world_mesh)
460460

461461

462-
def _load_model(builder_args, only_config=False):
462+
def _load_model(builder_args: BuilderArgs) -> Model:
463463
world_mesh, parallel_dims = _maybe_init_distributed(builder_args)
464464
if builder_args.gguf_path:
465465
model = _load_model_gguf(builder_args)
@@ -474,12 +474,12 @@ def _load_model(builder_args, only_config=False):
474474

475475

476476
def _initialize_model(
477-
builder_args,
477+
builder_args: BuilderArgs,
478478
quantize,
479479
tokenizer=None,
480480
max_seq_length=None,
481481
support_tensor_subclass: bool = True,
482-
):
482+
) -> Model:
483483
print("Loading model...")
484484

485485
if builder_args.gguf_path and (builder_args.dso_path or builder_args.pte_path):
@@ -505,7 +505,7 @@ def _initialize_model(
505505
# ), "quantize not valid for exported DSO model. Specify quantization during export."
506506

507507
with measure_time("Time to load model: {time:.02f} seconds"):
508-
model = _load_model(builder_args, only_config=True)
508+
model = _load_model(builder_args)
509509
device_sync(device=builder_args.device)
510510

511511
try:
@@ -532,7 +532,7 @@ def _initialize_model(
532532
# ), "quantize not valid for exported PTE model. Specify quantization during export."
533533

534534
with measure_time("Time to load model: {time:.02f} seconds"):
535-
model = _load_model(builder_args, only_config=True)
535+
model = _load_model(builder_args)
536536
device_sync(device=builder_args.device)
537537

538538
try:

0 commit comments

Comments
 (0)