Skip to content

Commit 11d2077

Browse files
authored
[BACKEND] Add one more pass plugin readme example (#8815)
1 parent 78b1d25 commit 11d2077

File tree

1 file changed

+27
-1
lines changed

1 file changed

+27
-1
lines changed

lib/Plugins/README.md

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,33 @@ This functionality can be toggled on and off by just commenting out this line in
199199
knobs.runtime.add_stages_inspection_hook = inspect_stages_hook
200200
without needing any core compiler changes or rebuilding Triton.
201201

202-
## Example 3: Fully customizing the compiler pipeline with pass and op insertions at abitrary locations
202+
## Example 3: Inserting a new pass into the compiler pipeline at an arbitary point.
203+
204+
Example 2 added a new pass to the end of the ttgir "stage". However the plugin pass's location is arbitary and can be dynamically inserted anywhere in the pipeline. Replacing the inspect_stages_hook function from example 2 instead with:
205+
206+
```python
207+
def inspect_stages_hook(self=None, stages=None, options=None, language=None, capability=None):
208+
if all(arg is None for arg in (stages, options, language, capability)):
209+
return get_key(), get_hash()
210+
module_name = 'dynamic_module'
211+
spec = importlib.util.spec_from_loader(module_name, loader=None)
212+
module = importlib.util.module_from_spec(spec)
213+
sys.modules[module_name] = module
214+
stage_src = textwrap.dedent(inspect.getsource(self.make_ttir))
215+
stage_src = 'from triton._C.libtriton import ir, passes, llvm, amd, nvidia\n' + stage_src
216+
# Inject plugin pass right after loop unroll in the dynamically loaded stage source
217+
stage_src = stage_src.replace(
218+
"passes.ttir.add_loop_unroll(pm)",
219+
"passes.ttir.add_loop_unroll(pm)\n passes.plugin.add_plugin(pm)"
220+
)
221+
exec(stage_src, module.__dict__)
222+
make_lambda = lambda f: lambda src, metadata: f(src, metadata, options, capability)
223+
stages["ttir"] = make_lambda(module.make_ttir)
224+
return get_key(), get_hash()
225+
```
226+
directs the new pass's placement based on other surrounding passes. Knowing which passes are in the pipeline a priori can challenging, therefore in the next example we show how to dump and inspect the entire pipeline that is run for a particlar kernel to allow for precise placement of specialized out of tree passes even if the upstream pass pipeline structure changes.
227+
228+
## Example 4: Fully customizing the compiler pipeline with pass and op insertions at abitrary locations
203229

204230
Here we now run two kernels one with the full standard Triton pipeline and one with fully customized pipeline entirely from within
205231
kernel code with modifying any core Triton compiler code or recompiling. We run the kernel with a hook to output the standard pipeline, modify

0 commit comments

Comments
 (0)