Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 32 additions & 32 deletions torchchat/cli/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import os
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union
Expand All @@ -21,12 +21,7 @@
except ImportError:
pass

from distributed import (
init_distributed,
launch_distributed,
ParallelDims,
parallelize_llama,
)
from distributed import launch_distributed, ParallelDims, parallelize_llama

from torch.distributed.device_mesh import DeviceMesh

Expand Down Expand Up @@ -101,7 +96,7 @@ def __post_init__(self):
self.prefill_possible = True

@classmethod
def from_args(cls, args): # -> BuilderArgs:
def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
# Handle disabled checkpoint_dir option
checkpoint_dir = None
if hasattr(args, "checkpoint_dir"):
Expand Down Expand Up @@ -183,7 +178,7 @@ def from_args(cls, args): # -> BuilderArgs:
)

@classmethod
def from_speculative_args(cls, args): # -> BuilderArgs:
def from_speculative_args(cls, args: argparse.Namespace) -> "BuilderArgs":
speculative_builder_args = BuilderArgs.from_args(args)
# let's limit multi-checkpoint to checker
speculative_builder_args.checkpoint_dir = None
Expand Down Expand Up @@ -229,7 +224,7 @@ def __post_init__(self):

def validate_model(
self,
model: Model,
model: Optional[Model],
model_description: str = "model",
) -> None:
if model is None:
Expand All @@ -250,10 +245,21 @@ def validate_model(
return

@classmethod
def from_args(cls, args): # -> TokenizerArgs:
is_sentencepiece = False
is_tiktoken = False

def from_args(cls, args: argparse.Namespace) -> "TokenizerArgs":
"""
Create a TokenizerArgs object from command line arguments.
Specifically, `tokenizer_path` is resolved with precedence:
* From Explicitly provided tokenizer_path
* Resolve via model_config identified by args.model
* Look in the directory of args.checkpoint_path for tokenizer.model
* Look in the directory of args.checkpoint_dir for tokenizer.model

Args:
args (argparse.Namespace): The command line arguments.

Returns:
TokenizerArgs: A TokenizerArgs object.
"""
if args.tokenizer_path:
tokenizer_path = args.tokenizer_path
elif args.model: # Using a named, well-known model
Expand All @@ -263,7 +269,6 @@ def from_args(cls, args): # -> TokenizerArgs:
/ model_config.name
/ model_config.tokenizer_file
)

elif args.checkpoint_path:
tokenizer_path = args.checkpoint_path.parent / "tokenizer.model"
elif hasattr(args, "checkpoint_dir") and args.checkpoint_dir:
Expand All @@ -276,12 +281,7 @@ def from_args(cls, args): # -> TokenizerArgs:
f"did not find tokenizer at {os.path.abspath(tokenizer_path)}"
)

return cls(
tokenizer_path=tokenizer_path,
is_sentencepiece=is_sentencepiece,
is_tiktoken=is_tiktoken,
t=None,
)
return cls(tokenizer_path=tokenizer_path)


def _initialize_tokenizer(tokenizer_args: TokenizerArgs):
Expand All @@ -299,7 +299,7 @@ def _initialize_tokenizer(tokenizer_args: TokenizerArgs):


# TODO: remove these once ET supports _weight_int4pack_mm
def _set_gguf_kwargs(builder_args, is_et, context: str):
def _set_gguf_kwargs(builder_args: BuilderArgs, is_et: bool, context: str) -> None:
assert context in ["export", "generate"]
assert builder_args.gguf_kwargs is None

Expand All @@ -312,11 +312,11 @@ def _set_gguf_kwargs(builder_args, is_et, context: str):
builder_args.gguf_kwargs["load_as_quantized"] = False


def _unset_gguf_kwargs(builder_args):
def _unset_gguf_kwargs(builder_args: BuilderArgs) -> None:
builder_args.gguf_kwargs = None


def _init_model_on_meta_device(builder_args):
def _init_model_on_meta_device(builder_args: BuilderArgs) -> Model:
with torch.device("meta"):
if builder_args.params_path:
return Model.from_params(builder_args.params_path)
Expand All @@ -326,7 +326,7 @@ def _init_model_on_meta_device(builder_args):
return Model.from_name(builder_args.checkpoint_path.parent.name)


def _load_model_gguf(builder_args, only_config=False):
def _load_model_gguf(builder_args: BuilderArgs) -> Model:
assert builder_args.gguf_path
if builder_args.gguf_kwargs is None:
kwargs = {}
Expand All @@ -336,10 +336,10 @@ def _load_model_gguf(builder_args, only_config=False):
return model


def _load_model_default(builder_args, only_config=False):
def _load_model_default(builder_args: BuilderArgs) -> Model:
assert not builder_args.gguf_path

model = _init_model_on_meta_device(builder_args)
model: Model = _init_model_on_meta_device(builder_args)

if builder_args.params_table and builder_args.params_table.endswith("Tune"):
print("Loading Tune checkpoint")
Expand Down Expand Up @@ -459,7 +459,7 @@ def _maybe_parellelize_model(
return load_checkpoints_to_model(model, builder_args, world_mesh)


def _load_model(builder_args, only_config=False):
def _load_model(builder_args: BuilderArgs) -> Model:
world_mesh, parallel_dims = _maybe_init_distributed(builder_args)
if builder_args.gguf_path:
model = _load_model_gguf(builder_args)
Expand All @@ -474,12 +474,12 @@ def _load_model(builder_args, only_config=False):


def _initialize_model(
builder_args,
builder_args: BuilderArgs,
quantize,
tokenizer=None,
max_seq_length=None,
support_tensor_subclass: bool = True,
):
) -> Model:
print("Loading model...")

if builder_args.gguf_path and (builder_args.dso_path or builder_args.pte_path):
Expand All @@ -505,7 +505,7 @@ def _initialize_model(
# ), "quantize not valid for exported DSO model. Specify quantization during export."

with measure_time("Time to load model: {time:.02f} seconds"):
model = _load_model(builder_args, only_config=True)
model = _load_model(builder_args)
device_sync(device=builder_args.device)

try:
Expand All @@ -532,7 +532,7 @@ def _initialize_model(
# ), "quantize not valid for exported PTE model. Specify quantization during export."

with measure_time("Time to load model: {time:.02f} seconds"):
model = _load_model(builder_args, only_config=True)
model = _load_model(builder_args)
device_sync(device=builder_args.device)

try:
Expand Down
Loading