@@ -457,7 +457,12 @@ def time_per_output_token_ms(self) -> Optional[float]: # type: ignore[override]
457
457
This includes the time to generate the first token and all other tokens.
458
458
None if the output_tokens is None or 0.
459
459
"""
460
- if self .output_tokens is None or self .output_tokens == 0 :
460
+ if (
461
+ self .output_tokens is None
462
+ or self .output_tokens == 0
463
+ or self .first_token_time is None
464
+ or self .last_token_time is None
465
+ ):
461
466
return None
462
467
463
468
return super ().time_per_output_token_ms
@@ -614,41 +619,46 @@ def duration(self) -> float:
614
619
),
615
620
)
616
621
617
- def create_sampled (self , sample_size : int ) -> "GenerativeBenchmark" :
622
+ def set_sample_size (self , sample_size : Optional [ int ] ) -> "GenerativeBenchmark" :
618
623
"""
619
- Create a new benchmark instance with a random sample of the completed and
620
- errored requests based on the given sample sizes. If the sample sizes are
621
- larger than the total number of requests, the sample sizes are capped at
622
- the total number of requests.
624
+ Set the sample size for the benchmark. This will randomly sample the
625
+ requests for each status type to the given sample size or the maximum
626
+ number of requests for that status type, whichever is smaller.
627
+ This is applied to requests.successful, requests.errored, and
628
+ requests.incomplete.
629
+ If None, no sampling is applied and the state is kept.
623
630
624
631
:param sample_size: The number of requests to sample for each status type.
625
- :return: A new benchmark instance with the sampled requests.
626
- :raises ValueError: If the sample sizes are negative .
632
+ :return: The benchmark with the sampled requests.
633
+ :raises ValueError: If the sample size is invalid .
627
634
"""
628
- if sample_size < 0 :
629
- raise ValueError (f"Sample size must be non-negative, given { sample_size } " )
630
635
631
- sample_size = min (sample_size , len (self .requests .successful ))
632
- error_sample_size = min (sample_size , len (self .requests .errored ))
633
- incomplete_sample_size = min (sample_size , len (self .requests .incomplete ))
636
+ if sample_size is not None :
637
+ if sample_size < 0 or not isinstance (sample_size , int ):
638
+ raise ValueError (
639
+ f"Sample size must be non-negative integer, given { sample_size } "
640
+ )
634
641
635
- sampled_instance = self .model_copy ()
636
- sampled_instance .requests .successful = random .sample (
637
- self .requests .successful , sample_size
638
- )
639
- sampled_instance .requests .errored = random .sample (
640
- self .requests .errored , error_sample_size
641
- )
642
- sampled_instance .requests .incomplete = random .sample (
643
- self .requests .incomplete , incomplete_sample_size
644
- )
645
- sampled_instance .request_samples = StatusBreakdown (
646
- successful = len (sampled_instance .requests .successful ),
647
- incomplete = len (sampled_instance .requests .incomplete ),
648
- errored = len (sampled_instance .requests .errored ),
649
- )
642
+ sample_size = min (sample_size , len (self .requests .successful ))
643
+ error_sample_size = min (sample_size , len (self .requests .errored ))
644
+ incomplete_sample_size = min (sample_size , len (self .requests .incomplete ))
645
+
646
+ self .requests .successful = random .sample (
647
+ self .requests .successful , sample_size
648
+ )
649
+ self .requests .errored = random .sample (
650
+ self .requests .errored , error_sample_size
651
+ )
652
+ self .requests .incomplete = random .sample (
653
+ self .requests .incomplete , incomplete_sample_size
654
+ )
655
+ self .request_samples = StatusBreakdown (
656
+ successful = len (self .requests .successful ),
657
+ incomplete = len (self .requests .incomplete ),
658
+ errored = len (self .requests .errored ),
659
+ )
650
660
651
- return sampled_instance
661
+ return self
652
662
653
663
@staticmethod
654
664
def from_stats (
0 commit comments