diff --git a/torchchat/generate.py b/torchchat/generate.py index 4a67195fb..66f26ff9f 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -1189,12 +1189,27 @@ def callback(x, *, done_generating=False): f"Mean Accepted: {sum([idx * i for idx, i in enumerate(counts_aggregated)])/sum(counts_aggregated)}" ) - print( - f"\n Average tokens/sec (total): {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f} \ - \nAverage tokens/sec (first token): {torch.mean(torch.tensor(aggregate_metrics['first_token_per_sec'])).item():.2f} \ - \nAverage tokens/sec (next tokens): {torch.mean(torch.tensor(aggregate_metrics['next_tokens_per_sec'])).item():.2f} \n\ + avg_tokens_sec = torch.mean( + torch.tensor(aggregate_metrics["tokens_per_sec"]) + ).item() + avg_first_token_sec = torch.mean( + torch.tensor(aggregate_metrics["first_token_per_sec"]) + ).item() + avg_next_tokens_sec = torch.mean( + torch.tensor(aggregate_metrics["next_tokens_per_sec"]) + ).item() + + if not ( + torch.isnan(torch.tensor(avg_tokens_sec)) + or torch.isnan(torch.tensor(avg_first_token_sec)) + or torch.isnan(torch.tensor(avg_next_tokens_sec)) + ): + print( + f"\n Average tokens/sec (total): {avg_tokens_sec:.2f} \ + \nAverage tokens/sec (first token): {avg_first_token_sec:.2f} \ + \nAverage tokens/sec (next tokens): {avg_next_tokens_sec:.2f} \n\ " - ) + ) if torch.cuda.is_available(): print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")