Skip to content

Commit 2a2d5d2

Browse files
authored
Replace torch.cuda.Event with torch.Event for better hardware compatibility (#26985)
Signed-off-by: Kunshang Ji <[email protected]>
1 parent c3e2978 commit 2a2d5d2

15 files changed

+41
-48
lines changed

benchmarks/kernels/benchmark_cutlass_moe_fp8.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,8 +255,8 @@ def bench_cuda_graph(graph, num_warmup=5, num_iters=100):
255255
torch.cuda.synchronize()
256256

257257
# Timing
258-
start_event = torch.cuda.Event(enable_timing=True)
259-
end_event = torch.cuda.Event(enable_timing=True)
258+
start_event = torch.Event(enable_timing=True)
259+
end_event = torch.Event(enable_timing=True)
260260

261261
latencies = []
262262
for _ in range(num_iters):

benchmarks/kernels/benchmark_moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,8 @@ def run():
185185
graph.replay()
186186
torch.cuda.synchronize()
187187

188-
start_event = torch.cuda.Event(enable_timing=True)
189-
end_event = torch.cuda.Event(enable_timing=True)
188+
start_event = torch.Event(enable_timing=True)
189+
end_event = torch.Event(enable_timing=True)
190190

191191
latencies: list[float] = []
192192
for i in range(num_iters):

benchmarks/kernels/benchmark_moe_permute_unpermute.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,8 @@ def run():
105105
graph.replay()
106106
torch.cuda.synchronize()
107107

108-
start_event = torch.cuda.Event(enable_timing=True)
109-
end_event = torch.cuda.Event(enable_timing=True)
108+
start_event = torch.Event(enable_timing=True)
109+
end_event = torch.Event(enable_timing=True)
110110

111111
latencies: list[float] = []
112112
for i in range(num_iters):
@@ -241,8 +241,8 @@ def run(input: tuple):
241241
graph.replay()
242242
torch.cuda.synchronize()
243243

244-
start_event = torch.cuda.Event(enable_timing=True)
245-
end_event = torch.cuda.Event(enable_timing=True)
244+
start_event = torch.Event(enable_timing=True)
245+
end_event = torch.Event(enable_timing=True)
246246

247247
latencies: list[float] = []
248248
for i in range(num_iters):

benchmarks/kernels/benchmark_per_token_group_quant.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ def _time_cuda(
3030
fn()
3131
torch.cuda.synchronize()
3232

33-
start = torch.cuda.Event(enable_timing=True)
34-
end = torch.cuda.Event(enable_timing=True)
33+
start = torch.Event(enable_timing=True)
34+
end = torch.Event(enable_timing=True)
3535

3636
start.record()
3737
for _ in range(bench_iters):

benchmarks/kernels/benchmark_silu_mul_fp8_quant.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,8 +253,8 @@ def generate_expert_loads(n_e, total_tokens, ratio, device="cuda"):
253253
)
254254
torch.cuda.synchronize()
255255

256-
start_event = torch.cuda.Event(enable_timing=True)
257-
end_event = torch.cuda.Event(enable_timing=True)
256+
start_event = torch.Event(enable_timing=True)
257+
end_event = torch.Event(enable_timing=True)
258258

259259
# Benchmark
260260
latencies: list[float] = []

benchmarks/kernels/benchmark_trtllm_decode_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,8 @@ def benchmark_decode(
127127

128128
def time_fn(fn, warmup=10, trials=20):
129129
torch.cuda.synchronize()
130-
start = torch.cuda.Event(enable_timing=True)
131-
end = torch.cuda.Event(enable_timing=True)
130+
start = torch.Event(enable_timing=True)
131+
end = torch.Event(enable_timing=True)
132132
times = []
133133
for i in range(warmup):
134134
fn()

benchmarks/kernels/benchmark_trtllm_prefill_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,8 @@ def benchmark_prefill(
139139

140140
def time_fn(fn, warmup=10, trials=20):
141141
torch.cuda.synchronize()
142-
start = torch.cuda.Event(enable_timing=True)
143-
end = torch.cuda.Event(enable_timing=True)
142+
start = torch.Event(enable_timing=True)
143+
end = torch.Event(enable_timing=True)
144144
times = []
145145
for i in range(warmup):
146146
fn()

benchmarks/kernels/benchmark_w8a8_block_fp8.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,8 @@ def run():
183183
run()
184184
torch.cuda.synchronize()
185185

186-
start_event = torch.cuda.Event(enable_timing=True)
187-
end_event = torch.cuda.Event(enable_timing=True)
186+
start_event = torch.Event(enable_timing=True)
187+
end_event = torch.Event(enable_timing=True)
188188

189189
latencies: list[float] = []
190190
for i in range(num_iters):

tests/kernels/attention/test_merge_attn_states.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,8 @@ def test_merge_attn_states(
150150
output_torch = output.clone()
151151
output_lse_torch = output_lse.clone()
152152
total_time_torch_kernel = 0
153-
start = torch.cuda.Event(enable_timing=True)
154-
end = torch.cuda.Event(enable_timing=True)
153+
start = torch.Event(enable_timing=True)
154+
end = torch.Event(enable_timing=True)
155155

156156
# 0. Run the Torch kernel
157157
prefix_lse_torch = prefix_lse.clone()
@@ -188,8 +188,8 @@ def test_merge_attn_states(
188188
output_lse_ref_triton = output_lse.clone()
189189

190190
total_time_triton_kernel = 0
191-
start = torch.cuda.Event(enable_timing=True)
192-
end = torch.cuda.Event(enable_timing=True)
191+
start = torch.Event(enable_timing=True)
192+
end = torch.Event(enable_timing=True)
193193

194194
for _ in range(warmup_times):
195195
merge_attn_states_triton(

vllm/v1/kv_offload/worker/cpu_gpu.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@ def __init__(
6868
self.h2d_stream = torch.cuda.Stream()
6969

7070
# job_id -> transfer cuda event
71-
self.transfer_events: dict[int, torch.cuda.Event] = {}
71+
self.transfer_events: dict[int, torch.Event] = {}
7272
# list of cuda events available for re-use
73-
self.events_pool: list[torch.cuda.Event] = []
73+
self.events_pool: list[torch.Event] = []
7474

7575
pin_memory = is_pin_memory_available()
7676

@@ -153,7 +153,7 @@ def transfer_async(self, job_id: int, spec: TransferSpec) -> bool:
153153
)
154154
src_to_dst_tensor = torch.from_numpy(src_to_dst)
155155

156-
event = self.events_pool.pop() if self.events_pool else torch.cuda.Event()
156+
event = self.events_pool.pop() if self.events_pool else torch.Event()
157157
with torch.cuda.stream(stream):
158158
for src_tensor, dst_tensor, kv_dim in zip(
159159
src_tensors, dst_tensors, self.kv_dim_before_num_blocks

0 commit comments

Comments
 (0)