diff --git a/torchchat/generate.py b/torchchat/generate.py index 66f26ff9f..9b4c6430a 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -1149,9 +1149,11 @@ def callback(x, *, done_generating=False): print( f"just-in-time compilation time (incl run time): {compilation_time:.2} seconds" ) - aggregate_metrics["tokens_per_sec"].append(tokens_sec) - aggregate_metrics["first_token_per_sec"].append(first_token_sec) - aggregate_metrics["next_tokens_per_sec"].append(next_tokens_sec) + else: + # aggregate_metrics will not append when is jit_compile, which will affect the average numbers. + aggregate_metrics["tokens_per_sec"].append(tokens_sec) + aggregate_metrics["first_token_per_sec"].append(first_token_sec) + aggregate_metrics["next_tokens_per_sec"].append(next_tokens_sec) logging.info( f"\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\ @@ -1205,7 +1207,8 @@ def callback(x, *, done_generating=False): or torch.isnan(torch.tensor(avg_next_tokens_sec)) ): print( - f"\n Average tokens/sec (total): {avg_tokens_sec:.2f} \ + f"\nWarning: Excluding compile in calculations \ + \n Average tokens/sec (total): {avg_tokens_sec:.2f} \ \nAverage tokens/sec (first token): {avg_first_token_sec:.2f} \ \nAverage tokens/sec (next tokens): {avg_next_tokens_sec:.2f} \n\ "