Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 116c5c2

Browse files
committed
Only set up during the first sample
1 parent 2fcc37c commit 116c5c2

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

torchchat/generate.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,7 @@ def generate(
591591
Dict[str, Any]
592592
] = None, # List of Image prompt tensors for multimodal models
593593
start_pos: int = 0,
594+
skip_cache_setup: bool = False,
594595
draft_model: Model,
595596
speculate_k: Optional[int] = 8,
596597
sequential_prefill=True,
@@ -613,7 +614,7 @@ def generate(
613614
prompt_length = prompt.size(0)
614615
max_new_tokens = min(max_new_tokens, max_seq_length - start_pos - prompt_length)
615616
# set up caches only if first inference
616-
if start_pos == 0:
617+
if start_pos == 0 and not skip_cache_setup:
617618
model = model.to(device=device)
618619
with torch.device(device):
619620
if (
@@ -1020,6 +1021,7 @@ def chat(
10201021
)
10211022
for i in range(num_samples):
10221023
device_sync(device=self.builder_args.device)
1024+
is_first_sample: bool = i == 0
10231025
if generator_args.chat_mode:
10241026
prompt = input("User: ")
10251027
if prompt == "/bye":
@@ -1045,7 +1047,7 @@ def chat(
10451047
]
10461048
)
10471049
self.system_prompt = None
1048-
elif i == 0:
1050+
elif is_first_sample:
10491051
encoded = self.chat_formatter.encode_dialog_prompt(
10501052
[{"role": "user", "content": prompt}]
10511053
)
@@ -1116,6 +1118,7 @@ def callback(x, *, done_generating=False):
11161118
top_k=generator_args.top_k,
11171119
sequential_prefill=generator_args.sequential_prefill,
11181120
start_pos=start_pos,
1121+
skip_cache_setup=not is_first_sample,
11191122
max_seq_length=max_seq_length,
11201123
)
11211124
for token_tensor, metrics in generator_func:
@@ -1125,7 +1128,7 @@ def callback(x, *, done_generating=False):
11251128
if metrics is not None:
11261129
aggregate_metrics.update(metrics)
11271130
yield token_tensor, metrics
1128-
jit_compile = (i == 0) and (
1131+
jit_compile = is_first_sample and (
11291132
generator_args.compile or generator_args.compile_prefill
11301133
)
11311134
compilation_time = time.perf_counter() - t0

0 commit comments

Comments
 (0)