Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 4a977a5

Browse files
authored
Merge branch 'main' into pinbump1111
2 parents 7aa96d7 + 4697764 commit 4a977a5

File tree

7 files changed

+69
-123
lines changed

7 files changed

+69
-123
lines changed

install/install_requirements.sh

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,21 @@ then
1414
if [[ -z ${CONDA_DEFAULT_ENV:-} ]] || [[ ${CONDA_DEFAULT_ENV:-} == "base" ]] || [[ ! -x "$(command -v python)" ]];
1515
then
1616
PYTHON_EXECUTABLE=python3
17+
else
18+
PYTHON_EXECUTABLE=python
1719
fi
1820
fi
1921
echo "Using python executable: $PYTHON_EXECUTABLE"
2022

2123
PYTHON_SYS_VERSION="$($PYTHON_EXECUTABLE -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')")"
22-
# Check python version. Expect 3.10.x or 3.11.x
24+
# Check python version. Expect at least 3.10.x
2325
if ! $PYTHON_EXECUTABLE -c "
2426
import sys
25-
if sys.version_info < (3, 10) or sys.version_info >= (3, 12):
27+
if sys.version_info < (3, 10):
2628
sys.exit(1)
2729
";
2830
then
29-
echo "Python version must be 3.10.x or 3.11.x. Detected version: $PYTHON_SYS_VERSION"
31+
echo "Python version must be at least 3.10.x. Detected version: $PYTHON_SYS_VERSION"
3032
exit 1
3133
fi
3234

install/requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ snakeviz
1414
sentencepiece
1515
# numpy version range required by GGUF util
1616
numpy >= 1.17, < 2.0
17-
gguf
1817
blobfile
1918
tomli >= 1.1.0 ; python_version < "3.11"
2019
openai

tokenizer/base64.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#pragma once
2626

2727
#include <cassert>
28+
#include <cstdint>
2829
#include <string>
2930
#include <string_view>
3031

torchchat.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import argparse
88
import logging
9-
import subprocess
9+
import signal
1010
import sys
1111

1212
# MPS ops missing with Multimodal torchtune
@@ -25,7 +25,15 @@
2525
default_device = "cpu"
2626

2727

28+
def signal_handler(sig, frame):
29+
print("\nInterrupted by user. Bye!\n")
30+
sys.exit(0)
31+
32+
2833
if __name__ == "__main__":
34+
# Set the signal handler for SIGINT
35+
signal.signal(signal.SIGINT, signal_handler)
36+
2937
# Initialize the top-level parser
3038
parser = argparse.ArgumentParser(
3139
prog="torchchat",

torchchat/cli/builder.py

Lines changed: 0 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,6 @@
1616
import torch._inductor.config
1717
import torch.nn as nn
1818

19-
from torch.distributed.device_mesh import DeviceMesh
20-
from torch.distributed.elastic.multiprocessing.errors import record
21-
from torch.distributed.elastic.utils.distributed import get_free_port
22-
23-
from torchchat.distributed import launch_distributed, ParallelDims, parallelize_llama
24-
2519
from torchchat.model import Model, ModelArgs, ModelType
2620

2721
from torchchat.model_config.model_config import resolve_model_config
@@ -465,77 +459,11 @@ def _load_model_default(builder_args: BuilderArgs) -> Model:
465459
return model
466460

467461

468-
def _maybe_init_distributed(
469-
builder_args: BuilderArgs,
470-
) -> Tuple[Optional[DeviceMesh], Optional[ParallelDims]]:
471-
"""
472-
Initialize distributed related setups if the user specified
473-
using distributed inference. If not, this is a no-op.
474-
475-
Args:
476-
builder_args (:class:`BuilderArgs`):
477-
Command args for model building.
478-
Returns:
479-
Tuple[Optional[DeviceMesh], Optional[ParallelDims]]:
480-
- The first element is an optional DeviceMesh object,
481-
which which describes the mesh topology of devices for the DTensor.
482-
- The second element is an optional ParallelDims object,
483-
which represents the parallel dimensions configuration.
484-
"""
485-
if not builder_args.use_distributed:
486-
return None, None
487-
dist_config = "llama3_8B.toml" # TODO - integrate with chat cmd line
488-
489-
world_mesh, parallel_dims = launch_distributed(dist_config)
490-
491-
assert (
492-
world_mesh is not None and parallel_dims is not None
493-
), f"failed to launch distributed using {dist_config}"
494-
495-
return world_mesh, parallel_dims
496-
497-
498-
def _maybe_parallelize_model(
499-
model: nn.Module,
500-
builder_args: BuilderArgs,
501-
world_mesh: DeviceMesh,
502-
parallel_dims: ParallelDims,
503-
) -> nn.Module:
504-
"""
505-
We parallelize the module and load the distributed checkpoint to the model
506-
if the user specifies using distributed inference. If not, this is a no-op.
507-
508-
Args:
509-
model (:class:`nn.Module`):
510-
Module to be parallelized.
511-
builder_args (:class:`BuilderArgs`):
512-
Command args for model building.
513-
world_mesh (:class:`DeviceMesh`):
514-
Object which describes the mesh topology
515-
of devices for the DTensor.
516-
parallel_dims (:class:`ParallelDims`):
517-
Object which represents the parallel dimensions configuration.
518-
Returns:
519-
A :class:`nn.Module` object which is parallelized and checkpoint loaded
520-
if the user specifies using distributed inference.
521-
"""
522-
if world_mesh is None:
523-
return model
524-
assert parallel_dims is not None
525-
print("Applying model parallel to model ...")
526-
parallelize_llama(model, world_mesh, parallel_dims)
527-
return load_checkpoints_to_model(model, builder_args, world_mesh)
528-
529-
530462
def _load_model(builder_args: BuilderArgs) -> Model:
531-
# world_mesh, parallel_dims = _maybe_init_distributed(builder_args)
532463
if builder_args.gguf_path:
533464
model = _load_model_gguf(builder_args)
534-
# elif builder_args.use_distributed:
535-
# model = _init_model_on_meta_device(builder_args)
536465
else:
537466
model = _load_model_default(builder_args)
538-
# model = _maybe_parallelize_model(model, builder_args, world_mesh, parallel_dims)
539467

540468
if builder_args.dso_path or builder_args.aoti_package_path:
541469
# AOTI-compoiled model will load its own weights.

torchchat/cli/cli.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
logger = logging.getLogger(__name__)
2222

2323
default_device = os.getenv("TORCHCHAT_DEVICE", "fast")
24+
default_dtype = os.getenv("TORCHCHAT_PRECISION", "fast")
25+
2426
default_model_dir = Path(
2527
os.getenv("TORCHCHAT_MODELDIR", "~/.torchchat/model-cache")
2628
).expanduser()
@@ -149,9 +151,9 @@ def _add_model_config_args(parser, verb: str) -> None:
149151

150152
model_config_parser.add_argument(
151153
"--dtype",
152-
default="fast",
154+
default=None,
153155
choices=allowable_dtype_names(),
154-
help="Override the dtype of the model (default is the checkpoint dtype). Options: bf16, fp16, fp32, fast16, fast",
156+
help="Override the dtype of the model. Options: bf16, fp16, fp32, fast16, fast",
155157
)
156158
model_config_parser.add_argument(
157159
"--quantize",
@@ -165,9 +167,9 @@ def _add_model_config_args(parser, verb: str) -> None:
165167
model_config_parser.add_argument(
166168
"--device",
167169
type=str,
168-
default=default_device,
170+
default=None,
169171
choices=["fast", "cpu", "cuda", "mps"],
170-
help="Hardware device to use. Options: cpu, cuda, mps",
172+
help="Hardware device to use. Options: fast, cpu, cuda, mps",
171173
)
172174

173175

@@ -513,20 +515,34 @@ def arg_init(args):
513515
if isinstance(args.quantize, str):
514516
args.quantize = json.loads(args.quantize)
515517

516-
# if we specify dtype in quantization recipe, replicate it as args.dtype
517-
args.dtype = args.quantize.get("precision", {}).get("dtype", args.dtype)
518+
# if we specify dtype in quantization recipe, allow args.dtype top override if specified
519+
if args.dtype is None:
520+
args.dtype = args.quantize.get("precision", {}).get("dtype", default_dtype)
521+
else:
522+
precision_handler = args.quantize.get("precision", None)
523+
if precision_handler:
524+
if precision_handler["dtype"] != args.dtype:
525+
print('overriding json-specified dtype {precision_handler["dtype"]} with cli dtype {args.dtype}')
526+
precision_handler["dtype"] = args.dtype
518527

519528
if getattr(args, "output_pte_path", None):
520-
if args.device not in ["cpu", "fast"]:
529+
if args.device not in [None, "cpu", "fast"]:
521530
raise RuntimeError("Device not supported by ExecuTorch")
522531
args.device = "cpu"
523532
else:
524533
# Localized import to minimize expensive imports
525534
from torchchat.utils.build_utils import get_device_str
526535

527-
args.device = get_device_str(
528-
args.quantize.get("executor", {}).get("accelerator", args.device)
529-
)
536+
if args.device is None or args.device == "fast":
537+
args.device = get_device_str(
538+
args.quantize.get("executor", {}).get("accelerator", default_device)
539+
)
540+
else:
541+
executor_handler = args.quantize.get("executor", None)
542+
if executor_handler:
543+
if executor_handler["accelerator"] != args.device:
544+
print('overriding json-specified device {executor_handler["accelerator"]} with cli device {args.device}')
545+
executor_handler["accelerator"] = args.device
530546

531547
if "mps" in args.device:
532548
if getattr(args, "compile", False) or getattr(args, "compile_prefill", False):

torchchat/generate.py

Lines changed: 28 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,7 @@ def generate(
591591
Dict[str, Any]
592592
] = None, # List of Image prompt tensors for multimodal models
593593
start_pos: int = 0,
594+
skip_cache_setup: bool = False,
594595
draft_model: Model,
595596
speculate_k: Optional[int] = 8,
596597
sequential_prefill=True,
@@ -614,26 +615,27 @@ def generate(
614615
max_new_tokens = min(max_new_tokens, max_seq_length - start_pos - prompt_length)
615616
# set up caches only if first inference
616617
if start_pos == 0:
617-
model = model.to(device=device)
618-
with torch.device(device):
619-
if (
620-
self.is_torchtune_model
621-
or self.model.config.model_type == ModelType.Flamingo
622-
):
623-
# 6404 is one-gpu affordable max_seq_length for single image input
624-
model.setup_caches(
625-
batch_size=1,
626-
dtype=self.dtype,
627-
encoder_max_seq_len=6404,
628-
decoder_max_seq_len=max_seq_length,
629-
)
630-
else:
631-
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
632-
if is_speculative and draft_model is not model:
633-
draft_model.setup_caches(
634-
max_batch_size=1,
635-
max_seq_length=max_seq_length,
636-
)
618+
if not skip_cache_setup:
619+
model = model.to(device=device)
620+
with torch.device(device):
621+
if (
622+
self.is_torchtune_model
623+
or self.model.config.model_type == ModelType.Flamingo
624+
):
625+
# 6404 is one-gpu affordable max_seq_length for single image input
626+
model.setup_caches(
627+
batch_size=1,
628+
dtype=self.dtype,
629+
encoder_max_seq_len=6404,
630+
decoder_max_seq_len=max_seq_length,
631+
)
632+
else:
633+
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
634+
if is_speculative and draft_model is not model:
635+
draft_model.setup_caches(
636+
max_batch_size=1,
637+
max_seq_length=max_seq_length,
638+
)
637639
if model.config.model_type == ModelType.Flamingo:
638640
model.reset_caches()
639641

@@ -915,13 +917,6 @@ def chat(
915917
]
916918
)
917919
if generator_args.compile:
918-
if (
919-
self.is_speculative and self.builder_args.use_distributed
920-
): # and ("cuda" in builder_args.device):
921-
torch._inductor.config.triton.cudagraph_trees = (
922-
False # Bug with cudagraph trees in this case
923-
)
924-
925920
if self.builder_args.device == "cpu":
926921
if generator_args.max_autotune:
927922
kwargs = {"mode": "max-autotune"}
@@ -1020,6 +1015,7 @@ def chat(
10201015
)
10211016
for i in range(num_samples):
10221017
device_sync(device=self.builder_args.device)
1018+
is_first_sample: bool = i == 0
10231019
if generator_args.chat_mode:
10241020
prompt = input("User: ")
10251021
if prompt == "/bye":
@@ -1045,7 +1041,7 @@ def chat(
10451041
]
10461042
)
10471043
self.system_prompt = None
1048-
elif i == 0:
1044+
elif is_first_sample:
10491045
encoded = self.chat_formatter.encode_dialog_prompt(
10501046
[{"role": "user", "content": prompt}]
10511047
)
@@ -1091,9 +1087,7 @@ def callback(x, *, done_generating=False):
10911087

10921088
torch._inductor.config.profiler_mark_wrapper_call = True
10931089
torch._inductor.config.cpp.enable_kernel_profile = True
1094-
if (i != generator_args.num_samples - 1 or not self.profile) or (
1095-
self.builder_args.use_distributed and self.rank != 0
1096-
):
1090+
if i != generator_args.num_samples - 1 or not self.profile:
10971091
import contextlib
10981092

10991093
prof = contextlib.nullcontext()
@@ -1116,6 +1110,7 @@ def callback(x, *, done_generating=False):
11161110
top_k=generator_args.top_k,
11171111
sequential_prefill=generator_args.sequential_prefill,
11181112
start_pos=start_pos,
1113+
skip_cache_setup=not is_first_sample,
11191114
max_seq_length=max_seq_length,
11201115
)
11211116
for token_tensor, metrics in generator_func:
@@ -1125,7 +1120,7 @@ def callback(x, *, done_generating=False):
11251120
if metrics is not None:
11261121
aggregate_metrics.update(metrics)
11271122
yield token_tensor, metrics
1128-
jit_compile = (i == 0) and (
1123+
jit_compile = is_first_sample and (
11291124
generator_args.compile or generator_args.compile_prefill
11301125
)
11311126
compilation_time = time.perf_counter() - t0
@@ -1136,10 +1131,7 @@ def callback(x, *, done_generating=False):
11361131
print(prof.key_averages().table(sort_by="self_cpu_time_total"))
11371132
else:
11381133
print(prof.key_averages().table(sort_by="self_cuda_time_total"))
1139-
if self.builder_args.use_distributed:
1140-
prof.export_chrome_trace(f"{self.profile}_rank_{self.rank}.json")
1141-
else:
1142-
prof.export_chrome_trace(f"{self.profile}.json")
1134+
prof.export_chrome_trace(f"{self.profile}.json")
11431135

11441136
if start_pos >= max_seq_length:
11451137
print(

0 commit comments

Comments
 (0)