Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit f7aef8a

Browse files
authored
🦙🦙 Llama3.2 Release: Support for 1B, 3B, and 11B (Multimodal) (#1204)
1 parent 0b8ca05 commit f7aef8a

File tree

14 files changed

+272
-21
lines changed

14 files changed

+272
-21
lines changed

README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22

33
torchchat is a small codebase showcasing the ability to run large language models (LLMs) seamlessly. With torchchat, you can run LLMs using Python, within your own (C/C++) application (desktop or server) and on iOS and Android.
44

5+
> [!IMPORTANT]
6+
> Update September 25, 2024: torchchat has multimodal support for **Llama3.2 11B**!!
7+
>
8+
> To try it out, finish the [Installation](#Installation) section below, then hop
9+
> over to our [multimodal guide](docs/multimodal.md) to learn more.
10+
511

612
## What can you do with torchchat?
713
- [Run models via PyTorch / Python](#running-via-pytorch--python)
@@ -18,6 +24,7 @@ torchchat is a small codebase showcasing the ability to run large language model
1824

1925

2026
## Highlights
27+
2128
- Command line interaction with popular LLMs such as Llama 3, Llama 2, Stories, Mistral and more
2229
- PyTorch-native execution with performance
2330
- Supports popular hardware and OS
@@ -514,6 +521,13 @@ aliases.
514521

515522
| Model | Mobile Friendly | Notes |
516523
|------------------|---|---------------------|
524+
|[meta-llama/Meta-Llama-3.2-3B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct)||Tuned for `chat` . Alias to `llama3.2-3b`.|
525+
|[meta-llama/Meta-Llama-3.2-3B](https://huggingface.co/meta-llama/Llama-3.2-3B)||Best for `generate`. Alias to `llama3.2-3b-base`.|
526+
|[meta-llama/Llama-Guard-3-1B](https://huggingface.co/meta-llama/Llama-Guard-3-1B)||Tuned for classification . Alias to `llama3-1b-guard`.|
527+
|[meta-llama/Meta-Llama-3.2-1B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct)||Tuned for `chat` . Alias to `llama3.2-1b`.|
528+
|[meta-llama/Meta-Llama-3.2-1B](https://huggingface.co/meta-llama/Llama-3.2-1B)||Best for `generate`. Alias to `llama3.2-1b-base`.|
529+
|[meta-llama/Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct)||Multimodal (Image + Text). Tuned for `chat` . Alias to `llama3.2-11B`.|
530+
|[meta-llama/Llama-3.2-11B-Vision](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision)||Multimodal (Image + Text). Tuned for `generate` . Alias to `llama3.2-11B-base`.|
517531
|[meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct)||Tuned for `chat` . Alias to `llama3.1`.|
518532
|[meta-llama/Meta-Llama-3.1-8B](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B)||Best for `generate`. Alias to `llama3.1-base`.|
519533
|[meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)||Tuned for `chat` . Alias to `llama3`.|

assets/dog.jpg

43.6 KB
Loading

docs/multimodal.md

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Multimodal Models
2+
3+
Released on September 25th, 2024, **Llama3.2 11B Vision** is torchchat's first multimodal model.
4+
5+
This page goes over the different commands you can run with LLama 3.2 11B Vision.
6+
7+
## Model Access
8+
9+
> [!NOTE]
10+
> While the commands refer to the model as some variant of "Llama 3.2 11B Vision",
11+
> the underlying checkpoint used is based off the "Instruct" variant of the model.
12+
13+
**Llama3.2 11B Vision** is available via both [Hugging Face](https://huggingface.co/meta-llama) and [directly from Meta](https://www.llama.com/).
14+
15+
While we strongly encourage you to use the Hugging Face checkpoint (which is the default for torchchat when utilizing the commands with the argument `llama3.2-11B`), we also provide support for manually providing the checkpoint. This can be done by replacing the `llama3.2-11B` argument in the commands below with the following:
16+
17+
```
18+
--checkpoint-path <file.pth> --tokenizer-path <tokenizer.model> --params-path torchchat/model_params/Llama-3.2-11B-Vision.json
19+
```
20+
21+
## Generation
22+
23+
**We are currently debugging Multimodal Inference on MPS and will have updates soon. In the meantime, when testing on Mac, please set `--device cpu`**
24+
25+
This generates text output based on a text prompt and (optional) image prompt.
26+
27+
```
28+
python torchchat.py generate llama3.2-11B --prompt "What's in this image?" --image-prompt assets/dog.jpg
29+
```
30+
31+
## Server
32+
This mode exposes a REST API for interacting with a model.
33+
The server follows the [OpenAI API specification](https://platform.openai.com/docs/api-reference/chat) for chat completions.
34+
35+
To test out the REST API, **you'll need 2 terminals**: one to host the server, and one to send the request.
36+
In one terminal, start the server
37+
38+
[skip default]: begin
39+
40+
```bash
41+
python3 torchchat.py server llama3.2-11B
42+
```
43+
[skip default]: end
44+
45+
In another terminal, query the server using `curl`. This query might take a few minutes to respond.
46+
47+
**We are currently debugging the server integration and will have updated examples shortly.**
48+
49+
## Browser
50+
51+
This command opens a basic browser interface for local chat by querying a local server.
52+
53+
First, follow the steps in the Server section above to start a local server. Then, in another terminal, launch the interface. Running the following will open a tab in your browser.
54+
55+
[skip default]: begin
56+
57+
```
58+
streamlit run torchchat/usages/browser.py
59+
```
60+
61+
**We are currently debugging the browser integration and will have updated examples shortly.**
62+
63+
---
64+
65+
# Future Work
66+
67+
One of the goals of torchchat is to support various execution modes for every model. The following are execution modes that will be supported for **Llama3.2 11B Vision** in the near future:
68+
69+
- **[torch.compile](https://pytorch.org/docs/stable/torch.compiler.html)**: Optimize inference via JIT Compilation
70+
- **[AOTI](https://pytorch.org/blog/pytorch2-2/)**: Enable pre-compiled and C++ inference
71+
- **[ExecuTorch](https://github.com/pytorch/executorch)**: On-device (Edge) inference
72+
73+
In addition, we are in the process of integrating with [lm_evaluation_harness](https://github.com/EleutherAI/lm-evaluation-harness) for multimodal model evaluation.

torchchat/cli/builder.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,7 @@
1616
import torch._inductor.config
1717
import torch.nn as nn
1818

19-
try:
20-
from _torchchat_test_script import flamingo_meta_to_tune
21-
except ImportError:
22-
pass
19+
from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune
2320

2421
from distributed import launch_distributed, ParallelDims, parallelize_llama
2522

@@ -404,7 +401,7 @@ def _load_model_default(builder_args: BuilderArgs) -> Model:
404401
for submodule in model.modules():
405402
if isinstance(submodule, Llama3ScaledRoPE):
406403
submodule.__init__(head_dim, max_seq_len, rope_base)
407-
state_dict = flamingo_meta_to_tune(checkpoint)
404+
state_dict = llama3_vision_meta_to_tune(checkpoint)
408405
model.model.load_state_dict(state_dict, assign=True, strict=False)
409406
else:
410407
checkpoint = {"model." + k: v for k, v in checkpoint.items()}

torchchat/cli/convert_hf_checkpoint.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ def convert_hf_checkpoint(
3434
if model_name is None:
3535
model_name = model_dir.name
3636

37+
# TODO: This is an incongruent way of resolving config_args
38+
# See https://github.com/pytorch/torchchat/issues/1179
3739
config_args = ModelArgs.from_name(model_name).transformer_args['text']
3840
config = TransformerArgs.from_params(config_args)
3941
print(f"Model config {config.__dict__}")
@@ -132,6 +134,26 @@ def permute(w, n_heads):
132134
os.remove(file)
133135

134136

137+
@torch.inference_mode()
138+
def convert_hf_checkpoint_to_tune(
139+
*,
140+
model_dir: Optional[Path] = None,
141+
model_name: str,
142+
) -> None:
143+
assert model_dir is not None
144+
145+
consolidated_pth = model_dir / "original" / "consolidated.pth"
146+
tokenizer_pth = model_dir / "original" / "tokenizer.model"
147+
if consolidated_pth.is_file() and tokenizer_pth.is_file():
148+
print(f"Moving checkpoint to {model_dir / 'model.pth'}.")
149+
os.rename(consolidated_pth, model_dir / "model.pth")
150+
print(f"Moving tokenizer to {model_dir / 'tokenizer.model'}.")
151+
os.rename(tokenizer_pth, model_dir / "tokenizer.model")
152+
print("Done.")
153+
else:
154+
raise RuntimeError(f"Could not find {consolidated_pth}")
155+
156+
135157
if __name__ == "__main__":
136158
import argparse
137159

torchchat/cli/download.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from pathlib import Path
1111
from typing import Optional
1212

13-
from torchchat.cli.convert_hf_checkpoint import convert_hf_checkpoint
13+
from torchchat.cli.convert_hf_checkpoint import convert_hf_checkpoint, convert_hf_checkpoint_to_tune
1414
from torchchat.model_config.model_config import (
1515
load_model_configs,
1616
ModelConfig,
@@ -50,11 +50,17 @@ def _download_hf_snapshot(
5050
else:
5151
raise e
5252

53-
# Convert the model to the torchchat format.
54-
print(f"Converting {model_config.name} to torchchat format...", file=sys.stderr)
55-
convert_hf_checkpoint(
56-
model_dir=artifact_dir, model_name=model_config.name, remove_bin_files=True
57-
)
53+
# Convert the Multimodal Llama model to the torchtune format.
54+
if model_config.name in {"meta-llama/Llama-3.2-11B-Vision-Instruct", "meta-llama/Llama-3.2-11B-Vision"}:
55+
print(f"Converting {model_config.name} to torchtune format...", file=sys.stderr)
56+
convert_hf_checkpoint_to_tune( model_dir=artifact_dir, model_name=model_config.name)
57+
58+
else:
59+
# Convert the model to the torchchat format.
60+
print(f"Converting {model_config.name} to torchchat format...", file=sys.stderr)
61+
convert_hf_checkpoint(
62+
model_dir=artifact_dir, model_name=model_config.name, remove_bin_files=True
63+
)
5864

5965

6066
def _download_direct(

torchchat/generate.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,7 @@
2020
import torch._dynamo.config
2121
import torch._inductor.config
2222

23-
try:
24-
from _torchchat_test_script import flamingo_transform
25-
except ImportError:
26-
pass
23+
from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform
2724

2825
from PIL import Image
2926

@@ -753,7 +750,7 @@ def chat(
753750
Message(role="assistant", content=""),
754751
]
755752

756-
transform = flamingo_transform(str(self.tokenizer_args.tokenizer_path))
753+
transform = llama3_2_vision_transform(str(self.tokenizer_args.tokenizer_path))
757754

758755
with torch.device(device=self.builder_args.device), set_default_dtype(self.dtype):
759756
data = transform({"messages": messages}, inference=True)

torchchat/model_config/models.json

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,44 @@
6969
"distribution_path": "meta-llama/Meta-Llama-3.1-70B-Instruct",
7070
"transformer_params_key": "Meta-Llama-3.1-70B-Tune"
7171
},
72+
"meta-llama/Meta-Llama-3.2-1B": {
73+
"aliases": ["llama3.2-1b-base"],
74+
"distribution_channel": "HuggingFaceSnapshot",
75+
"distribution_path": "meta-llama/Llama-3.2-1B"
76+
},
77+
"meta-llama/Meta-Llama-3.2-1B-Instruct": {
78+
"aliases": ["llama3.2-1b", "llama3.2-1b-chat", "llama3.2-1b-instruct"],
79+
"distribution_channel": "HuggingFaceSnapshot",
80+
"distribution_path": "meta-llama/Llama-3.2-1B-Instruct",
81+
"transformer_params_key": "Meta-Llama-3.2-1B"
82+
},
83+
"meta-llama/Llama-Guard-3-1B": {
84+
"aliases": ["llama3-1b-guard", "llama3.2-1b-guard"],
85+
"distribution_channel": "HuggingFaceSnapshot",
86+
"distribution_path": "meta-llama/Llama-Guard-3-1B"
87+
},
88+
"meta-llama/Meta-Llama-3.2-3B": {
89+
"aliases": ["llama3.2-3b-base"],
90+
"distribution_channel": "HuggingFaceSnapshot",
91+
"distribution_path": "meta-llama/Llama-3.2-3B"
92+
},
93+
"meta-llama/Meta-Llama-3.2-3B-Instruct": {
94+
"aliases": ["llama3.2-3b", "llama3.2-3b-chat", "llama3.2-3b-instruct"],
95+
"distribution_channel": "HuggingFaceSnapshot",
96+
"distribution_path": "meta-llama/Llama-3.2-3B-Instruct",
97+
"transformer_params_key": "Meta-Llama-3.2-3B"
98+
},
99+
"meta-llama/Llama-3.2-11B-Vision": {
100+
"aliases": ["llama3.2-11B-base", "Llama-3.2-11B-Vision-base"],
101+
"distribution_channel": "HuggingFaceSnapshot",
102+
"distribution_path": "meta-llama/Llama-3.2-11B-Vision"
103+
},
104+
"meta-llama/Llama-3.2-11B-Vision-Instruct": {
105+
"aliases": ["llama3.2-11B", "Llama-3.2-11B-Vision", "Llama-3.2-mm"],
106+
"distribution_channel": "HuggingFaceSnapshot",
107+
"distribution_path": "meta-llama/Llama-3.2-11B-Vision-Instruct",
108+
"transformer_params_key": "Llama-3.2-11B-Vision"
109+
},
72110
"meta-llama/CodeLlama-7b-Python-hf": {
73111
"aliases": ["codellama", "codellama-7b"],
74112
"distribution_channel": "HuggingFaceSnapshot",
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
{
2+
"model_type": "flamingo",
3+
"use_tiktoken": true,
4+
"encoder": {
5+
"patch_size": 14,
6+
"num_heads": 16,
7+
"clip_embed_dim": 1280,
8+
"clip_num_layers": 32,
9+
"clip_hidden_states": [3, 7, 15, 23, 30],
10+
"decoder_embed_dim": 4096,
11+
"num_layers_projection": 8,
12+
"tile_size": 560,
13+
"max_num_tiles": 4,
14+
"in_channels": 3
15+
},
16+
"decoder": {
17+
"vocab_size": 128256,
18+
"num_layers": 32,
19+
"fusion_interval": 4,
20+
"num_special_tokens": 8,
21+
"num_heads": 32,
22+
"num_kv_heads": 8,
23+
"embed_dim": 4096,
24+
"max_seq_len": 131072,
25+
"encoder_max_seq_len": 128080,
26+
"rope_base": 500000.0,
27+
"intermediate_dim": 14336
28+
}
29+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
{
2+
"block_size": 131072,
3+
"dim": 2048,
4+
"hidden_dim": 6400,
5+
"n_layers": 12,
6+
"n_heads": 32,
7+
"n_kv_heads": 8,
8+
"vocab_size": 128256,
9+
"ffn_dim_multiplier": 1.5,
10+
"multiple_of": 256,
11+
"norm_eps": 1e-05,
12+
"rope_theta": 500000.0,
13+
"rope_scaling": {
14+
"factor": 32.0,
15+
"low_freq_factor": 1.0,
16+
"high_freq_factor": 4.0,
17+
"original_max_position_embeddings": 8192
18+
},
19+
"use_tiktoken": true
20+
}

0 commit comments

Comments
 (0)