Skip to content

Commit 7bdcc6b

Browse files
authored
[triton_kernels][opt_flags] Add function to reset opt_flags (#8453)
Without resetting opt_flags, the following does not work and gives error `AssertionError: opt_flags already set; please reset to None first`: ``` import torch from triton_kernels.matmul_ogs import matmul_ogs, PrecisionConfig from triton_kernels.matmul_ogs_details.opt_flags import ( make_opt_flags, set_opt_flags, ) from triton_kernels.routing import RoutingData m = 64 n = 128 k = 32 BATCH_SIZE = 1000 dtype = torch.float16 x = torch.randn((BATCH_SIZE, m, k), device="cuda", dtype=dtype) w = torch.randn((BATCH_SIZE, k, n), device="cuda", dtype=dtype) bias = None opt_flags = make_opt_flags( dtype, dtype, dtype, PrecisionConfig(), m, n, k, RoutingData(None, None, BATCH_SIZE, 1), True, False, False, ) set_opt_flags(opt_flags) tri_y = matmul_ogs(x, w, bias) opt_flags.num_warps = 2 set_opt_flags(opt_flags) tri_y = matmul_ogs(x, w, bias) ``` After adding `reset_opt_flags()` before the second call of `set_opt_flags` everything works fine.
1 parent 7c59c1d commit 7bdcc6b

File tree

1 file changed

+4
-0
lines changed
  • python/triton_kernels/triton_kernels/matmul_ogs_details

1 file changed

+4
-0
lines changed

python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,10 @@ def reset_opt_flags_constraints():
321321
global _opt_flags_constraints
322322
_opt_flags_constraints = dict()
323323

324+
def reset_opt_flags():
325+
global _opt_flags
326+
_opt_flags = None
327+
324328
def set_opt_flags(opt_flags: OptFlags):
325329
global _opt_flags
326330
assert not _opt_flags_constraints, "setting constraints is incompatible with manual flags override"

0 commit comments

Comments
 (0)