Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 5113018

Browse files
committed
Merge branch 'main' of https://github.com/pytorch/torchchat into infil00p/missing_include
2 parents 0f7f442 + b809b69 commit 5113018

File tree

12 files changed

+222
-156
lines changed

12 files changed

+222
-156
lines changed

.ci/scripts/run-docs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,23 @@ if [ "$1" == "evaluation" ]; then
9191
echo "*******************************************"
9292
bash -x ./run-evaluation.sh
9393
fi
94+
95+
if [ "$1" == "multimodal" ]; then
96+
97+
# Expecting that this might fail this test as-is, because
98+
# it's the first on-pr test depending on githib secrets for access with HF token access
99+
100+
echo "::group::Create script to run multimodal"
101+
python3 torchchat/utils/scripts/updown.py --file docs/multimodal.md > ./run-multimodal.sh
102+
# for good measure, if something happened to updown processor,
103+
# and it did not error out, fail with an exit 1
104+
echo "exit 1" >> ./run-multimodal.sh
105+
echo "::endgroup::"
106+
107+
echo "::group::Run multimodal"
108+
echo "*******************************************"
109+
cat ./run-multimodal.sh
110+
echo "*******************************************"
111+
bash -x ./run-multimodal.sh
112+
echo "::endgroup::"
113+
fi

.github/workflows/run-readme-pr.yml

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,4 +243,47 @@ jobs:
243243
echo "::group::Completion"
244244
echo "tests complete"
245245
echo "*******************************************"
246-
echo "::endgroup::"
246+
echo "::endgroup::"
247+
248+
test-multimodal-any:
249+
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
250+
with:
251+
runner: linux.g5.4xlarge.nvidia.gpu
252+
gpu-arch-type: cuda
253+
gpu-arch-version: "12.1"
254+
timeout: 60
255+
script: |
256+
echo "::group::Print machine info"
257+
uname -a
258+
echo "::endgroup::"
259+
260+
echo "::group::Install newer objcopy that supports --set-section-alignment"
261+
yum install -y devtoolset-10-binutils
262+
export PATH=/opt/rh/devtoolset-10/root/usr/bin/:$PATH
263+
echo "::endgroup::"
264+
265+
.ci/scripts/run-docs multimodal
266+
267+
echo "::group::Completion"
268+
echo "tests complete"
269+
echo "*******************************************"
270+
echo "::endgroup::"
271+
272+
test-multimodal-cpu:
273+
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
274+
with:
275+
runner: linux.g5.4xlarge.nvidia.gpu
276+
gpu-arch-type: cuda
277+
gpu-arch-version: "12.1"
278+
timeout: 60
279+
script: |
280+
echo "::group::Print machine info"
281+
uname -a
282+
echo "::endgroup::"
283+
284+
echo "::group::Install newer objcopy that supports --set-section-alignment"
285+
yum install -y devtoolset-10-binutils
286+
export PATH=/opt/rh/devtoolset-10/root/usr/bin/:$PATH
287+
echo "::endgroup::"
288+
289+
TORCHCHAT_DEVICE=cpu .ci/scripts/run-docs multimodal

docs/multimodal.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@ This page goes over the different commands you can run with LLama 3.2 11B Vision
1414

1515
While we strongly encourage you to use the Hugging Face checkpoint (which is the default for torchchat when utilizing the commands with the argument `llama3.2-11B`), we also provide support for manually providing the checkpoint. This can be done by replacing the `llama3.2-11B` argument in the commands below with the following:
1616

17+
[skip default]: begin
1718
```
1819
--checkpoint-path <file.pth> --tokenizer-path <tokenizer.model> --params-path torchchat/model_params/Llama-3.2-11B-Vision.json
1920
```
21+
[skip default]: end
2022

2123
## Generation
2224
This generates text output based on a text prompt and (optional) image prompt.
@@ -48,6 +50,7 @@ Setting `stream` to "true" in the request emits a response in chunks. If `stream
4850

4951
**Example Input + Output**
5052

53+
[skip default]: begin
5154
```
5255
curl http://127.0.0.1:5000/v1/chat/completions \
5356
-H "Content-Type: application/json" \
@@ -75,6 +78,7 @@ curl http://127.0.0.1:5000/v1/chat/completions \
7578
```
7679
{"id": "chatcmpl-cb7b39af-a22e-4f71-94a8-17753fa0d00c", "choices": [{"message": {"role": "assistant", "content": "The image depicts a simple black and white cartoon-style drawing of an animal face. It features a profile view, complete with two ears, expressive eyes, and a partial snout. The animal looks to the left, with its eye and mouth implied, suggesting that the drawn face might belong to a rabbit, dog, or pig. The graphic face has a bold black outline and a smaller, solid black nose. A small circle, forming part of the face, has a white background with two black quirkly short and long curved lines forming an outline of what was likely a mouth, complete with two teeth. The presence of the curve lines give the impression that the animal is smiling or speaking. Grey and black shadows behind the right ear and mouth suggest that this face is looking left and upwards. Given the prominent outline of the head and the outline of the nose, it appears that the depicted face is most likely from the side profile of a pig, although the ears make it seem like a dog and the shape of the nose makes it seem like a rabbit. Overall, it seems that this image, possibly part of a character illustration, is conveying a playful or expressive mood through its design and positioning."}, "finish_reason": "stop"}], "created": 1727487574, "model": "llama3.2", "system_fingerprint": "cpu_torch.float16", "object": "chat.completion"}%
7780
```
81+
[skip default]: end
7882

7983
</details>
8084

@@ -90,6 +94,8 @@ First, follow the steps in the Server section above to start a local server. The
9094
streamlit run torchchat/usages/browser.py
9195
```
9296

97+
[skip default]: end
98+
9399
---
94100

95101
# Future Work

install/install_requirements.sh

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,26 +9,41 @@ set -eou pipefail
99

1010
# Install required python dependencies for developing
1111
# Dependencies are defined in .pyproject.toml
12-
PYTHON_EXECUTABLE=${PYTHON_EXECUTABLE:-python}
13-
if [[ -z ${CONDA_DEFAULT_ENV:-} ]] || [[ ${CONDA_DEFAULT_ENV:-} == "base" ]] || [[ ! -x "$(command -v python)" ]];
12+
if [ -z "${PYTHON_EXECUTABLE:-}" ];
1413
then
15-
PYTHON_EXECUTABLE=python3
14+
if [[ -z ${CONDA_DEFAULT_ENV:-} ]] || [[ ${CONDA_DEFAULT_ENV:-} == "base" ]] || [[ ! -x "$(command -v python)" ]];
15+
then
16+
PYTHON_EXECUTABLE=python3
17+
else
18+
PYTHON_EXECUTABLE=python
19+
fi
1620
fi
17-
18-
# Check python version. Expect 3.10.x or 3.11.x
19-
printf "import sys\nif sys.version_info.major != 3 or sys.version_info.minor < 10 :\n\tprint('Please use Python >=3.10');sys.exit(1)\n" | $PYTHON_EXECUTABLE
20-
if [[ $? -ne 0 ]]
21+
echo "Using python executable: $PYTHON_EXECUTABLE"
22+
23+
PYTHON_SYS_VERSION="$($PYTHON_EXECUTABLE -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')")"
24+
# Check python version. Expect at least 3.10.x
25+
if ! $PYTHON_EXECUTABLE -c "
26+
import sys
27+
if sys.version_info < (3, 10):
28+
sys.exit(1)
29+
";
2130
then
31+
echo "Python version must be at least 3.10.x. Detected version: $PYTHON_SYS_VERSION"
2232
exit 1
2333
fi
2434

2535
if [[ "$PYTHON_EXECUTABLE" == "python" ]];
2636
then
2737
PIP_EXECUTABLE=pip
28-
else
38+
elif [[ "$PYTHON_EXECUTABLE" == "python3" ]];
39+
then
2940
PIP_EXECUTABLE=pip3
41+
else
42+
PIP_EXECUTABLE=pip${PYTHON_SYS_VERSION}
3043
fi
3144

45+
echo "Using pip executable: $PIP_EXECUTABLE"
46+
3247
#
3348
# First install requirements in install/requirements.txt. Older torch may be
3449
# installed from the dependency of other models. It will be overridden by

install/requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ snakeviz
1414
sentencepiece
1515
# numpy version range required by GGUF util
1616
numpy >= 1.17, < 2.0
17-
gguf
1817
blobfile
1918
tomli >= 1.1.0 ; python_version < "3.11"
2019
openai

tokenizer/base64.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#pragma once
2626

2727
#include <cassert>
28+
#include <cstdint>
2829
#include <string>
2930
#include <string_view>
3031
#include <cstdint>

torchchat.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import argparse
88
import logging
9-
import subprocess
9+
import signal
1010
import sys
1111

1212
# MPS ops missing with Multimodal torchtune
@@ -25,7 +25,15 @@
2525
default_device = "cpu"
2626

2727

28+
def signal_handler(sig, frame):
29+
print("\nInterrupted by user. Bye!\n")
30+
sys.exit(0)
31+
32+
2833
if __name__ == "__main__":
34+
# Set the signal handler for SIGINT
35+
signal.signal(signal.SIGINT, signal_handler)
36+
2937
# Initialize the top-level parser
3038
parser = argparse.ArgumentParser(
3139
prog="torchchat",

torchchat/cli/builder.py

Lines changed: 3 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,6 @@
1616
import torch._inductor.config
1717
import torch.nn as nn
1818

19-
from torch.distributed.device_mesh import DeviceMesh
20-
from torch.distributed.elastic.multiprocessing.errors import record
21-
from torch.distributed.elastic.utils.distributed import get_free_port
22-
23-
from torchchat.distributed import launch_distributed, ParallelDims, parallelize_llama
24-
2519
from torchchat.model import Model, ModelArgs, ModelType
2620

2721
from torchchat.model_config.model_config import resolve_model_config
@@ -80,7 +74,7 @@ def __post_init__(self):
8074
or (self.pte_path and Path(self.pte_path).is_file())
8175
):
8276
raise RuntimeError(
83-
"need to specified a valid checkpoint path, checkpoint dir, gguf path, DSO path, or PTE path"
77+
"need to specify a valid checkpoint path, checkpoint dir, gguf path, DSO path, AOTI PACKAGE or PTE path"
8478
)
8579

8680
if self.aoti_package_path and self.pte_path:
@@ -97,7 +91,7 @@ def __post_init__(self):
9791
for param, param_msg in ignored_params:
9892
if param:
9993
print(
100-
f"Warning: {param_msg} ignored because an exported DSO or PTE path was specified"
94+
f"Warning: {param_msg} ignored because an exported model was specified using a DSO, AOTI PACKAGE or PTE path argument"
10195
)
10296
else:
10397
self.prefill_possible = True
@@ -464,77 +458,11 @@ def _load_model_default(builder_args: BuilderArgs) -> Model:
464458
return model
465459

466460

467-
def _maybe_init_distributed(
468-
builder_args: BuilderArgs,
469-
) -> Tuple[Optional[DeviceMesh], Optional[ParallelDims]]:
470-
"""
471-
Initialize distributed related setups if the user specified
472-
using distributed inference. If not, this is a no-op.
473-
474-
Args:
475-
builder_args (:class:`BuilderArgs`):
476-
Command args for model building.
477-
Returns:
478-
Tuple[Optional[DeviceMesh], Optional[ParallelDims]]:
479-
- The first element is an optional DeviceMesh object,
480-
which which describes the mesh topology of devices for the DTensor.
481-
- The second element is an optional ParallelDims object,
482-
which represents the parallel dimensions configuration.
483-
"""
484-
if not builder_args.use_distributed:
485-
return None, None
486-
dist_config = "llama3_8B.toml" # TODO - integrate with chat cmd line
487-
488-
world_mesh, parallel_dims = launch_distributed(dist_config)
489-
490-
assert (
491-
world_mesh is not None and parallel_dims is not None
492-
), f"failed to launch distributed using {dist_config}"
493-
494-
return world_mesh, parallel_dims
495-
496-
497-
def _maybe_parallelize_model(
498-
model: nn.Module,
499-
builder_args: BuilderArgs,
500-
world_mesh: DeviceMesh,
501-
parallel_dims: ParallelDims,
502-
) -> nn.Module:
503-
"""
504-
We parallelize the module and load the distributed checkpoint to the model
505-
if the user specifies using distributed inference. If not, this is a no-op.
506-
507-
Args:
508-
model (:class:`nn.Module`):
509-
Module to be parallelized.
510-
builder_args (:class:`BuilderArgs`):
511-
Command args for model building.
512-
world_mesh (:class:`DeviceMesh`):
513-
Object which describes the mesh topology
514-
of devices for the DTensor.
515-
parallel_dims (:class:`ParallelDims`):
516-
Object which represents the parallel dimensions configuration.
517-
Returns:
518-
A :class:`nn.Module` object which is parallelized and checkpoint loaded
519-
if the user specifies using distributed inference.
520-
"""
521-
if world_mesh is None:
522-
return model
523-
assert parallel_dims is not None
524-
print("Applying model parallel to model ...")
525-
parallelize_llama(model, world_mesh, parallel_dims)
526-
return load_checkpoints_to_model(model, builder_args, world_mesh)
527-
528-
529461
def _load_model(builder_args: BuilderArgs) -> Model:
530-
# world_mesh, parallel_dims = _maybe_init_distributed(builder_args)
531462
if builder_args.gguf_path:
532463
model = _load_model_gguf(builder_args)
533-
# elif builder_args.use_distributed:
534-
# model = _init_model_on_meta_device(builder_args)
535464
else:
536465
model = _load_model_default(builder_args)
537-
# model = _maybe_parallelize_model(model, builder_args, world_mesh, parallel_dims)
538466

539467
if builder_args.dso_path or builder_args.aoti_package_path:
540468
# AOTI-compoiled model will load its own weights.
@@ -706,4 +634,4 @@ def tokenizer_setting_to_name(tiktoken: bool, tokenizers: bool) -> str:
706634
return "TikToken"
707635
if tokenizers:
708636
return "Tokenizers"
709-
return "SentencePiece"
637+
return "SentencePiece"

torchchat/cli/cli.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
logger = logging.getLogger(__name__)
2222

2323
default_device = os.getenv("TORCHCHAT_DEVICE", "fast")
24+
default_dtype = os.getenv("TORCHCHAT_PRECISION", "fast")
25+
2426
default_model_dir = Path(
2527
os.getenv("TORCHCHAT_MODELDIR", "~/.torchchat/model-cache")
2628
).expanduser()
@@ -149,9 +151,9 @@ def _add_model_config_args(parser, verb: str) -> None:
149151

150152
model_config_parser.add_argument(
151153
"--dtype",
152-
default="fast",
154+
default=None,
153155
choices=allowable_dtype_names(),
154-
help="Override the dtype of the model (default is the checkpoint dtype). Options: bf16, fp16, fp32, fast16, fast",
156+
help="Override the dtype of the model. Options: bf16, fp16, fp32, fast16, fast",
155157
)
156158
model_config_parser.add_argument(
157159
"--quantize",
@@ -165,9 +167,9 @@ def _add_model_config_args(parser, verb: str) -> None:
165167
model_config_parser.add_argument(
166168
"--device",
167169
type=str,
168-
default=default_device,
170+
default=None,
169171
choices=["fast", "cpu", "cuda", "mps"],
170-
help="Hardware device to use. Options: cpu, cuda, mps",
172+
help="Hardware device to use. Options: fast, cpu, cuda, mps",
171173
)
172174

173175

@@ -513,20 +515,34 @@ def arg_init(args):
513515
if isinstance(args.quantize, str):
514516
args.quantize = json.loads(args.quantize)
515517

516-
# if we specify dtype in quantization recipe, replicate it as args.dtype
517-
args.dtype = args.quantize.get("precision", {}).get("dtype", args.dtype)
518+
# if we specify dtype in quantization recipe, allow args.dtype top override if specified
519+
if args.dtype is None:
520+
args.dtype = args.quantize.get("precision", {}).get("dtype", default_dtype)
521+
else:
522+
precision_handler = args.quantize.get("precision", None)
523+
if precision_handler:
524+
if precision_handler["dtype"] != args.dtype:
525+
print('overriding json-specified dtype {precision_handler["dtype"]} with cli dtype {args.dtype}')
526+
precision_handler["dtype"] = args.dtype
518527

519528
if getattr(args, "output_pte_path", None):
520-
if args.device not in ["cpu", "fast"]:
529+
if args.device not in [None, "cpu", "fast"]:
521530
raise RuntimeError("Device not supported by ExecuTorch")
522531
args.device = "cpu"
523532
else:
524533
# Localized import to minimize expensive imports
525534
from torchchat.utils.build_utils import get_device_str
526535

527-
args.device = get_device_str(
528-
args.quantize.get("executor", {}).get("accelerator", args.device)
529-
)
536+
if args.device is None:
537+
args.device = get_device_str(
538+
args.quantize.get("executor", {}).get("accelerator", default_device)
539+
)
540+
else:
541+
args.device = get_device_str(args.device)
542+
executor_handler = args.quantize.get("executor", None)
543+
if executor_handler and executor_handler["accelerator"] != args.device:
544+
print(f'overriding json-specified device {executor_handler["accelerator"]} with cli device {args.device}')
545+
executor_handler["accelerator"] = args.device
530546

531547
if "mps" in args.device:
532548
if getattr(args, "compile", False) or getattr(args, "compile_prefill", False):

0 commit comments

Comments
 (0)