You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[triton_kernels][opt_flags] Add function to reset opt_flags (#8453)
Without resetting opt_flags, the following does not work and gives error
`AssertionError: opt_flags already set; please reset to None first`:
```
import torch
from triton_kernels.matmul_ogs import matmul_ogs, PrecisionConfig
from triton_kernels.matmul_ogs_details.opt_flags import (
make_opt_flags,
set_opt_flags,
)
from triton_kernels.routing import RoutingData
m = 64
n = 128
k = 32
BATCH_SIZE = 1000
dtype = torch.float16
x = torch.randn((BATCH_SIZE, m, k), device="cuda", dtype=dtype)
w = torch.randn((BATCH_SIZE, k, n), device="cuda", dtype=dtype)
bias = None
opt_flags = make_opt_flags(
dtype,
dtype,
dtype,
PrecisionConfig(),
m,
n,
k,
RoutingData(None, None, BATCH_SIZE, 1),
True,
False,
False,
)
set_opt_flags(opt_flags)
tri_y = matmul_ogs(x, w, bias)
opt_flags.num_warps = 2
set_opt_flags(opt_flags)
tri_y = matmul_ogs(x, w, bias)
```
After adding `reset_opt_flags()` before the second call of
`set_opt_flags` everything works fine.
0 commit comments