@@ -634,18 +634,42 @@ <h2>Submodules<a class="headerlink" href="#submodules" title="Permalink to this
634
634
635
635
< dl class ="py function ">
636
636
< dt class ="sig sig-object py " id ="tilelang.autotuner.autotune ">
637
- < span class ="sig-prename descclassname "> < span class ="pre "> tilelang.autotuner.</ span > </ span > < span class ="sig-name descname "> < span class ="pre "> autotune</ span > </ span > < span class ="sig-paren "> (</ span > < em class ="sig-param "> < span class ="n "> < span class ="pre "> func</ span > </ span > < span class ="p "> < span class ="pre "> :</ span > </ span > < span class ="w "> </ span > < span class ="n "> < span class ="pre "> Optional</ span > < span class ="p "> < span class ="pre "> [</ span > </ span > < span class ="pre "> Union</ span > < span class ="p "> < span class ="pre "> [</ span > </ span > < span class ="pre "> Callable</ span > < span class ="p "> < span class ="pre "> [</ span > </ span > < span class ="p "> < span class ="pre "> [</ span > </ span > < span class ="pre "> _P</ span > < span class ="p "> < span class ="pre "> ]</ span > </ span > < span class ="p "> < span class ="pre "> ,</ span > </ span > < span class ="w "> </ span > < span class ="pre "> _RProg</ span > < span class ="p "> < span class ="pre "> ]</ span > </ span > < span class ="p "> < span class ="pre "> ,</ span > </ span > < span class ="w "> </ span > < span class ="pre "> PrimFunc</ span > < span class ="p "> < span class ="pre "> ]</ span > </ span > < span class ="p "> < span class ="pre "> ]</ span > </ span > </ span > < span class ="w "> </ span > < span class ="o "> < span class ="pre "> =</ span > </ span > < span class ="w "> </ span > < span class ="default_value "> < span class ="pre "> None</ span > </ span > </ em > , < em class ="sig-param "> < span class ="o "> < span class ="pre "> *</ span > </ span > </ em > , < em class ="sig-param "> < span class ="n "> < span class ="pre "> configs</ span > </ span > < span class ="p "> < span class ="pre "> :</ span > </ span > < span class ="w "> </ span > < span class ="n "> < span class ="pre "> Any</ span > </ span > </ em > , < em class ="sig-param "> < span class ="n "> < span class ="pre "> warmup</ span > </ span > < span class ="p "> < span class ="pre "> :</ span > </ span > < span class ="w "> </ span > < span class ="n "> < span class ="pre "> int</ span > </ span > < span class ="w "> </ span > < span class ="o "> < span class ="pre "> =</ span > </ span > < span class ="w "> </ span > < span class ="default_value "> < span class ="pre "> 25</ span > </ span > </ em > , < em class ="sig-param "> < span class ="n "> < span class ="pre "> rep</ span > </ span > < span class ="p "> < span class ="pre "> :</ span > </ span > < span class ="w "> </ span > < span class ="n "> < span class ="pre "> int</ span > </ span > < span class ="w "> </ span > < span class ="o "> < span class ="pre "> =</ span > </ span > < span class ="w "> </ span > < span class ="default_value "> < span class ="pre "> 100</ span > </ span > </ em > , < em class ="sig-param "> < span class ="n "> < span class ="pre "> timeout</ span > </ span > < span class ="p "> < span class ="pre "> :</ span > </ span > < span class ="w "> </ span > < span class ="n "> < span class ="pre "> int</ span > </ span > < span class ="w "> </ span > < span class ="o "> < span class ="pre "> =</ span > </ span > < span class ="w "> </ span > < span class ="default_value "> < span class ="pre "> 100</ span > </ span > </ em > , < em class ="sig-param "> < span class ="n "> < span class ="pre "> supply_type</ span > </ span > < span class ="p "> < span class ="pre "> :</ span > </ span > < span class ="w "> </ span > < span class ="n "> < a class ="reference internal " href ="tilelang.utils.tensor.html#tilelang.utils.tensor.TensorSupplyType " title ="tilelang.utils.tensor.TensorSupplyType "> < span class ="pre "> TensorSupplyType</ span > </ a > </ span > < span class ="w "> </ span > < span class ="o "> < span class ="pre "> =</ span > </ span > < span class ="w "> </ span > < span class ="default_value "> < span class ="pre "> TensorSupplyType.Auto</ span > </ span > </ em > , < em class ="sig-param "> < span class ="n "> < span class ="pre "> ref_prog</ span > </ span > < span class ="p "> < span class ="pre "> :</ span > </ span > < span class ="w "> </ span > < span class ="n "> < span class ="pre "> Optional</ span > < span class ="p "> < span class ="pre "> [</ span > </ span > < span class ="pre "> Callable</ span > < span class ="p "> < span class ="pre "> ]</ span > </ span > </ span > < span class ="w "> </ span > < span class ="o "> < span class ="pre "> =</ span > </ span > < span class ="w "> </ span > < span class ="default_value "> < span class ="pre "> None</ span > </ span > </ em > , < em class ="sig-param "> < span class ="n "> < span class ="pre "> supply_prog</ span > </ span > < span class ="p "> < span class ="pre "> :</ span > </ span > < span class ="w "> </ span > < span class ="n "> < span class ="pre "> Optional</ span > < span class ="p "> < span class ="pre "> [</ span > </ span > < span class ="pre "> Callable</ span > < span class ="p "> < span class ="pre "> ]</ span > </ span > </ span > < span class ="w "> </ span > < span class ="o "> < span class ="pre "> =</ span > </ span > < span class ="w "> </ span > < span class ="default_value "> < span class ="pre "> None</ span > </ span > </ em > , < em class ="sig-param "> < span class ="n "> < span class ="pre "> rtol</ span > </ span > < span class ="p "> < span class ="pre "> :</ span > </ span > < span class ="w "> </ span > < span class ="n "> < span class ="pre "> float</ span > </ span > < span class ="w "> </ span > < span class ="o "> < span class ="pre "> =</ span > </ span > < span class ="w "> </ span > < span class ="default_value "> < span class ="pre "> 0.01</ span > </ span > </ em > , < em class ="sig-param "> < span class ="n "> < span class ="pre "> atol</ span > </ span > < span class ="p "> < span class ="pre "> :</ span > </ span > < span class ="w "> </ span > < span class ="n "> < span class ="pre "> float</ span > </ span > < span class ="w "> </ span > < span class ="o "> < span class ="pre "> =</ span > </ span > < span class ="w "> </ span > < span class ="default_value "> < span class ="pre "> 0.01</ span > </ span > </ em > , < em class ="sig-param "> < span class ="n "> < span class ="pre "> max_mismatched_ratio</ span > </ span > < span class ="p "> < span class ="pre "> :</ span > </ span > < span class ="w "> </ span > < span class ="n "> < span class ="pre "> float</ span > </ span > < span class ="w "> </ span > < span class ="o "> < span class ="pre "> =</ span > </ span > < span class ="w "> </ span > < span class ="default_value "> < span class ="pre "> 0.01</ span > </ span > </ em > , < em class ="sig-param "> < span class ="n "> < span class ="pre "> skip_check</ span > </ span > < span class ="p "> < span class ="pre "> :</ span > </ span > < span class ="w "> </ span > < span class ="n "> < span class ="pre "> bool</ span > </ span > < span class ="w "> </ span > < span class ="o "> < span class ="pre "> =</ span > </ span > < span class ="w "> </ span > < span class ="default_value "> < span class ="pre "> False</ span > </ span > </ em > , < em class ="sig-param "> < span class ="n "> < span class ="pre "> manual_check_prog</ span > </ span > < span class ="p "> < span class ="pre "> :</ span > </ span > < span class ="w "> </ span > < span class ="n "> < span class ="pre "> Optional</ span > < span class ="p "> < span class ="pre "> [</ span > </ span > < span class ="pre "> Callable</ span > < span class ="p "> < span class ="pre "> ]</ span > </ span > </ span > < span class ="w "> </ span > < span class ="o "> < span class ="pre "> =</ span > </ span > < span class ="w "> </ span > < span class ="default_value "> < span class ="pre "> None</ span > </ span > </ em > , < em class ="sig-param "> < span class ="n "> < span class ="pre "> cache_input_tensors</ span > </ span > < span class ="p "> < span class ="pre "> :</ span > </ span > < span class ="w "> </ span > < span class ="n "> < span class ="pre "> bool</ span > </ span > < span class ="w "> </ span > < span class ="o "> < span class ="pre "> =</ span > </ span > < span class ="w "> </ span > < span class ="default_value "> < span class ="pre "> False</ span > </ span > </ em > < span class ="sig-paren "> )</ span > < a class ="headerlink " href ="#tilelang.autotuner.autotune " title ="Permalink to this definition "> #</ a > </ dt >
637
+ < span class ="sig-prename descclassname "> < span class ="pre "> tilelang.autotuner.</ span > </ span > < span class ="sig-name descname "> < span class ="pre "> autotune</ span > </ span > < span class ="sig-paren "> (</ span > < em class ="sig-param "> < span class ="n "> < span class ="pre "> func</ span > </ span > < span class ="p "> < span class ="pre "> :</ span > </ span > < span class ="w "> </ span > < span class ="n "> < span class ="pre "> Optional</ span > < span class ="p "> < span class ="pre "> [</ span > </ span > < span class ="pre "> Union</ span > < span class ="p "> < span class ="pre "> [</ span > </ span > < span class ="pre "> Callable</ span > < span class ="p "> < span class ="pre "> [</ span > </ span > < span class ="p "> < span class ="pre "> [</ span > </ span > < span class ="pre "> _P</ span > < span class ="p "> < span class ="pre "> ]</ span > </ span > < span class ="p "> < span class ="pre "> ,</ span > </ span > < span class ="w "> </ span > < span class ="pre "> _RProg</ span > < span class ="p "> < span class ="pre "> ]</ span > </ span > < span class ="p "> < span class ="pre "> ,</ span > </ span > < span class ="w "> </ span > < span class ="pre "> PrimFunc</ span > < span class ="p "> < span class ="pre "> ]</ span > </ span > < span class ="p "> < span class ="pre "> ]</ span > </ span > </ span > < span class ="w "> </ span > < span class ="o "> < span class ="pre "> =</ span > </ span > < span class ="w "> </ span > < span class ="default_value "> < span class ="pre "> None</ span > </ span > </ em > , < em class ="sig-param "> < span class ="o "> < span class ="pre "> *</ span > </ span > </ em > , < em class ="sig-param "> < span class ="n "> < span class ="pre "> configs</ span > </ span > < span class ="p "> < span class ="pre "> :</ span > </ span > < span class ="w "> </ span > < span class ="n "> < span class ="pre "> Union</ span > < span class ="p "> < span class ="pre "> [</ span > </ span > < span class ="pre "> Dict</ span > < span class ="p "> < span class ="pre "> ,</ span > </ span > < span class ="w "> </ span > < span class ="pre "> Callable</ span > < span class ="p "> < span class ="pre "> ]</ span > </ span > </ span > </ em > , < em class ="sig-param "> < span class ="n "> < span class ="pre "> warmup</ span > </ span > < span class ="p "> < span class ="pre "> :</ span > </ span > < span class ="w "> </ span > < span class ="n "> < span class ="pre "> int</ span > </ span > < span class ="w "> </ span > < span class ="o "> < span class ="pre "> =</ span > </ span > < span class ="w "> </ span > < span class ="default_value "> < span class ="pre "> 25</ span > </ span > </ em > , < em class ="sig-param "> < span class ="n "> < span class ="pre "> rep</ span > </ span > < span class ="p "> < span class ="pre "> :</ span > </ span > < span class ="w "> </ span > < span class ="n "> < span class ="pre "> int</ span > </ span > < span class ="w "> </ span > < span class ="o "> < span class ="pre "> =</ span > </ span > < span class ="w "> </ span > < span class ="default_value "> < span class ="pre "> 100</ span > </ span > </ em > , < em class ="sig-param "> < span class ="n "> < span class ="pre "> timeout</ span > </ span > < span class ="p "> < span class ="pre "> :</ span > </ span > < span class ="w "> </ span > < span class ="n "> < span class ="pre "> int</ span > </ span > < span class ="w "> </ span > < span class ="o "> < span class ="pre "> =</ span > </ span > < span class ="w "> </ span > < span class ="default_value "> < span class ="pre "> 100</ span > </ span > </ em > , < em class ="sig-param "> < span class ="n "> < span class ="pre "> supply_type</ span > </ span > < span class ="p "> < span class ="pre "> :</ span > </ span > < span class ="w "> </ span > < span class ="n "> < a class ="reference internal " href ="tilelang.utils.tensor.html#tilelang.utils.tensor.TensorSupplyType " title ="tilelang.utils.tensor.TensorSupplyType "> < span class ="pre "> TensorSupplyType</ span > </ a > </ span > < span class ="w "> </ span > < span class ="o "> < span class ="pre "> =</ span > </ span > < span class ="w "> </ span > < span class ="default_value "> < span class ="pre "> TensorSupplyType.Auto</ span > </ span > </ em > , < em class ="sig-param "> < span class ="n "> < span class ="pre "> ref_prog</ span > </ span > < span class ="p "> < span class ="pre "> :</ span > </ span > < span class ="w "> </ span > < span class ="n "> < span class ="pre "> Optional</ span > < span class ="p "> < span class ="pre "> [</ span > </ span > < span class ="pre "> Callable</ span > < span class ="p "> < span class ="pre "> ]</ span > </ span > </ span > < span class ="w "> </ span > < span class ="o "> < span class ="pre "> =</ span > </ span > < span class ="w "> </ span > < span class ="default_value "> < span class ="pre "> None</ span > </ span > </ em > , < em class ="sig-param "> < span class ="n "> < span class ="pre "> supply_prog</ span > </ span > < span class ="p "> < span class ="pre "> :</ span > </ span > < span class ="w "> </ span > < span class ="n "> < span class ="pre "> Optional</ span > < span class ="p "> < span class ="pre "> [</ span > </ span > < span class ="pre "> Callable</ span > < span class ="p "> < span class ="pre "> ]</ span > </ span > </ span > < span class ="w "> </ span > < span class ="o "> < span class ="pre "> =</ span > </ span > < span class ="w "> </ span > < span class ="default_value "> < span class ="pre "> None</ span > </ span > </ em > , < em class ="sig-param "> < span class ="n "> < span class ="pre "> rtol</ span > </ span > < span class ="p "> < span class ="pre "> :</ span > </ span > < span class ="w "> </ span > < span class ="n "> < span class ="pre "> float</ span > </ span > < span class ="w "> </ span > < span class ="o "> < span class ="pre "> =</ span > </ span > < span class ="w "> </ span > < span class ="default_value "> < span class ="pre "> 0.01</ span > </ span > </ em > , < em class ="sig-param "> < span class ="n "> < span class ="pre "> atol</ span > </ span > < span class ="p "> < span class ="pre "> :</ span > </ span > < span class ="w "> </ span > < span class ="n "> < span class ="pre "> float</ span > </ span > < span class ="w "> </ span > < span class ="o "> < span class ="pre "> =</ span > </ span > < span class ="w "> </ span > < span class ="default_value "> < span class ="pre "> 0.01</ span > </ span > </ em > , < em class ="sig-param "> < span class ="n "> < span class ="pre "> max_mismatched_ratio</ span > </ span > < span class ="p "> < span class ="pre "> :</ span > </ span > < span class ="w "> </ span > < span class ="n "> < span class ="pre "> float</ span > </ span > < span class ="w "> </ span > < span class ="o "> < span class ="pre "> =</ span > </ span > < span class ="w "> </ span > < span class ="default_value "> < span class ="pre "> 0.01</ span > </ span > </ em > , < em class ="sig-param "> < span class ="n "> < span class ="pre "> skip_check</ span > </ span > < span class ="p "> < span class ="pre "> :</ span > </ span > < span class ="w "> </ span > < span class ="n "> < span class ="pre "> bool</ span > </ span > < span class ="w "> </ span > < span class ="o "> < span class ="pre "> =</ span > </ span > < span class ="w "> </ span > < span class ="default_value "> < span class ="pre "> False</ span > </ span > </ em > , < em class ="sig-param "> < span class ="n "> < span class ="pre "> manual_check_prog</ span > </ span > < span class ="p "> < span class ="pre "> :</ span > </ span > < span class ="w "> </ span > < span class ="n "> < span class ="pre "> Optional</ span > < span class ="p "> < span class ="pre "> [</ span > </ span > < span class ="pre "> Callable</ span > < span class ="p "> < span class ="pre "> ]</ span > </ span > </ span > < span class ="w "> </ span > < span class ="o "> < span class ="pre "> =</ span > </ span > < span class ="w "> </ span > < span class ="default_value "> < span class ="pre "> None</ span > </ span > </ em > , < em class ="sig-param "> < span class ="n "> < span class ="pre "> cache_input_tensors</ span > </ span > < span class ="p "> < span class ="pre "> :</ span > </ span > < span class ="w "> </ span > < span class ="n "> < span class ="pre "> bool</ span > </ span > < span class ="w "> </ span > < span class ="o "> < span class ="pre "> =</ span > </ span > < span class ="w "> </ span > < span class ="default_value "> < span class ="pre "> False</ span > </ span > </ em > < span class ="sig-paren "> )</ span > < a class ="headerlink " href ="#tilelang.autotuner.autotune " title ="Permalink to this definition "> #</ a > </ dt >
638
638
< dd > < p > Just-In-Time (JIT) compiler decorator for TileLang functions.</ p >
639
- < dl class =" simple " >
639
+ < dl >
640
640
< dt > This decorator can be used without arguments (e.g., < cite > @tilelang.jit</ cite > ):</ dt > < dd > < p > Applies JIT compilation with default settings.</ p >
641
641
</ dd >
642
+ < dt > Tips:</ dt > < dd > < ul >
643
+ < li > < dl >
644
+ < dt > If you want to skip the auto-tuning process, you can set override the tunable parameters in the function signature.</ dt > < dd > < dl class ="simple ">
645
+ < dt > < a href ="#id1 "> < span class ="problematic " id ="id2 "> ``</ span > </ a > < a href ="#id3 "> < span class ="problematic " id ="id4 "> `</ span > </ a > python</ dt > < dd > < dl class ="simple ">
646
+ < dt > if enable_autotune:</ dt > < dd > < p > kernel = flashattn(batch, heads, seq_len, dim, is_causal)</ p >
647
+ </ dd >
648
+ < dt > else:</ dt > < dd > < dl class ="simple ">
649
+ < dt > kernel = flashattn(</ dt > < dd > < p > batch, heads, seq_len, dim, is_causal, groups=groups, block_M=128, block_N=128, num_stages=2, threads=256)</ p >
650
+ </ dd >
651
+ </ dl >
652
+ </ dd >
653
+ </ dl >
654
+ </ dd >
655
+ </ dl >
656
+ < p > < a href ="#id5 "> < span class ="problematic " id ="id6 "> ``</ span > </ a > < a href ="#id7 "> < span class ="problematic " id ="id8 "> `</ span > </ a > </ p >
657
+ </ dd >
658
+ </ dl >
659
+ </ li >
660
+ </ ul >
661
+ </ dd >
642
662
</ dl >
643
663
< dl class ="field-list simple ">
644
664
< dt class ="field-odd "> Parameters< span class ="colon "> :</ span > </ dt >
645
665
< dd class ="field-odd "> < ul class ="simple ">
646
666
< li > < p > < strong > func_or_out_idx</ strong > (< em > Any</ em > < em > , </ em > < em > optional</ em > ) – If using < cite > @tilelang.jit(…)</ cite > to configure, this is the < cite > out_idx</ cite > parameter.
647
667
If using < cite > @tilelang.jit</ cite > directly on a function, this argument is implicitly
648
668
the function to be decorated (and < cite > out_idx</ cite > will be < cite > None</ cite > ).</ p > </ li >
669
+ < li > < p > < strong > configs</ strong > (< em > Dict</ em > < em > or </ em > < em > Callable</ em > ) – Configuration space to explore during auto-tuning.</ p > </ li >
670
+ < li > < p > < strong > warmup</ strong > (< em > int</ em > < em > , </ em > < em > optional</ em > ) – Number of warmup iterations before timing.</ p > </ li >
671
+ < li > < p > < strong > rep</ strong > (< em > int</ em > < em > , </ em > < em > optional</ em > ) – Number of repetitions for timing measurements.</ p > </ li >
672
+ < li > < p > < strong > timeout</ strong > (< em > int</ em > < em > , </ em > < em > optional</ em > ) – </ p > </ li >
649
673
< li > < p > < strong > target</ strong > (< em > Union</ em > < em > [</ em > < em > str</ em > < em > , </ em > < em > Target</ em > < em > ]</ em > < em > , </ em > < em > optional</ em > ) – Compilation target for TVM (e.g., “cuda”, “llvm”). Defaults to “auto”.</ p > </ li >
650
674
< li > < p > < strong > target_host</ strong > (< em > Union</ em > < em > [</ em > < em > str</ em > < em > , </ em > < em > Target</ em > < em > ]</ em > < em > , </ em > < em > optional</ em > ) – Target host for cross-compilation. Defaults to None.</ p > </ li >
651
675
< li > < p > < strong > execution_backend</ strong > (< em > Literal</ em > < em > [</ em > < em > "dlpack"</ em > < em > , </ em > < em > "ctypes"</ em > < em > , </ em > < em > "cython"</ em > < em > ]</ em > < em > , </ em > < em > optional</ em > ) – Backend for kernel execution and argument passing. Defaults to “cython”.</ p > </ li >
0 commit comments