@@ -135,26 +135,43 @@ def __str__(self) -> str:
135
135
class BenchmarkResult :
136
136
"Class for holding results of benchmark runs"
137
137
short_name : str
138
- elapsed_time : torch .Tensor # milliseconds
138
+ gpu_elapsed_time : torch .Tensor # milliseconds
139
+ cpu_elapsed_time : torch .Tensor # milliseconds
139
140
mem_stats : List [MemoryStats ] # memory stats per rank
140
141
rank : int = - 1
141
142
142
143
def __str__ (self ) -> str :
143
- runtime = f"Runtime (P90): { self .runtime_percentile (90 ):.2f} ms"
144
+ gpu_runtime = (
145
+ f"GPU Runtime (P90): { self .runtime_percentile (90 , device = 'gpu' ):.2f} ms"
146
+ )
147
+ cpu_runtime = (
148
+ f"CPU Runtime (P90): { self .runtime_percentile (90 , device = 'cpu' ):.2f} ms"
149
+ )
144
150
if len (self .mem_stats ) == 0 :
145
- return f"{ self .short_name : <{35 }} | { runtime } "
151
+ return f"{ self .short_name : <{35 }} | { gpu_runtime } | { cpu_runtime } "
146
152
mem_alloc = (
147
153
f"Peak Memory alloc (P90): { self .max_mem_alloc_percentile (90 )/ 1000 :.2f} GB"
148
154
)
149
155
mem_reserved = f"Peak Memory reserved (P90): { self .max_mem_reserved_percentile (90 )/ 1000 :.2f} GB"
150
- malloc_retries = f"Malloc retries (P50/P90/P100): { self .mem_retries (50 ) } / { self .mem_retries (90 )} / { self .mem_retries (100 )} "
151
- return f"{ self .short_name : <{35 }} | { malloc_retries } | { runtime } | { mem_alloc } | { mem_reserved } "
156
+ malloc_retries = f"Malloc retries (P50/P90/P100): { self .mem_retries (50 )} / { self .mem_retries (90 )} / { self .mem_retries (100 )} "
157
+ return f"{ self .short_name : <{35 }} | { malloc_retries } | { gpu_runtime } | { cpu_runtime } | { mem_alloc } | { mem_reserved } "
152
158
153
159
def runtime_percentile (
154
- self , percentile : int = 50 , interpolation : str = "nearest"
160
+ self ,
161
+ percentile : int = 50 ,
162
+ interpolation : str = "nearest" ,
163
+ device : str = "gpu" ,
155
164
) -> torch .Tensor :
165
+ """Return the runtime percentile for the requested timer.
166
+
167
+ Args:
168
+ percentile: Percentile to compute.
169
+ interpolation: See ``torch.quantile``.
170
+ device: 'gpu' for CUDA event timings, 'cpu' for active CPU timings.
171
+ """
172
+ timings = self .gpu_elapsed_time if device == "gpu" else self .cpu_elapsed_time
156
173
return torch .quantile (
157
- self . elapsed_time ,
174
+ timings ,
158
175
percentile / 100.0 ,
159
176
interpolation = interpolation ,
160
177
)
@@ -409,17 +426,26 @@ def write_report(
409
426
num_requests : int ,
410
427
) -> None :
411
428
for benchmark_res in benchmark_results :
412
- avg_dur_s = benchmark_res .elapsed_time .mean ().item () * 1e-3 # time in seconds
413
- std_dur_s = benchmark_res .elapsed_time .std ().item () * 1e-3 # time in seconds
429
+ # GPU statistics
430
+ avg_dur_s_gpu = benchmark_res .gpu_elapsed_time .mean ().item () * 1e-3 # sec
431
+ std_dur_s_gpu = benchmark_res .gpu_elapsed_time .std ().item () * 1e-3 # sec
432
+
433
+ # CPU statistics
434
+ avg_dur_s_cpu = benchmark_res .cpu_elapsed_time .mean ().item () * 1e-3 # sec
435
+ std_dur_s_cpu = benchmark_res .cpu_elapsed_time .std ().item () * 1e-3 # sec
414
436
415
- qps = int (num_requests / avg_dur_s )
437
+ qps_gpu = int (num_requests / avg_dur_s_gpu )
416
438
417
439
mem_str = ""
418
440
for memory_stats in benchmark_res .mem_stats :
419
441
mem_str += f"{ memory_stats } \n "
420
442
421
- report_str += f"{ benchmark_res .short_name :40} Avg QPS:{ qps :10} Avg Duration: { int (1000 * avg_dur_s ):5} "
422
- report_str += f"ms Standard Dev Duration: { (1000 * std_dur_s ):.2f} ms\n "
443
+ report_str += (
444
+ f"{ benchmark_res .short_name :40} "
445
+ f"Avg QPS(GPU):{ qps_gpu :10} "
446
+ f"GPU Avg: { int (1000 * avg_dur_s_gpu ):5} ms ±{ (1000 * std_dur_s_gpu ):.2f} ms "
447
+ f"CPU Avg: { int (1000 * avg_dur_s_cpu ):5} ms ±{ (1000 * std_dur_s_cpu ):.2f} ms\n "
448
+ )
423
449
report_str += f"\t Memory Allocated Per Rank:\n \t { mem_str } \n "
424
450
425
451
with open (report_file , "w" ) as f :
@@ -731,44 +757,63 @@ def _run_benchmark_core(
731
757
if reset_accumulated_memory_stats :
732
758
torch .cuda .reset_accumulated_memory_stats (rank )
733
759
734
- # Optional allocator warm-up to create fragmentation similar to production
735
- if pre_gpu_load and device_type == "cuda" :
736
- _tmp = torch .rand (16384 , 16384 , device = "cuda" )
737
- for _ in range (pre_gpu_load ):
738
- _tmp = _tmp * torch .rand (16384 , 16384 , device = "cuda" )
760
+ # Optional allocator warm-up to create fragmentation similar to production
761
+ if pre_gpu_load :
762
+ _tmp = torch .rand (16384 , 16384 , device = "cuda" )
763
+ for _ in range (pre_gpu_load ):
764
+ _tmp = _tmp * torch .rand (16384 , 16384 , device = "cuda" )
739
765
740
- # Timing loop
766
+ # Timings
741
767
start_events , end_events , times = [], [], []
768
+
742
769
if device_type == "cuda" :
743
770
start_events = [
744
771
torch .cuda .Event (enable_timing = True ) for _ in range (num_benchmarks )
745
772
]
746
773
end_events = [
747
774
torch .cuda .Event (enable_timing = True ) for _ in range (num_benchmarks )
748
775
]
776
+ # Capture per-iteration active CPU cycles (excludes time the thread is truly idle/asleep) using `process_time_ns`.
777
+ cpu_times_active_ns : List [int ] = []
778
+
749
779
for i in range (num_benchmarks ):
780
+ # Ensure that outstanding GPU work from the previous iteration has
781
+ # finished so that we do not attribute its wait time to the next
782
+ # CPU measurement.
783
+ if i > 0 :
784
+ torch .cuda .synchronize (rank if rank >= 0 else 0 )
785
+
750
786
start_events [i ].record ()
787
+ cpu_start_active_ns = time .process_time_ns ()
788
+
751
789
run_iter_fn ()
790
+
791
+ cpu_end_active_ns = time .process_time_ns ()
752
792
end_events [i ].record ()
753
- else :
754
- times = timeit .repeat (run_iter_fn , number = 1 , repeat = num_benchmarks )
793
+ cpu_times_active_ns .append (cpu_end_active_ns - cpu_start_active_ns )
755
794
756
- # Make sure all kernels are finished before reading timers / stats
757
- if device_type == "cuda" :
795
+ # Convert to milliseconds and drop the first iteration
796
+ cpu_elapsed_time = torch .tensor (
797
+ [t / 1e6 for t in cpu_times_active_ns [1 :]], dtype = torch .float
798
+ )
799
+
800
+ # Make sure all kernels are finished before reading timers / stats
758
801
if rank == - 1 :
759
802
for di in range (world_size ):
760
803
torch .cuda .synchronize (di )
761
804
else :
762
805
torch .cuda .synchronize (rank )
763
806
764
- # First Benchmark Run for Eager Mode produces outlier
765
- # Start counting after first as workaround for standard deviation
766
- if device_type == "cuda" :
767
- elapsed_time = torch .tensor (
807
+ gpu_elapsed_time = torch .tensor (
768
808
[s .elapsed_time (e ) for s , e in zip (start_events [1 :], end_events [1 :])]
769
809
)
770
810
else :
771
- elapsed_time = torch .tensor (times ) * 1e3 # convert seconds ➜ milliseconds
811
+ # For CPU-only benchmarks we fall back to wall-clock timing via ``timeit``.
812
+ times = timeit .repeat (run_iter_fn , number = 1 , repeat = num_benchmarks )
813
+ cpu_elapsed_time = torch .tensor (times ) * 1e3 # convert to ms
814
+
815
+ # mirror CPU timings for overall consistency
816
+ gpu_elapsed_time = cpu_elapsed_time .clone ()
772
817
773
818
# Memory statistics collection
774
819
mem_stats : List [MemoryStats ] = []
@@ -820,7 +865,11 @@ def _trace_handler(prof: torch.profiler.profile) -> None:
820
865
torch .cuda .synchronize (rank )
821
866
822
867
return BenchmarkResult (
823
- short_name = name , elapsed_time = elapsed_time , mem_stats = mem_stats , rank = rank
868
+ short_name = name ,
869
+ gpu_elapsed_time = gpu_elapsed_time ,
870
+ cpu_elapsed_time = cpu_elapsed_time ,
871
+ mem_stats = mem_stats ,
872
+ rank = rank ,
824
873
)
825
874
826
875
@@ -1095,10 +1144,11 @@ def setUp() -> None:
1095
1144
assert 0 == p .exitcode
1096
1145
1097
1146
total_benchmark_res = BenchmarkResult (
1098
- benchmark_res_per_rank [0 ].short_name ,
1099
- benchmark_res_per_rank [0 ].elapsed_time ,
1100
- [MemoryStats (rank , 0 , 0 , 0 ) for rank in range (world_size )],
1101
- 0 ,
1147
+ short_name = benchmark_res_per_rank [0 ].short_name ,
1148
+ gpu_elapsed_time = benchmark_res_per_rank [0 ].gpu_elapsed_time ,
1149
+ cpu_elapsed_time = benchmark_res_per_rank [0 ].cpu_elapsed_time ,
1150
+ mem_stats = [MemoryStats (rank , 0 , 0 , 0 ) for rank in range (world_size )],
1151
+ rank = 0 ,
1102
1152
)
1103
1153
1104
1154
for res in benchmark_res_per_rank :
0 commit comments