Skip to content

Commit d0893dd

Browse files
committed
Update
1 parent 1989b7d commit d0893dd

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

bench/bench/bench_mlp.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,14 @@
1010
from triton_bench.numerics import InFlexData
1111
from triton_bench.routing import routing_torch, simulate_expert_sharded_routing
1212

13+
target = triton.runtime.driver.active.get_current_target()
14+
1315

1416
def is_hip_cdna4():
15-
target = triton.runtime.driver.active.get_current_target()
1617
return target.backend == 'hip' and target.arch == 'gfx950'
1718

1819

19-
if torch.cuda.is_available() and not is_hip_cdna4():
20+
if torch.cuda.is_available() and not target.backend == "hip":
2021
from triton._C.libtriton import nvidia
2122
cublas_workspace = torch.empty(32 * 1024 * 1024, device="cuda", dtype=torch.uint8)
2223
cublas = nvidia.cublas.CublasLt(cublas_workspace)

0 commit comments

Comments
 (0)