Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 8900f8a

Browse files
author
anirudh
committed
use kv caching and other minor fixes
1 parent bfc62dc commit 8900f8a

File tree

3 files changed

+49
-38
lines changed

3 files changed

+49
-38
lines changed

install/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,4 @@ streamlit
3434
flask
3535

3636
# eval
37-
lm_eval==0.4.5
37+
lm_eval==0.4.7

torchchat/model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,12 @@ def setup_caches(self, batch_size, dtype, encoder_max_seq_len, decoder_max_seq_l
608608
decoder_max_seq_len=decoder_max_seq_len,
609609
)
610610

611+
def caches_are_setup(self) -> bool:
612+
return self.model.caches_are_setup()
613+
614+
def caches_are_enabled(self) -> bool:
615+
return self.model.caches_are_enabled()
616+
611617
def reset_caches(self):
612618
self.model.reset_caches()
613619

torchchat/usages/eval.py

Lines changed: 42 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from lm_eval.models.hf_vlms import HFMultimodalLM
3737
from lm_eval.evaluator import evaluate
3838

39+
from torchtune.modules.common_utils import local_kv_cache
3940
from torchtune.modules.model_fusion import DeepFusionModel
4041
from torchtune.modules.transforms import Transform
4142
from torchtune.data import (
@@ -183,12 +184,7 @@ def _model_generate(self, context, max_length, eos_token_id):
183184
raise Exception("unimplemented")
184185

185186

186-
# Dummy class which _VLMEvalWrapper can inherit from when the imports don't work
187-
# class HFMultimodalLM():
188-
# def __init__(self):
189-
# return
190-
191-
class _VLMEvalWrapper(HFMultimodalLM):
187+
class VLMEvalWrapper(HFMultimodalLM):
192188
"""An EvalWrapper for EleutherAI's eval harness based on gpt-fast's
193189
EvalWrapper: https://github.com/pytorch-labs/gpt-fast/blob/main/eval.py.
194190
@@ -234,6 +230,7 @@ def __init__(
234230
self._enable_kv_cache = True
235231
self._image_tag = image_tag
236232
self._max_images_per_sample = max_images_per_sample
233+
self.times = []
237234

238235
@property
239236
def model(self):
@@ -338,6 +335,7 @@ def tok_batch_multimodal_encode(
338335
all_encoded_messages,
339336
pad_direction="left",
340337
pad_max_images=self._max_images_per_sample,
338+
pad_max_tiles=self._transform.max_num_tiles,
341339
)
342340
utils.batch_to_device(tok_batch, self.device)
343341

@@ -376,15 +374,11 @@ def _model_multimodal_generate(
376374
"multimodal generation."
377375
)
378376

379-
# 2. Setup KV cache and masks for bsz 1
377+
encoder_max_seq_len = (
378+
self.model_transform.image_seq_len * self._max_images_per_sample
379+
)
380+
# Setup masks for bsz 1
380381
with self.device:
381-
self.model.setup_caches(
382-
batch_size=1,
383-
dtype=self._dtype,
384-
encoder_max_seq_len=self.model_transform.image_seq_len
385-
* self._max_images_per_sample,
386-
decoder_max_seq_len=self.max_length,
387-
)
388382
causal_mask = torch.tril(
389383
torch.ones(
390384
size=(self.max_length, self.max_length),
@@ -396,28 +390,39 @@ def _model_multimodal_generate(
396390
batch["input_pos"] = input_pos[None, :seq_len]
397391
batch["mask"] = causal_mask[None, :seq_len]
398392

399-
# 3. Prefill step
400-
generated_tokens = []
401-
logits = self.model(prompt, **batch)[:, -1]
402-
token = sample(logits, temperature=0.0, top_k=None)
403-
generated_tokens.append(token.item())
404-
405-
cache_mask = batch["encoder_mask"][:, -1:]
406-
407-
# 4. Continue generating
408-
for _ in range(max_length):
409-
if token.item() in self.model_transform.stop_tokens:
410-
break
411-
logits = self.model(
412-
token,
413-
mask=causal_mask[None, seq_len, None, :],
414-
encoder_input=None,
415-
encoder_mask=cache_mask,
416-
input_pos=input_pos[None, seq_len],
417-
)[:, -1]
418-
token = sample(logits, temperature=0.0, top_k=None)
419-
generated_tokens.append(token.item())
420-
seq_len += 1
393+
with measure_time(message=None) as measure:
394+
# 2. Setup KV cache
395+
with local_kv_cache(
396+
self.model,
397+
batch_size=self.batch_size,
398+
device=self.device,
399+
dtype=self._dtype,
400+
encoder_max_seq_len=encoder_max_seq_len,
401+
decoder_max_seq_len=self.max_length,
402+
):
403+
# 3. Prefill step
404+
generated_tokens = []
405+
logits = self.model(prompt, **batch)[:, -1]
406+
token = sample(logits, temperature=0.0, top_k=None)
407+
generated_tokens.append(token.item())
408+
409+
cache_mask = batch["encoder_mask"][:, -1:]
410+
411+
# 4. Continue generating
412+
for _ in range(max_length):
413+
if token.item() in self.model_transform.stop_tokens:
414+
break
415+
logits = self.model(
416+
token,
417+
mask=causal_mask[None, seq_len, None, :],
418+
encoder_input=None,
419+
encoder_mask=cache_mask,
420+
input_pos=input_pos[None, seq_len],
421+
)[:, -1]
422+
token = sample(logits, temperature=0.0, top_k=None)
423+
generated_tokens.append(token.item())
424+
seq_len += 1
425+
self.times.append(measure.get_time())
421426

422427
# 5. Return generated tokens
423428
return torch.tensor(generated_tokens, dtype=torch.int32).unsqueeze(0)
@@ -506,7 +511,7 @@ def multi_model_eval(
506511
max_seq_length = 4096 if max_seq_length is None else max_seq_length
507512
device = utils.get_device(device) if isinstance(device, str) else device
508513

509-
model_eval_wrapper = _VLMEvalWrapper(
514+
model_eval_wrapper = VLMEvalWrapper(
510515
model,
511516
transform=tokenizer, # tranform is the tokenizer for multimodal models
512517
max_seq_length=max_seq_length,

0 commit comments

Comments
 (0)