Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 52caa9c

Browse files
author
vmpuri
committed
Remove if statement preventing tps stats from being printed when running generate with compile
1 parent 7fe2c86 commit 52caa9c

File tree

1 file changed

+12
-16
lines changed

1 file changed

+12
-16
lines changed

torchchat/generate.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,15 @@
2424

2525
from PIL import Image
2626

27+
# torchtune model definition dependencies
28+
from torchtune.data import Message, padded_collate_tiled_images_and_mask
29+
30+
from torchtune.generation import sample as tune_sample
31+
from torchtune.models.llama3 import llama3_tokenizer
32+
33+
from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform
34+
from torchtune.training import set_default_dtype
35+
2736
from torchchat.cli.builder import (
2837
_initialize_model,
2938
_initialize_tokenizer,
@@ -34,15 +43,6 @@
3443
from torchchat.utils.build_utils import device_sync, set_precision
3544
from torchchat.utils.device_info import get_device_info
3645

37-
# torchtune model definition dependencies
38-
from torchtune.data import Message, padded_collate_tiled_images_and_mask
39-
40-
from torchtune.generation import sample as tune_sample
41-
from torchtune.models.llama3 import llama3_tokenizer
42-
43-
from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform
44-
from torchtune.training import set_default_dtype
45-
4646

4747
class _ChatFormatter(ABC):
4848
def __init__(self, tokenizer):
@@ -1164,13 +1164,9 @@ def callback(x, *, done_generating=False):
11641164
print(
11651165
f"just-in-time compilation time (incl run time): {compilation_time:.2} seconds"
11661166
)
1167-
aggregate_metrics["tokens_per_sec_jit_compile"] = tokens_sec
1168-
# Don't continue here.... because we need to report and reset
1169-
# continue
1170-
else:
1171-
aggregate_metrics["tokens_per_sec"].append(tokens_sec)
1172-
aggregate_metrics["first_token_per_sec"].append(first_token_sec)
1173-
aggregate_metrics["next_tokens_per_sec"].append(next_tokens_sec)
1167+
aggregate_metrics["tokens_per_sec"].append(tokens_sec)
1168+
aggregate_metrics["first_token_per_sec"].append(first_token_sec)
1169+
aggregate_metrics["next_tokens_per_sec"].append(next_tokens_sec)
11741170

11751171
logging.info(
11761172
f"\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\

0 commit comments

Comments
 (0)