Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 815966c

Browse files
author
anirudh
committed
instantiate transform in eval()
1 parent 14502bf commit 815966c

File tree

2 files changed

+37
-26
lines changed

2 files changed

+37
-26
lines changed

torchchat/model_params/Llama-3.2-11B-Vision.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"model_type": "flamingo",
3-
"use_tiktoken": false,
3+
"use_tiktoken": true,
44
"encoder": {
55
"patch_size": 14,
66
"num_heads": 16,

torchchat/usages/eval.py

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -197,24 +197,11 @@ class VLMEvalWrapper(HFMultimodalLM):
197197
the max number of images in MMMU.
198198
"""
199199

200-
# Having the imports here allow running other evals without installing torchtune
201-
from torchtune import utils
202-
from torchtune.data import (
203-
format_content_with_images,
204-
left_pad_sequence,
205-
Message,
206-
padded_collate_tiled_images_and_mask,
207-
)
208-
from torchtune.generation import generate, sample
209-
210-
from torchtune.modules.common_utils import local_kv_cache
211-
from torchtune.modules.model_fusion import DeepFusionModel
212-
from torchtune.modules.transforms import Transform
213200

214201
def __init__(
215202
self,
216-
model: DeepFusionModel,
217-
transform: Transform,
203+
model: Model,
204+
transform,
218205
*,
219206
device: torch.device,
220207
max_seq_length: int = 4096,
@@ -226,6 +213,25 @@ def __init__(
226213
image_tag: str = "<image>",
227214
max_images_per_sample: int = 7,
228215
):
216+
# Having the imports here allow running other evals without installing torchtune
217+
from torchtune.utils import batch_to_device
218+
from torchtune.data import (
219+
format_content_with_images,
220+
left_pad_sequence,
221+
Message,
222+
padded_collate_tiled_images_and_mask,
223+
)
224+
from torchtune.generation import generate, sample
225+
from torchtune.modules.common_utils import local_kv_cache
226+
self.batch_to_device = batch_to_device
227+
self.format_content_with_images = format_content_with_images
228+
self.left_pad_sequence = left_pad_sequence
229+
self.Message = Message
230+
self.padded_collate_tiled_images_and_mask = padded_collate_tiled_images_and_mask
231+
self.generate = generate
232+
self.sample = sample
233+
self.local_kv_cache = local_kv_cache
234+
229235
self._model = model
230236
self._transform = transform
231237
self._device = device
@@ -326,24 +332,24 @@ def tok_batch_multimodal_encode(
326332

327333
# Construct the messages
328334
messages = []
329-
content = format_content_with_images(
335+
content = self.format_content_with_images(
330336
text, image_tag=self._image_tag, images=proper_images
331337
)
332-
messages.append(Message(role="user", content=content))
333-
messages.append(Message(role="assistant", content=""))
338+
messages.append(self.Message(role="user", content=content))
339+
messages.append(self.Message(role="assistant", content=""))
334340

335341
# Transform the messages
336342
tok_batch = self.model_transform({"messages": messages}, inference=True)
337343
all_encoded_messages.append(tok_batch)
338344

339345
# Pad the encoded messages
340-
tok_batch = padded_collate_tiled_images_and_mask(
346+
tok_batch = self.padded_collate_tiled_images_and_mask(
341347
all_encoded_messages,
342348
pad_direction="left",
343349
pad_max_images=self._max_images_per_sample,
344350
pad_max_tiles=self._transform.max_num_tiles,
345351
)
346-
utils.batch_to_device(tok_batch, self.device)
352+
self.batch_to_device(tok_batch, self.device)
347353

348354
# Convert the batch to the format expected by the HF
349355
tok_batch["input_ids"] = tok_batch.pop("tokens")
@@ -398,7 +404,7 @@ def _model_multimodal_generate(
398404

399405
with measure_time(message=None) as measure:
400406
# 2. Setup KV cache
401-
with local_kv_cache(
407+
with self.local_kv_cache(
402408
self.model,
403409
batch_size=self.batch_size,
404410
device=self.device,
@@ -409,7 +415,7 @@ def _model_multimodal_generate(
409415
# 3. Prefill step
410416
generated_tokens = []
411417
logits = self.model(prompt, **batch)[:, -1]
412-
token = sample(logits, temperature=0.0, top_k=None)
418+
token = self.sample(logits, temperature=0.0, top_k=None)
413419
generated_tokens.append(token.item())
414420

415421
cache_mask = batch["encoder_mask"][:, -1:]
@@ -425,7 +431,7 @@ def _model_multimodal_generate(
425431
encoder_mask=cache_mask,
426432
input_pos=input_pos[None, seq_len],
427433
)[:, -1]
428-
token = sample(logits, temperature=0.0, top_k=None)
434+
token = self.sample(logits, temperature=0.0, top_k=None)
429435
generated_tokens.append(token.item())
430436
seq_len += 1
431437
self.times.append(measure.get_time())
@@ -460,6 +466,7 @@ def eval(
460466
Returns:
461467
eval_results (dict): A dictionary of evaluation results for the specified task(s).
462468
"""
469+
463470
if tasks is None:
464471
if modality == "text":
465472
tasks = ["wikitext"]
@@ -478,11 +485,14 @@ def eval(
478485
# use eot_token_id as prefix_token_id.
479486
model_eval_wrapper.custom_prefix_token_id = model_eval_wrapper.eot_token_id
480487
elif modality == "text-image":
488+
from torchtune.utils import get_device
489+
from torchtune.models.llama3_2_vision import llama3_2_vision_transform
490+
481491
model_eval_wrapper = VLMEvalWrapper(
482492
model,
483-
transform=tokenizer,
493+
transform=llama3_2_vision_transform(path=str(tokenizer.tokenizer_path)),
484494
max_seq_length = 4096 if max_seq_length is None else max_seq_length,
485-
device = utils.get_device(device) if isinstance(device, str) else device,
495+
device = get_device(device) if isinstance(device, str) else device,
486496
)
487497

488498
try:
@@ -531,6 +541,7 @@ def main(args) -> None:
531541
set_precision(builder_args.precision)
532542

533543
tokenizer = _initialize_tokenizer(tokenizer_args)
544+
tokenizer.tokenizer_path = tokenizer_args.tokenizer_path
534545
builder_args.setup_caches = False
535546
model = _initialize_model(
536547
builder_args,

0 commit comments

Comments
 (0)