Skip to content

Commit 5ee013c

Browse files
authored
[python][compiler] Memoize device max shared memory per device (#6503)
Similar to #6000 this patch is an upstreamed internal patch at Meta with the goal of reducing our internal patches, cc @jamesjwu the original author. When running various benchmarks with small kernels we see a non-trivial amount of time spent fetching this property, and memoizing helped. It might be worth looking into memoizing `get_device_properties`, but I think that'd need a more careful treatment in the driver package in order to properly handle arbitrary backends.
1 parent 18ef770 commit 5ee013c

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

python/triton/compiler/compiler.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,11 @@ def triton_key():
163163
return f'{__version__}' + '-'.join(contents)
164164

165165

166+
@functools.lru_cache()
167+
def max_shared_mem(device):
168+
return driver.active.utils.get_device_properties(device)["max_shared_mem"]
169+
170+
166171
def parse(full_name, ext, context):
167172
if ext == "ttir" or ext == "ttgir":
168173
module = ir.parse_mlir_module(full_name, context)
@@ -397,7 +402,7 @@ def _init_handles(self):
397402
# create launcher
398403
self.run = driver.active.launcher_cls(self.src, self.metadata)
399404
# not enough shared memory to run the kernel
400-
max_shared = driver.active.utils.get_device_properties(device)["max_shared_mem"]
405+
max_shared = max_shared_mem(device)
401406
if self.metadata.shared > max_shared:
402407
raise OutOfResources(self.metadata.shared, max_shared, "shared memory")
403408
if hasattr(self.metadata, "tmem_size") and self.metadata.tmem_size is not None:

0 commit comments

Comments
 (0)