Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit bfc62dc

Browse files
author
anirudh
committed
Added Llama3VisionTransform in TokenizerArgs and other changes
1 parent 78bdacf commit bfc62dc

File tree

3 files changed

+58
-33
lines changed

3 files changed

+58
-33
lines changed

torchchat/cli/builder.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -252,13 +252,29 @@ class TokenizerArgs:
252252
is_sentencepiece: bool = False
253253
is_tiktoken: bool = False
254254
is_hf_tokenizer: bool = False
255+
is_llama_3_2_mm: bool = False
255256
t: Optional[Any] = None
256257

257258
def __post_init__(self):
259+
# special handling for llama-3.2-mm
260+
if "llama-3.2-11b-vision" in str(self.tokenizer_path).lower():
261+
try:
262+
from torchtune.models.llama3_2_vision import llama3_2_vision_transform
263+
264+
self.t = llama3_2_vision_transform(path=str(self.tokenizer_path))
265+
self.is_llama_3_2_mm = True
266+
self.is_tiktoken = False
267+
self.is_sentencepiece = False
268+
self.is_hf_tokenizer = False
269+
return
270+
except:
271+
pass
272+
258273
try:
259274
from tokenizer.tiktoken import Tokenizer as TiktokenTokenizer
260275

261276
self.t = TiktokenTokenizer(model_path=str(self.tokenizer_path))
277+
self.is_llama_3_2_mm = False
262278
self.is_tiktoken = True
263279
self.is_sentencepiece = False
264280
self.is_hf_tokenizer = False
@@ -270,6 +286,7 @@ def __post_init__(self):
270286
from sentencepiece import SentencePieceProcessor
271287

272288
self.t = SentencePieceProcessor(model_file=str(self.tokenizer_path))
289+
self.is_llama_3_2_mm = False
273290
self.is_tiktoken = False
274291
self.is_sentencepiece = True
275292
self.is_hf_tokenizer = False
@@ -281,13 +298,15 @@ def __post_init__(self):
281298
from tokenizer.hf_tokenizer import HFTokenizer
282299

283300
self.t = HFTokenizer(str(self.tokenizer_path))
301+
self.is_llama_3_2_mm = False
284302
self.is_tiktoken = False
285303
self.is_sentencepiece = False
286304
self.is_hf_tokenizer = True
287305
return
288306
except:
289307
pass
290308

309+
self.is_llama_3_2_mm = False
291310
self.is_tiktoken = False
292311
self.is_sentencepiece = False
293312
self.is_hf_tokenizer = False
@@ -302,20 +321,22 @@ def validate_model(
302321
if model is None:
303322
return
304323

305-
if sum([self.is_tiktoken, self.is_hf_tokenizer, self.is_sentencepiece]) != 1:
324+
if sum([self.is_tiktoken, self.is_hf_tokenizer, self.is_sentencepiece, self.is_llama_3_2_mm]) != 1:
306325
raise RuntimeError(f"no tokenizer was found at {self.tokenizer_path}")
307326

308327
is_tiktoken = self.is_tiktoken
309328
is_sentencepiece = self.is_sentencepiece
310329
is_hf_tokenizer = self.is_hf_tokenizer
330+
is_llama_3_2_mm = self.is_llama_3_2_mm
331+
311332
use_tiktoken = model.config.use_tiktoken
312333
use_hf_tokenizer = model.config.use_hf_tokenizer
313-
use_sentencepiece = not (use_tiktoken or use_hf_tokenizer)
314-
334+
use_other_tokenizer = not (use_tiktoken or use_hf_tokenizer)
315335
if (
316336
(is_tiktoken and not use_tiktoken) or
317337
(is_hf_tokenizer and not use_hf_tokenizer) or
318-
(is_sentencepiece and not use_sentencepiece)
338+
(is_sentencepiece and not use_other_tokenizer) or
339+
(is_llama_3_2_mm and not use_other_tokenizer)
319340
):
320341
raise RuntimeError(
321342
"model-specified tokenizer ({}) does not match provided tokenizer ({}) for {}".format(

torchchat/model_params/Llama-3.2-11B-Vision.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"model_type": "flamingo",
3-
"use_tiktoken": true,
3+
"use_tiktoken": false,
44
"encoder": {
55
"patch_size": 14,
66
"num_heads": 16,

torchchat/usages/eval.py

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -378,16 +378,13 @@ def _model_multimodal_generate(
378378

379379
# 2. Setup KV cache and masks for bsz 1
380380
with self.device:
381-
if self.model.caches_are_enabled():
382-
self.model.reset_caches()
383-
else:
384-
self.model.setup_caches(
385-
batch_size=1,
386-
dtype=self._dtype,
387-
encoder_max_seq_len=self.model_transform.image_seq_len
388-
* self._max_images_per_sample,
389-
decoder_max_seq_len=self.max_length,
390-
)
381+
self.model.setup_caches(
382+
batch_size=1,
383+
dtype=self._dtype,
384+
encoder_max_seq_len=self.model_transform.image_seq_len
385+
* self._max_images_per_sample,
386+
decoder_max_seq_len=self.max_length,
387+
)
391388
causal_mask = torch.tril(
392389
torch.ones(
393390
size=(self.max_length, self.max_length),
@@ -506,6 +503,8 @@ def multi_model_eval(
506503
"""
507504
if tasks is None:
508505
tasks = ["wikitext"]
506+
max_seq_length = 4096 if max_seq_length is None else max_seq_length
507+
device = utils.get_device(device) if isinstance(device, str) else device
509508

510509
model_eval_wrapper = _VLMEvalWrapper(
511510
model,
@@ -578,25 +577,30 @@ def main(args) -> None:
578577
)
579578
torch._inductor.config.coordinate_descent_tuning = False if device == "cpu" else True
580579

581-
evaluator = None
582-
if modality == "text":
583-
evaluator = eval
584-
elif modality == "text-image":
585-
evaluator = multi_model_eval
586-
else:
587-
raise ValueError(f"Unsupported modality: {modality}")
588-
589580
with measure_time("Time to run eval: {time:.02f}s."):
590-
result = evaluator(
591-
model.to(device),
592-
model_forward,
593-
tokenizer,
594-
tasks,
595-
limit,
596-
max_seq_length,
597-
device=builder_args.device,
598-
is_pte_model=builder_args.pte_path is not None,
599-
)
581+
if modality == "text":
582+
result = eval(
583+
model.to(device),
584+
model_forward,
585+
tokenizer,
586+
tasks,
587+
limit,
588+
max_seq_length,
589+
device=builder_args.device,
590+
is_pte_model=builder_args.pte_path is not None,
591+
)
592+
elif modality == "text-image":
593+
result = multi_model_eval(
594+
model.to(device),
595+
model_forward,
596+
tokenizer,
597+
tasks,
598+
limit,
599+
max_seq_length,
600+
device=builder_args.device,
601+
)
602+
else:
603+
raise ValueError(f"Unsupported modality: {modality}")
600604

601605
times = torch.tensor(result["times"])
602606
print(

0 commit comments

Comments
 (0)