We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 1989b7d commit d0893ddCopy full SHA for d0893dd
bench/bench/bench_mlp.py
@@ -10,13 +10,14 @@
10
from triton_bench.numerics import InFlexData
11
from triton_bench.routing import routing_torch, simulate_expert_sharded_routing
12
13
+target = triton.runtime.driver.active.get_current_target()
14
+
15
16
def is_hip_cdna4():
- target = triton.runtime.driver.active.get_current_target()
17
return target.backend == 'hip' and target.arch == 'gfx950'
18
19
-if torch.cuda.is_available() and not is_hip_cdna4():
20
+if torch.cuda.is_available() and not target.backend == "hip":
21
from triton._C.libtriton import nvidia
22
cublas_workspace = torch.empty(32 * 1024 * 1024, device="cuda", dtype=torch.uint8)
23
cublas = nvidia.cublas.CublasLt(cublas_workspace)
0 commit comments