Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 616104d

Browse files
authored
Merge branch 'main' into patch-10
2 parents 656fd94 + 6eae887 commit 616104d

File tree

11 files changed

+157
-176
lines changed

11 files changed

+157
-176
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,7 @@ We really value our community and the contributions made by our wonderful users.
575575

576576
To connect with us and other community members, we invite you to join our Slack community by filling out this [form](https://docs.google.com/forms/d/e/1FAIpQLSeADnUNW36fjKjYzyHDOzEB_abKQE9b6gqqW9NXse6O0MWh0A/viewform). Once you've joined, you can:
577577
* Head to the `#torchchat-general` channel for general questions, discussion, and community support.
578-
* Join the `#torchchat-contribution` channel if you're interested in contributing directly to project development.
578+
* Join the `#torchchat-contributors` channel if you're interested in contributing directly to project development.
579579

580580
Looking forward to discussing with you about torchchat future!
581581

install/install_requirements.sh

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,26 +9,39 @@ set -eou pipefail
99

1010
# Install required python dependencies for developing
1111
# Dependencies are defined in .pyproject.toml
12-
PYTHON_EXECUTABLE=${PYTHON_EXECUTABLE:-python}
13-
if [[ -z ${CONDA_DEFAULT_ENV:-} ]] || [[ ${CONDA_DEFAULT_ENV:-} == "base" ]] || [[ ! -x "$(command -v python)" ]];
12+
if [ -z "${PYTHON_EXECUTABLE:-}" ];
1413
then
15-
PYTHON_EXECUTABLE=python3
14+
if [[ -z ${CONDA_DEFAULT_ENV:-} ]] || [[ ${CONDA_DEFAULT_ENV:-} == "base" ]] || [[ ! -x "$(command -v python)" ]];
15+
then
16+
PYTHON_EXECUTABLE=python3
17+
fi
1618
fi
19+
echo "Using python executable: $PYTHON_EXECUTABLE"
1720

21+
PYTHON_SYS_VERSION="$($PYTHON_EXECUTABLE -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')")"
1822
# Check python version. Expect 3.10.x or 3.11.x
19-
printf "import sys\nif sys.version_info.major != 3 or sys.version_info.minor < 10 :\n\tprint('Please use Python >=3.10');sys.exit(1)\n" | $PYTHON_EXECUTABLE
20-
if [[ $? -ne 0 ]]
23+
if ! $PYTHON_EXECUTABLE -c "
24+
import sys
25+
if sys.version_info < (3, 10) or sys.version_info >= (3, 12):
26+
sys.exit(1)
27+
";
2128
then
29+
echo "Python version must be 3.10.x or 3.11.x. Detected version: $PYTHON_SYS_VERSION"
2230
exit 1
2331
fi
2432

2533
if [[ "$PYTHON_EXECUTABLE" == "python" ]];
2634
then
2735
PIP_EXECUTABLE=pip
28-
else
36+
elif [[ "$PYTHON_EXECUTABLE" == "python3" ]];
37+
then
2938
PIP_EXECUTABLE=pip3
39+
else
40+
PIP_EXECUTABLE=pip${PYTHON_SYS_VERSION}
3041
fi
3142

43+
echo "Using pip executable: $PIP_EXECUTABLE"
44+
3245
#
3346
# First install requirements in install/requirements.txt. Older torch may be
3447
# installed from the dependency of other models. It will be overridden by

tokenizer/base64.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#pragma once
2626

2727
#include <cassert>
28+
#include <cstdint>
2829
#include <string>
2930
#include <string_view>
3031

torchchat.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import argparse
88
import logging
9-
import subprocess
9+
import signal
1010
import sys
1111

1212
# MPS ops missing with Multimodal torchtune
@@ -25,7 +25,15 @@
2525
default_device = "cpu"
2626

2727

28+
def signal_handler(sig, frame):
29+
print("\nInterrupted by user. Bye!\n")
30+
sys.exit(0)
31+
32+
2833
if __name__ == "__main__":
34+
# Set the signal handler for SIGINT
35+
signal.signal(signal.SIGINT, signal_handler)
36+
2937
# Initialize the top-level parser
3038
parser = argparse.ArgumentParser(
3139
prog="torchchat",

torchchat/cli/builder.py

Lines changed: 1 addition & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,6 @@
1616
import torch._inductor.config
1717
import torch.nn as nn
1818

19-
from torch.distributed.device_mesh import DeviceMesh
20-
from torch.distributed.elastic.multiprocessing.errors import record
21-
from torch.distributed.elastic.utils.distributed import get_free_port
22-
23-
from torchchat.distributed import launch_distributed, ParallelDims, parallelize_llama
24-
2519
from torchchat.model import Model, ModelArgs, ModelType
2620

2721
from torchchat.model_config.model_config import resolve_model_config
@@ -464,77 +458,11 @@ def _load_model_default(builder_args: BuilderArgs) -> Model:
464458
return model
465459

466460

467-
def _maybe_init_distributed(
468-
builder_args: BuilderArgs,
469-
) -> Tuple[Optional[DeviceMesh], Optional[ParallelDims]]:
470-
"""
471-
Initialize distributed related setups if the user specified
472-
using distributed inference. If not, this is a no-op.
473-
474-
Args:
475-
builder_args (:class:`BuilderArgs`):
476-
Command args for model building.
477-
Returns:
478-
Tuple[Optional[DeviceMesh], Optional[ParallelDims]]:
479-
- The first element is an optional DeviceMesh object,
480-
which which describes the mesh topology of devices for the DTensor.
481-
- The second element is an optional ParallelDims object,
482-
which represents the parallel dimensions configuration.
483-
"""
484-
if not builder_args.use_distributed:
485-
return None, None
486-
dist_config = "llama3_8B.toml" # TODO - integrate with chat cmd line
487-
488-
world_mesh, parallel_dims = launch_distributed(dist_config)
489-
490-
assert (
491-
world_mesh is not None and parallel_dims is not None
492-
), f"failed to launch distributed using {dist_config}"
493-
494-
return world_mesh, parallel_dims
495-
496-
497-
def _maybe_parallelize_model(
498-
model: nn.Module,
499-
builder_args: BuilderArgs,
500-
world_mesh: DeviceMesh,
501-
parallel_dims: ParallelDims,
502-
) -> nn.Module:
503-
"""
504-
We parallelize the module and load the distributed checkpoint to the model
505-
if the user specifies using distributed inference. If not, this is a no-op.
506-
507-
Args:
508-
model (:class:`nn.Module`):
509-
Module to be parallelized.
510-
builder_args (:class:`BuilderArgs`):
511-
Command args for model building.
512-
world_mesh (:class:`DeviceMesh`):
513-
Object which describes the mesh topology
514-
of devices for the DTensor.
515-
parallel_dims (:class:`ParallelDims`):
516-
Object which represents the parallel dimensions configuration.
517-
Returns:
518-
A :class:`nn.Module` object which is parallelized and checkpoint loaded
519-
if the user specifies using distributed inference.
520-
"""
521-
if world_mesh is None:
522-
return model
523-
assert parallel_dims is not None
524-
print("Applying model parallel to model ...")
525-
parallelize_llama(model, world_mesh, parallel_dims)
526-
return load_checkpoints_to_model(model, builder_args, world_mesh)
527-
528-
529461
def _load_model(builder_args: BuilderArgs) -> Model:
530-
# world_mesh, parallel_dims = _maybe_init_distributed(builder_args)
531462
if builder_args.gguf_path:
532463
model = _load_model_gguf(builder_args)
533-
# elif builder_args.use_distributed:
534-
# model = _init_model_on_meta_device(builder_args)
535464
else:
536465
model = _load_model_default(builder_args)
537-
# model = _maybe_parallelize_model(model, builder_args, world_mesh, parallel_dims)
538466

539467
if builder_args.dso_path or builder_args.aoti_package_path:
540468
# AOTI-compoiled model will load its own weights.
@@ -706,4 +634,4 @@ def tokenizer_setting_to_name(tiktoken: bool, tokenizers: bool) -> str:
706634
return "TikToken"
707635
if tokenizers:
708636
return "Tokenizers"
709-
return "SentencePiece"
637+
return "SentencePiece"

torchchat/cli/cli.py

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,24 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import argparse
8+
import importlib.metadata
89
import json
910
import logging
1011
import os
1112
import sys
1213
from pathlib import Path
1314

14-
import torch
15-
16-
from torchchat.cli.download import download_and_convert, is_model_downloaded
17-
1815
from torchchat.utils.build_utils import (
1916
allowable_dtype_names,
2017
allowable_params_table,
21-
get_device_str,
2218
)
2319

2420
logging.basicConfig(level=logging.INFO, format="%(message)s")
2521
logger = logging.getLogger(__name__)
2622

2723
default_device = os.getenv("TORCHCHAT_DEVICE", "fast")
24+
default_dtype = os.getenv("TORCHCHAT_PRECISION", "fast")
25+
2826
default_model_dir = Path(
2927
os.getenv("TORCHCHAT_MODELDIR", "~/.torchchat/model-cache")
3028
).expanduser()
@@ -42,6 +40,9 @@
4240

4341
# Handle CLI arguments that are common to a majority of subcommands.
4442
def check_args(args, verb: str) -> None:
43+
# Local import to avoid unnecessary expensive imports
44+
from torchchat.cli.download import download_and_convert, is_model_downloaded
45+
4546
# Handle model download. Skip this for download, since it has slightly
4647
# different semantics.
4748
if (
@@ -150,9 +151,9 @@ def _add_model_config_args(parser, verb: str) -> None:
150151

151152
model_config_parser.add_argument(
152153
"--dtype",
153-
default="fast",
154+
default=None,
154155
choices=allowable_dtype_names(),
155-
help="Override the dtype of the model (default is the checkpoint dtype). Options: bf16, fp16, fp32, fast16, fast",
156+
help="Override the dtype of the model. Options: bf16, fp16, fp32, fast16, fast",
156157
)
157158
model_config_parser.add_argument(
158159
"--quantize",
@@ -166,9 +167,9 @@ def _add_model_config_args(parser, verb: str) -> None:
166167
model_config_parser.add_argument(
167168
"--device",
168169
type=str,
169-
default=default_device,
170+
default=None,
170171
choices=["fast", "cpu", "cuda", "mps"],
171-
help="Hardware device to use. Options: cpu, cuda, mps",
172+
help="Hardware device to use. Options: fast, cpu, cuda, mps",
172173
)
173174

174175

@@ -498,9 +499,10 @@ def _add_speculative_execution_args(parser) -> None:
498499

499500

500501
def arg_init(args):
501-
if not (torch.__version__ > "2.3"):
502+
torch_version = importlib.metadata.version("torch")
503+
if not torch_version or (torch_version <= "2.3"):
502504
raise RuntimeError(
503-
f"You are using PyTorch {torch.__version__}. At this time, torchchat uses the latest PyTorch technology with high-performance kernels only available in PyTorch nightly until the PyTorch 2.4 release"
505+
f"You are using PyTorch {torch_version}. At this time, torchchat uses the latest PyTorch technology with high-performance kernels only available in PyTorch nightly until the PyTorch 2.4 release"
504506
)
505507

506508
if sys.version_info.major != 3 or sys.version_info.minor < 10:
@@ -513,17 +515,34 @@ def arg_init(args):
513515
if isinstance(args.quantize, str):
514516
args.quantize = json.loads(args.quantize)
515517

516-
# if we specify dtype in quantization recipe, replicate it as args.dtype
517-
args.dtype = args.quantize.get("precision", {}).get("dtype", args.dtype)
518+
# if we specify dtype in quantization recipe, allow args.dtype top override if specified
519+
if args.dtype is None:
520+
args.dtype = args.quantize.get("precision", {}).get("dtype", default_dtype)
521+
else:
522+
precision_handler = args.quantize.get("precision", None)
523+
if precision_handler:
524+
if precision_handler["dtype"] != args.dtype:
525+
print('overriding json-specified dtype {precision_handler["dtype"]} with cli dtype {args.dtype}')
526+
precision_handler["dtype"] = args.dtype
518527

519528
if getattr(args, "output_pte_path", None):
520-
if args.device not in ["cpu", "fast"]:
529+
if args.device not in [None, "cpu", "fast"]:
521530
raise RuntimeError("Device not supported by ExecuTorch")
522531
args.device = "cpu"
523532
else:
524-
args.device = get_device_str(
525-
args.quantize.get("executor", {}).get("accelerator", args.device)
526-
)
533+
# Localized import to minimize expensive imports
534+
from torchchat.utils.build_utils import get_device_str
535+
536+
if args.device is None:
537+
args.device = get_device_str(
538+
args.quantize.get("executor", {}).get("accelerator", default_device)
539+
)
540+
else:
541+
executor_handler = args.quantize.get("executor", None)
542+
if executor_handler:
543+
if executor_handler["accelerator"] != args.device:
544+
print('overriding json-specified device {executor_handler["accelerator"]} with cli device {args.device}')
545+
executor_handler["accelerator"] = args.device
527546

528547
if "mps" in args.device:
529548
if getattr(args, "compile", False) or getattr(args, "compile_prefill", False):
@@ -534,5 +553,8 @@ def arg_init(args):
534553
vars(args)["compile_prefill"] = False
535554

536555
if hasattr(args, "seed") and args.seed:
556+
# Localized import to minimize expensive imports
557+
import torch
558+
537559
torch.manual_seed(args.seed)
538560
return args

torchchat/cli/convert_hf_checkpoint.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,25 +11,23 @@
1111
from pathlib import Path
1212
from typing import Optional
1313

14-
import torch
15-
16-
from torchchat.model import TransformerArgs
17-
1814
# support running without installing as a package
1915
wd = Path(__file__).parent.parent
2016
sys.path.append(str(wd.resolve()))
2117
sys.path.append(str((wd / "build").resolve()))
2218

23-
from torchchat.model import ModelArgs
24-
2519

26-
@torch.inference_mode()
2720
def convert_hf_checkpoint(
2821
*,
2922
model_dir: Optional[Path] = None,
3023
model_name: Optional[str] = None,
3124
remove_bin_files: bool = False,
3225
) -> None:
26+
27+
# Local imports to avoid expensive imports
28+
from torchchat.model import ModelArgs, TransformerArgs
29+
import torch
30+
3331
if model_dir is None:
3432
model_dir = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf")
3533
if model_name is None:
@@ -58,10 +56,11 @@ def convert_hf_checkpoint(
5856
tokenizer_pth = model_dir / "original" / "tokenizer.model"
5957
if consolidated_pth.is_file() and tokenizer_pth.is_file():
6058
# Confirm we can load it
61-
loaded_result = torch.load(
62-
str(consolidated_pth), map_location="cpu", mmap=True, weights_only=True
63-
)
64-
del loaded_result # No longer needed
59+
with torch.inference_mode():
60+
loaded_result = torch.load(
61+
str(consolidated_pth), map_location="cpu", mmap=True, weights_only=True
62+
)
63+
del loaded_result # No longer needed
6564
print(f"Moving checkpoint to {model_dir / 'model.pth'}.")
6665
os.rename(consolidated_pth, model_dir / "model.pth")
6766
os.rename(tokenizer_pth, model_dir / "tokenizer.model")
@@ -130,7 +129,8 @@ def load_safetensors():
130129
state_dict = None
131130
for loader in loaders:
132131
try:
133-
state_dict = loader()
132+
with torch.inference_mode():
133+
state_dict = loader()
134134
break
135135
except Exception:
136136
continue
@@ -173,7 +173,6 @@ def load_safetensors():
173173
os.remove(file)
174174

175175

176-
@torch.inference_mode()
177176
def convert_hf_checkpoint_to_tune(
178177
*,
179178
model_dir: Optional[Path] = None,

0 commit comments

Comments
 (0)