Skip to content

Commit 6c944db

Browse files
authored
Runner changes for TorchTune Llama3.2 vision text decoder (pytorch#6610)
1 parent 27f31cd commit 6c944db

File tree

4 files changed

+71
-24
lines changed

4 files changed

+71
-24
lines changed

examples/models/llama/runner/eager.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@
99
from typing import Optional
1010

1111
import torch
12+
1213
from executorch.examples.models.llama.export_llama_lib import (
1314
_prepare_for_llama_export,
1415
build_args_parser as _build_args_parser,
16+
TORCHTUNE_DEFINED_MODELS,
1517
)
16-
from executorch.examples.models.llama.llama_transformer import ModelArgs
1718
from executorch.examples.models.llama.runner.generation import LlamaRunner
1819
from executorch.extension.llm.export.builder import LLMEdgeManager
1920

@@ -26,15 +27,13 @@ class EagerLlamaRunner(LlamaRunner):
2627
def __init__(self, args):
2728
with open(args.params, "r") as f:
2829
params = json.loads(f.read())
29-
model_args: ModelArgs = ModelArgs(
30+
super().__init__(
31+
tokenizer_path=args.tokenizer_path,
3032
max_seq_len=args.max_seq_length,
3133
max_batch_size=1,
3234
use_kv_cache=args.use_kv_cache,
33-
**params,
34-
)
35-
super().__init__(
36-
tokenizer_path=args.tokenizer_path,
37-
model_args=model_args,
35+
vocab_size=params["vocab_size"],
36+
has_full_logits=args.model in TORCHTUNE_DEFINED_MODELS,
3837
device="cuda" if torch.cuda.is_available() else "cpu",
3938
)
4039
manager: LLMEdgeManager = _prepare_for_llama_export(args)

examples/models/llama/runner/generation.py

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import torch
1111

12-
from executorch.examples.models.llama.llama_transformer import ModelArgs
1312
from executorch.extension.llm.tokenizer.utils import get_tokenizer
1413

1514

@@ -47,11 +46,35 @@ def next_token(logits: torch.Tensor, temperature: float, top_p: float) -> int:
4746

4847

4948
class LlamaRunner(ABC):
50-
def __init__(self, tokenizer_path: str, model_args: ModelArgs, device: str = "cpu"):
51-
self.params = model_args
49+
def __init__(
50+
self,
51+
tokenizer_path: str,
52+
max_seq_len: int,
53+
max_batch_size: int,
54+
use_kv_cache: bool,
55+
vocab_size: int,
56+
has_full_logits: bool = False,
57+
device: str = "cpu",
58+
):
59+
"""
60+
Constructor.
61+
62+
Args:
63+
tokenizer_path: path to tokenizer.model file.
64+
max_seq_len: max length of the output sequence, after which the output will be clipped.
65+
max_batch_size: max batch size.
66+
use_kv_cache: whether to use a KV cache.
67+
vocab_size: number of items in the vocab.
68+
has_full_logits: whether the model returns the full logits or only returns the last logit.
69+
device: device to run the runner on.
70+
"""
71+
self.max_seq_len = max_seq_len
72+
self.max_batch_size = max_batch_size
73+
self.use_kv_cache = use_kv_cache
5274
self.tokenizer = get_tokenizer(tokenizer_path)
53-
assert model_args.vocab_size == self.tokenizer.n_words
75+
self.has_full_logits = has_full_logits
5476
self.device = device
77+
assert vocab_size == self.tokenizer.n_words
5578

5679
@abstractmethod
5780
def forward(
@@ -75,17 +98,20 @@ def generate( # noqa: C901
7598
tokens=torch.tensor([prompt_tokens], dtype=torch.long, device=self.device),
7699
input_pos=(
77100
torch.tensor([pos_base], dtype=torch.long, device=self.device)
78-
if self.params.use_kv_cache
101+
if self.use_kv_cache
79102
else None
80103
),
81104
)
82105

83-
current_token = next_token(logits, temperature, top_p)
106+
if self.has_full_logits:
107+
current_token = next_token(logits[:, -1, :], temperature, top_p)
108+
else:
109+
current_token = next_token(logits, temperature, top_p)
84110
print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True)
85111
tokens = prompt_tokens + [current_token]
86112

87113
while len(tokens) < max_seq_len:
88-
if self.params.use_kv_cache:
114+
if self.use_kv_cache:
89115
logits = self.forward(
90116
tokens=torch.tensor(
91117
[[current_token]], dtype=torch.long, device=self.device
@@ -100,13 +126,20 @@ def generate( # noqa: C901
100126
logits = self.forward(
101127
tokens=torch.tensor([tokens], dtype=torch.long, device=self.device),
102128
)
103-
current_token = next_token(logits, temperature, top_p)
129+
130+
# If the logits aren't already clipped to only contain the last logit, clip them.
131+
if self.has_full_logits:
132+
current_token = next_token(logits[:, -1, :], temperature, top_p)
133+
else:
134+
current_token = next_token(logits, temperature, top_p)
104135
tokens.append(current_token)
136+
105137
if current_token == self.tokenizer.eos_id or (
106138
hasattr(self.tokenizer, "stop_tokens")
107139
and current_token in self.tokenizer.stop_tokens
108140
):
109141
break
142+
110143
print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True)
111144
print("\n")
112145

@@ -136,7 +169,7 @@ def text_completion(
136169
"""
137170
return self.generate(
138171
prompt_tokens=self.tokenizer.encode(prompt, bos=True, eos=False),
139-
max_seq_len=self.params.max_seq_len,
172+
max_seq_len=self.max_seq_len,
140173
temperature=temperature,
141174
top_p=top_p,
142175
echo=echo,
@@ -171,7 +204,7 @@ def chat_completion(
171204
prompt_tokens=self.tokenizer.encode(
172205
self._format_prompt(prompt), bos=True, eos=False
173206
),
174-
max_seq_len=self.params.max_seq_len,
207+
max_seq_len=self.max_seq_len,
175208
temperature=temperature,
176209
top_p=top_p,
177210
echo=True,

examples/models/llama/runner/native.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,22 @@
1010

1111
import torch
1212

13-
from examples.models.llama.llama_transformer import ModelArgs
13+
from executorch.examples.models.llama.export_llama_lib import (
14+
EXECUTORCH_DEFINED_MODELS,
15+
TORCHTUNE_DEFINED_MODELS,
16+
)
17+
1418
from executorch.extension.pybindings.portable_lib import _load_for_executorch
1519

1620
# Load custom ops and quantized ops.
1721
from executorch.extension.pybindings import portable_lib # noqa # usort: skip
1822

23+
from executorch.examples.models.llama.runner.generation import LlamaRunner
24+
1925
# Note: import this after portable_lib
2026
from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip
2127
from executorch.kernels import quantized # noqa
2228

23-
from .generation import LlamaRunner
24-
2529

2630
class NativeLlamaRunner(LlamaRunner):
2731
"""
@@ -31,13 +35,14 @@ class NativeLlamaRunner(LlamaRunner):
3135
def __init__(self, args):
3236
with open(args.params, "r") as f:
3337
params = json.loads(f.read())
34-
model_args: ModelArgs = ModelArgs(
38+
super().__init__(
39+
tokenizer_path=args.tokenizer,
3540
max_seq_len=args.max_len,
3641
max_batch_size=1,
3742
use_kv_cache=args.kv_cache,
38-
**params,
43+
vocab_size=params["vocab_size"],
44+
has_full_logits=args.model in TORCHTUNE_DEFINED_MODELS,
3945
)
40-
super().__init__(tokenizer_path=args.tokenizer, model_args=model_args)
4146
self.model = _load_for_executorch(args.pte)
4247

4348
def forward(
@@ -53,8 +58,15 @@ def forward(
5358

5459

5560
def build_args_parser() -> argparse.ArgumentParser:
61+
# TODO: merge these with build_args_parser from export_llama_lib.
5662
parser = argparse.ArgumentParser()
5763

64+
parser.add_argument(
65+
"--model",
66+
default="llama3",
67+
choices=EXECUTORCH_DEFINED_MODELS + TORCHTUNE_DEFINED_MODELS,
68+
)
69+
5870
parser.add_argument(
5971
"-f",
6072
"--pte",
@@ -89,7 +101,6 @@ def build_args_parser() -> argparse.ArgumentParser:
89101
parser.add_argument(
90102
"-kv",
91103
"--kv_cache",
92-
default=True,
93104
action="store_true",
94105
)
95106

extension/llm/export/builder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,10 @@ def export(self) -> "LLMEdgeManager":
194194
strict=True,
195195
)
196196
else:
197+
logging.info("Exporting with:")
198+
logging.info(f"inputs: {self.example_inputs}")
199+
logging.info(f"kwargs: {self.example_kwarg_inputs}")
200+
logging.info(f"dynamic shapes: {dynamic_shape}")
197201
exported_module = export_for_training(
198202
self.model,
199203
self.example_inputs,

0 commit comments

Comments
 (0)