Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit b87d923

Browse files
committed
Remove tokens per sec in aggregate_metrics when jit_compile
1 parent 4697764 commit b87d923

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

torchchat/generate.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1149,9 +1149,11 @@ def callback(x, *, done_generating=False):
11491149
print(
11501150
f"just-in-time compilation time (incl run time): {compilation_time:.2} seconds"
11511151
)
1152-
aggregate_metrics["tokens_per_sec"].append(tokens_sec)
1153-
aggregate_metrics["first_token_per_sec"].append(first_token_sec)
1154-
aggregate_metrics["next_tokens_per_sec"].append(next_tokens_sec)
1152+
else:
1153+
# aggregate_metrics will not append when is jit_compile, which will affect the average numbers.
1154+
aggregate_metrics["tokens_per_sec"].append(tokens_sec)
1155+
aggregate_metrics["first_token_per_sec"].append(first_token_sec)
1156+
aggregate_metrics["next_tokens_per_sec"].append(next_tokens_sec)
11551157

11561158
logging.info(
11571159
f"\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\

0 commit comments

Comments
 (0)