Skip to content

Commit 3ac28bf

Browse files
SSYernarfacebook-github-bot
authored andcommitted
Added CPU runtime benchmarking and expanded BenchmarkResult to store CPU runtimes (#3206)
Summary: Pull Request resolved: #3206 * Added feature to benchmark CPU runtimes alongside with GPU measurements. (`time.process_time_ns()` measures the sum of system and user CPU time of the current process in nanoseconds, excluding time spent sleeping or waiting for I/O operations, making it ideal for measuring pure CPU computation time.) * Expanded `BenchmarkResult` class to store both device measurements * Adapted files that are importing `BenchmarkResult` to ensure compatibility Now we can compare CPU and GPU runtimes without running and analyzing PyTorch Profiler. BenchmarkResults can help users to detect if the module/function/operator is CPU-bounded or GPU-bounded. Example metrics of FBGEMM operators: | Operator | CPU Runtime | GPU Runtime | GPU Memory | |---------------------------------------------|-------------|-------------|------------| | **[fallback] pytorch generic** | 5.41 ms | 2.66 ms | 1.01 GB | | **[Prod] KeyedTensor.regroup_dup** | 2.13 ms | 2.48 ms | 1.01 GB | | **[Module] KTRegroupAsDict_dup** | 0.14 ms | 0.75 ms | 1.01 GB | | **[2 Ops] permute_multi_embs_dup** | 0.88 ms | 1.44 ms | 1.01 GB | | **[1 Op] KT_regroup_dup** | 0.99 ms | 1.54 ms | 1.01 GB | We can see that `[fallback] pytorch generic` is CPU-bounded Reviewed By: aliafzal Differential Revision: D78503319 fbshipit-source-id: d038bcb424880b258f6f765ca0c632c34d1ccc7b
1 parent f288fa0 commit 3ac28bf

File tree

3 files changed

+87
-37
lines changed

3 files changed

+87
-37
lines changed

torchrec/distributed/benchmark/benchmark_utils.py

Lines changed: 82 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -135,26 +135,43 @@ def __str__(self) -> str:
135135
class BenchmarkResult:
136136
"Class for holding results of benchmark runs"
137137
short_name: str
138-
elapsed_time: torch.Tensor # milliseconds
138+
gpu_elapsed_time: torch.Tensor # milliseconds
139+
cpu_elapsed_time: torch.Tensor # milliseconds
139140
mem_stats: List[MemoryStats] # memory stats per rank
140141
rank: int = -1
141142

142143
def __str__(self) -> str:
143-
runtime = f"Runtime (P90): {self.runtime_percentile(90):.2f} ms"
144+
gpu_runtime = (
145+
f"GPU Runtime (P90): {self.runtime_percentile(90, device='gpu'):.2f} ms"
146+
)
147+
cpu_runtime = (
148+
f"CPU Runtime (P90): {self.runtime_percentile(90, device='cpu'):.2f} ms"
149+
)
144150
if len(self.mem_stats) == 0:
145-
return f"{self.short_name: <{35}} | {runtime}"
151+
return f"{self.short_name: <{35}} | {gpu_runtime} | {cpu_runtime}"
146152
mem_alloc = (
147153
f"Peak Memory alloc (P90): {self.max_mem_alloc_percentile(90)/1000:.2f} GB"
148154
)
149155
mem_reserved = f"Peak Memory reserved (P90): {self.max_mem_reserved_percentile(90)/1000:.2f} GB"
150-
malloc_retries = f"Malloc retries (P50/P90/P100): {self.mem_retries(50) } / {self.mem_retries(90)} / {self.mem_retries(100)}"
151-
return f"{self.short_name: <{35}} | {malloc_retries} | {runtime} | {mem_alloc} | {mem_reserved}"
156+
malloc_retries = f"Malloc retries (P50/P90/P100): {self.mem_retries(50)} / {self.mem_retries(90)} / {self.mem_retries(100)}"
157+
return f"{self.short_name: <{35}} | {malloc_retries} | {gpu_runtime} | {cpu_runtime} | {mem_alloc} | {mem_reserved}"
152158

153159
def runtime_percentile(
154-
self, percentile: int = 50, interpolation: str = "nearest"
160+
self,
161+
percentile: int = 50,
162+
interpolation: str = "nearest",
163+
device: str = "gpu",
155164
) -> torch.Tensor:
165+
"""Return the runtime percentile for the requested timer.
166+
167+
Args:
168+
percentile: Percentile to compute.
169+
interpolation: See ``torch.quantile``.
170+
device: 'gpu' for CUDA event timings, 'cpu' for active CPU timings.
171+
"""
172+
timings = self.gpu_elapsed_time if device == "gpu" else self.cpu_elapsed_time
156173
return torch.quantile(
157-
self.elapsed_time,
174+
timings,
158175
percentile / 100.0,
159176
interpolation=interpolation,
160177
)
@@ -409,17 +426,26 @@ def write_report(
409426
num_requests: int,
410427
) -> None:
411428
for benchmark_res in benchmark_results:
412-
avg_dur_s = benchmark_res.elapsed_time.mean().item() * 1e-3 # time in seconds
413-
std_dur_s = benchmark_res.elapsed_time.std().item() * 1e-3 # time in seconds
429+
# GPU statistics
430+
avg_dur_s_gpu = benchmark_res.gpu_elapsed_time.mean().item() * 1e-3 # sec
431+
std_dur_s_gpu = benchmark_res.gpu_elapsed_time.std().item() * 1e-3 # sec
432+
433+
# CPU statistics
434+
avg_dur_s_cpu = benchmark_res.cpu_elapsed_time.mean().item() * 1e-3 # sec
435+
std_dur_s_cpu = benchmark_res.cpu_elapsed_time.std().item() * 1e-3 # sec
414436

415-
qps = int(num_requests / avg_dur_s)
437+
qps_gpu = int(num_requests / avg_dur_s_gpu)
416438

417439
mem_str = ""
418440
for memory_stats in benchmark_res.mem_stats:
419441
mem_str += f"{memory_stats}\n"
420442

421-
report_str += f"{benchmark_res.short_name:40} Avg QPS:{qps:10} Avg Duration: {int(1000*avg_dur_s):5}"
422-
report_str += f"ms Standard Dev Duration: {(1000*std_dur_s):.2f}ms\n"
443+
report_str += (
444+
f"{benchmark_res.short_name:40} "
445+
f"Avg QPS(GPU):{qps_gpu:10} "
446+
f"GPU Avg: {int(1000*avg_dur_s_gpu):5}ms ±{(1000*std_dur_s_gpu):.2f}ms "
447+
f"CPU Avg: {int(1000*avg_dur_s_cpu):5}ms ±{(1000*std_dur_s_cpu):.2f}ms\n"
448+
)
423449
report_str += f"\tMemory Allocated Per Rank:\n\t{mem_str}\n"
424450

425451
with open(report_file, "w") as f:
@@ -731,44 +757,63 @@ def _run_benchmark_core(
731757
if reset_accumulated_memory_stats:
732758
torch.cuda.reset_accumulated_memory_stats(rank)
733759

734-
# Optional allocator warm-up to create fragmentation similar to production
735-
if pre_gpu_load and device_type == "cuda":
736-
_tmp = torch.rand(16384, 16384, device="cuda")
737-
for _ in range(pre_gpu_load):
738-
_tmp = _tmp * torch.rand(16384, 16384, device="cuda")
760+
# Optional allocator warm-up to create fragmentation similar to production
761+
if pre_gpu_load:
762+
_tmp = torch.rand(16384, 16384, device="cuda")
763+
for _ in range(pre_gpu_load):
764+
_tmp = _tmp * torch.rand(16384, 16384, device="cuda")
739765

740-
# Timing loop
766+
# Timings
741767
start_events, end_events, times = [], [], []
768+
742769
if device_type == "cuda":
743770
start_events = [
744771
torch.cuda.Event(enable_timing=True) for _ in range(num_benchmarks)
745772
]
746773
end_events = [
747774
torch.cuda.Event(enable_timing=True) for _ in range(num_benchmarks)
748775
]
776+
# Capture per-iteration active CPU cycles (excludes time the thread is truly idle/asleep) using `process_time_ns`.
777+
cpu_times_active_ns: List[int] = []
778+
749779
for i in range(num_benchmarks):
780+
# Ensure that outstanding GPU work from the previous iteration has
781+
# finished so that we do not attribute its wait time to the next
782+
# CPU measurement.
783+
if i > 0:
784+
torch.cuda.synchronize(rank if rank >= 0 else 0)
785+
750786
start_events[i].record()
787+
cpu_start_active_ns = time.process_time_ns()
788+
751789
run_iter_fn()
790+
791+
cpu_end_active_ns = time.process_time_ns()
752792
end_events[i].record()
753-
else:
754-
times = timeit.repeat(run_iter_fn, number=1, repeat=num_benchmarks)
793+
cpu_times_active_ns.append(cpu_end_active_ns - cpu_start_active_ns)
755794

756-
# Make sure all kernels are finished before reading timers / stats
757-
if device_type == "cuda":
795+
# Convert to milliseconds and drop the first iteration
796+
cpu_elapsed_time = torch.tensor(
797+
[t / 1e6 for t in cpu_times_active_ns[1:]], dtype=torch.float
798+
)
799+
800+
# Make sure all kernels are finished before reading timers / stats
758801
if rank == -1:
759802
for di in range(world_size):
760803
torch.cuda.synchronize(di)
761804
else:
762805
torch.cuda.synchronize(rank)
763806

764-
# First Benchmark Run for Eager Mode produces outlier
765-
# Start counting after first as workaround for standard deviation
766-
if device_type == "cuda":
767-
elapsed_time = torch.tensor(
807+
gpu_elapsed_time = torch.tensor(
768808
[s.elapsed_time(e) for s, e in zip(start_events[1:], end_events[1:])]
769809
)
770810
else:
771-
elapsed_time = torch.tensor(times) * 1e3 # convert seconds ➜ milliseconds
811+
# For CPU-only benchmarks we fall back to wall-clock timing via ``timeit``.
812+
times = timeit.repeat(run_iter_fn, number=1, repeat=num_benchmarks)
813+
cpu_elapsed_time = torch.tensor(times) * 1e3 # convert to ms
814+
815+
# mirror CPU timings for overall consistency
816+
gpu_elapsed_time = cpu_elapsed_time.clone()
772817

773818
# Memory statistics collection
774819
mem_stats: List[MemoryStats] = []
@@ -820,7 +865,11 @@ def _trace_handler(prof: torch.profiler.profile) -> None:
820865
torch.cuda.synchronize(rank)
821866

822867
return BenchmarkResult(
823-
short_name=name, elapsed_time=elapsed_time, mem_stats=mem_stats, rank=rank
868+
short_name=name,
869+
gpu_elapsed_time=gpu_elapsed_time,
870+
cpu_elapsed_time=cpu_elapsed_time,
871+
mem_stats=mem_stats,
872+
rank=rank,
824873
)
825874

826875

@@ -1095,10 +1144,11 @@ def setUp() -> None:
10951144
assert 0 == p.exitcode
10961145

10971146
total_benchmark_res = BenchmarkResult(
1098-
benchmark_res_per_rank[0].short_name,
1099-
benchmark_res_per_rank[0].elapsed_time,
1100-
[MemoryStats(rank, 0, 0, 0) for rank in range(world_size)],
1101-
0,
1147+
short_name=benchmark_res_per_rank[0].short_name,
1148+
gpu_elapsed_time=benchmark_res_per_rank[0].gpu_elapsed_time,
1149+
cpu_elapsed_time=benchmark_res_per_rank[0].cpu_elapsed_time,
1150+
mem_stats=[MemoryStats(rank, 0, 0, 0) for rank in range(world_size)],
1151+
rank=0,
11021152
)
11031153

11041154
for res in benchmark_res_per_rank:

torchrec/sparse/tests/jagged_tensor_benchmark.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,14 +107,13 @@ def wrapped_func(
107107
)
108108
result = BenchmarkResult(
109109
short_name=name,
110-
elapsed_time=torch.tensor(times) * 1e3,
110+
gpu_elapsed_time=torch.tensor(times) * 1e3,
111+
cpu_elapsed_time=torch.tensor(times) * 1e3,
111112
mem_stats=[MemoryStats(0, 0, 0, 0)],
112113
)
113114

114-
mem_alloc = f"Memory alloc (P90): {result.max_mem_alloc_percentile(90):5.1f}"
115-
mem_reserved = f"Memory alloc (P90): {result.max_mem_reserved_percentile(90):5.1f}"
116115
print(
117-
f" {name : <{30}} | B: {batch_size : <{8}} | F: {feature_count : <{8}} | device: {device_type : <{8}} | Runtime (P90): {result.runtime_percentile(90):5.2f} ms | {mem_alloc} | {mem_reserved}"
116+
f"B: {batch_size : <{8}} | F: {feature_count : <{8}} | device: {device_type : <{8}} | {result}"
118117
)
119118

120119

torchrec/sparse/tests/keyed_jagged_tensor_benchmark_lib.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,8 @@ def benchmark_kjt(
227227

228228
result = BenchmarkResult(
229229
short_name=f"{test_name}-{transform_type.name}",
230-
elapsed_time=torch.tensor(times),
230+
gpu_elapsed_time=torch.tensor(times),
231+
cpu_elapsed_time=torch.tensor(times),
231232
mem_stats=[MemoryStats(0, 0, 0, 0)],
232233
)
233234

0 commit comments

Comments
 (0)