Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 7047d79

Browse files
authored
Merge branch 'main' into skip-setup
2 parents 0163f61 + 93f713f commit 7047d79

File tree

5 files changed

+39
-98
lines changed

5 files changed

+39
-98
lines changed

tokenizer/base64.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#pragma once
2626

2727
#include <cassert>
28+
#include <cstdint>
2829
#include <string>
2930
#include <string_view>
3031

torchchat.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import argparse
88
import logging
9-
import subprocess
9+
import signal
1010
import sys
1111

1212
# MPS ops missing with Multimodal torchtune
@@ -25,7 +25,15 @@
2525
default_device = "cpu"
2626

2727

28+
def signal_handler(sig, frame):
29+
print("\nInterrupted by user. Bye!\n")
30+
sys.exit(0)
31+
32+
2833
if __name__ == "__main__":
34+
# Set the signal handler for SIGINT
35+
signal.signal(signal.SIGINT, signal_handler)
36+
2937
# Initialize the top-level parser
3038
parser = argparse.ArgumentParser(
3139
prog="torchchat",

torchchat/cli/builder.py

Lines changed: 1 addition & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,6 @@
1616
import torch._inductor.config
1717
import torch.nn as nn
1818

19-
from torch.distributed.device_mesh import DeviceMesh
20-
from torch.distributed.elastic.multiprocessing.errors import record
21-
from torch.distributed.elastic.utils.distributed import get_free_port
22-
23-
from torchchat.distributed import launch_distributed, ParallelDims, parallelize_llama
24-
2519
from torchchat.model import Model, ModelArgs, ModelType
2620

2721
from torchchat.model_config.model_config import resolve_model_config
@@ -464,77 +458,11 @@ def _load_model_default(builder_args: BuilderArgs) -> Model:
464458
return model
465459

466460

467-
def _maybe_init_distributed(
468-
builder_args: BuilderArgs,
469-
) -> Tuple[Optional[DeviceMesh], Optional[ParallelDims]]:
470-
"""
471-
Initialize distributed related setups if the user specified
472-
using distributed inference. If not, this is a no-op.
473-
474-
Args:
475-
builder_args (:class:`BuilderArgs`):
476-
Command args for model building.
477-
Returns:
478-
Tuple[Optional[DeviceMesh], Optional[ParallelDims]]:
479-
- The first element is an optional DeviceMesh object,
480-
which which describes the mesh topology of devices for the DTensor.
481-
- The second element is an optional ParallelDims object,
482-
which represents the parallel dimensions configuration.
483-
"""
484-
if not builder_args.use_distributed:
485-
return None, None
486-
dist_config = "llama3_8B.toml" # TODO - integrate with chat cmd line
487-
488-
world_mesh, parallel_dims = launch_distributed(dist_config)
489-
490-
assert (
491-
world_mesh is not None and parallel_dims is not None
492-
), f"failed to launch distributed using {dist_config}"
493-
494-
return world_mesh, parallel_dims
495-
496-
497-
def _maybe_parallelize_model(
498-
model: nn.Module,
499-
builder_args: BuilderArgs,
500-
world_mesh: DeviceMesh,
501-
parallel_dims: ParallelDims,
502-
) -> nn.Module:
503-
"""
504-
We parallelize the module and load the distributed checkpoint to the model
505-
if the user specifies using distributed inference. If not, this is a no-op.
506-
507-
Args:
508-
model (:class:`nn.Module`):
509-
Module to be parallelized.
510-
builder_args (:class:`BuilderArgs`):
511-
Command args for model building.
512-
world_mesh (:class:`DeviceMesh`):
513-
Object which describes the mesh topology
514-
of devices for the DTensor.
515-
parallel_dims (:class:`ParallelDims`):
516-
Object which represents the parallel dimensions configuration.
517-
Returns:
518-
A :class:`nn.Module` object which is parallelized and checkpoint loaded
519-
if the user specifies using distributed inference.
520-
"""
521-
if world_mesh is None:
522-
return model
523-
assert parallel_dims is not None
524-
print("Applying model parallel to model ...")
525-
parallelize_llama(model, world_mesh, parallel_dims)
526-
return load_checkpoints_to_model(model, builder_args, world_mesh)
527-
528-
529461
def _load_model(builder_args: BuilderArgs) -> Model:
530-
# world_mesh, parallel_dims = _maybe_init_distributed(builder_args)
531462
if builder_args.gguf_path:
532463
model = _load_model_gguf(builder_args)
533-
# elif builder_args.use_distributed:
534-
# model = _init_model_on_meta_device(builder_args)
535464
else:
536465
model = _load_model_default(builder_args)
537-
# model = _maybe_parallelize_model(model, builder_args, world_mesh, parallel_dims)
538466

539467
if builder_args.dso_path or builder_args.aoti_package_path:
540468
# AOTI-compoiled model will load its own weights.
@@ -706,4 +634,4 @@ def tokenizer_setting_to_name(tiktoken: bool, tokenizers: bool) -> str:
706634
return "TikToken"
707635
if tokenizers:
708636
return "Tokenizers"
709-
return "SentencePiece"
637+
return "SentencePiece"

torchchat/cli/cli.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
logger = logging.getLogger(__name__)
2222

2323
default_device = os.getenv("TORCHCHAT_DEVICE", "fast")
24+
default_dtype = os.getenv("TORCHCHAT_PRECISION", "fast")
25+
2426
default_model_dir = Path(
2527
os.getenv("TORCHCHAT_MODELDIR", "~/.torchchat/model-cache")
2628
).expanduser()
@@ -149,9 +151,9 @@ def _add_model_config_args(parser, verb: str) -> None:
149151

150152
model_config_parser.add_argument(
151153
"--dtype",
152-
default="fast",
154+
default=None,
153155
choices=allowable_dtype_names(),
154-
help="Override the dtype of the model (default is the checkpoint dtype). Options: bf16, fp16, fp32, fast16, fast",
156+
help="Override the dtype of the model. Options: bf16, fp16, fp32, fast16, fast",
155157
)
156158
model_config_parser.add_argument(
157159
"--quantize",
@@ -165,9 +167,9 @@ def _add_model_config_args(parser, verb: str) -> None:
165167
model_config_parser.add_argument(
166168
"--device",
167169
type=str,
168-
default=default_device,
170+
default=None,
169171
choices=["fast", "cpu", "cuda", "mps"],
170-
help="Hardware device to use. Options: cpu, cuda, mps",
172+
help="Hardware device to use. Options: fast, cpu, cuda, mps",
171173
)
172174

173175

@@ -513,20 +515,34 @@ def arg_init(args):
513515
if isinstance(args.quantize, str):
514516
args.quantize = json.loads(args.quantize)
515517

516-
# if we specify dtype in quantization recipe, replicate it as args.dtype
517-
args.dtype = args.quantize.get("precision", {}).get("dtype", args.dtype)
518+
# if we specify dtype in quantization recipe, allow args.dtype top override if specified
519+
if args.dtype is None:
520+
args.dtype = args.quantize.get("precision", {}).get("dtype", default_dtype)
521+
else:
522+
precision_handler = args.quantize.get("precision", None)
523+
if precision_handler:
524+
if precision_handler["dtype"] != args.dtype:
525+
print('overriding json-specified dtype {precision_handler["dtype"]} with cli dtype {args.dtype}')
526+
precision_handler["dtype"] = args.dtype
518527

519528
if getattr(args, "output_pte_path", None):
520-
if args.device not in ["cpu", "fast"]:
529+
if args.device not in [None, "cpu", "fast"]:
521530
raise RuntimeError("Device not supported by ExecuTorch")
522531
args.device = "cpu"
523532
else:
524533
# Localized import to minimize expensive imports
525534
from torchchat.utils.build_utils import get_device_str
526535

527-
args.device = get_device_str(
528-
args.quantize.get("executor", {}).get("accelerator", args.device)
529-
)
536+
if args.device is None:
537+
args.device = get_device_str(
538+
args.quantize.get("executor", {}).get("accelerator", default_device)
539+
)
540+
else:
541+
executor_handler = args.quantize.get("executor", None)
542+
if executor_handler:
543+
if executor_handler["accelerator"] != args.device:
544+
print('overriding json-specified device {executor_handler["accelerator"]} with cli device {args.device}')
545+
executor_handler["accelerator"] = args.device
530546

531547
if "mps" in args.device:
532548
if getattr(args, "compile", False) or getattr(args, "compile_prefill", False):

torchchat/generate.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -917,13 +917,6 @@ def chat(
917917
]
918918
)
919919
if generator_args.compile:
920-
if (
921-
self.is_speculative and self.builder_args.use_distributed
922-
): # and ("cuda" in builder_args.device):
923-
torch._inductor.config.triton.cudagraph_trees = (
924-
False # Bug with cudagraph trees in this case
925-
)
926-
927920
if self.builder_args.device == "cpu":
928921
if generator_args.max_autotune:
929922
kwargs = {"mode": "max-autotune"}
@@ -1094,9 +1087,7 @@ def callback(x, *, done_generating=False):
10941087

10951088
torch._inductor.config.profiler_mark_wrapper_call = True
10961089
torch._inductor.config.cpp.enable_kernel_profile = True
1097-
if (i != generator_args.num_samples - 1 or not self.profile) or (
1098-
self.builder_args.use_distributed and self.rank != 0
1099-
):
1090+
if i != generator_args.num_samples - 1 or not self.profile:
11001091
import contextlib
11011092

11021093
prof = contextlib.nullcontext()
@@ -1140,10 +1131,7 @@ def callback(x, *, done_generating=False):
11401131
print(prof.key_averages().table(sort_by="self_cpu_time_total"))
11411132
else:
11421133
print(prof.key_averages().table(sort_by="self_cuda_time_total"))
1143-
if self.builder_args.use_distributed:
1144-
prof.export_chrome_trace(f"{self.profile}_rank_{self.rank}.json")
1145-
else:
1146-
prof.export_chrome_trace(f"{self.profile}.json")
1134+
prof.export_chrome_trace(f"{self.profile}.json")
11471135

11481136
if start_pos >= max_seq_length:
11491137
print(

0 commit comments

Comments
 (0)