Skip to content

Commit c508755

Browse files
Update docs
1 parent d9960a6 commit c508755

File tree

23 files changed

+309
-147
lines changed

23 files changed

+309
-147
lines changed

_sources/autoapi/tilelang/carver/arch/cdna/index.rst.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ Module Contents
3030
Bases: :py:obj:`tilelang.carver.arch.arch_base.TileDevice`
3131

3232

33+
Represents the architecture of a computing device, capturing various hardware specifications.
34+
35+
3336
.. py:attribute:: target
3437
3538

_sources/autoapi/tilelang/carver/arch/cpu/index.rst.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ Module Contents
3030
Bases: :py:obj:`tilelang.carver.arch.arch_base.TileDevice`
3131

3232

33+
Represents the architecture of a computing device, capturing various hardware specifications.
34+
35+
3336
.. py:attribute:: target
3437
3538

_sources/autoapi/tilelang/carver/arch/cuda/index.rst.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,9 @@ Module Contents
9292
Bases: :py:obj:`tilelang.carver.arch.arch_base.TileDevice`
9393

9494

95+
Represents the architecture of a computing device, capturing various hardware specifications.
96+
97+
9598
.. py:attribute:: target
9699
97100

_sources/autoapi/tilelang/carver/roller/hint/index.rst.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,12 +307,12 @@ Module Contents
307307
308308
309309
.. py:property:: raxis_order
310-
:type: List[int]
310+
:type: tilelang.carver.roller.rasterization.List[int]
311311

312312

313313

314314
.. py:property:: step
315-
:type: List[int]
315+
:type: tilelang.carver.roller.rasterization.List[int]
316316

317317

318318

_sources/autoapi/tilelang/carver/roller/policy/tensorcore/index.rst.txt

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,15 @@ Module Contents
3030

3131
.. py:data:: logger
3232
33-
.. py:class:: TensorCorePolicy
33+
.. py:class:: TensorCorePolicy(arch, tags = None)
3434
3535
Bases: :py:obj:`tilelang.carver.roller.policy.default.DefaultPolicy`
3636

3737

38+
Default Policy for fastdlight, a heuristic plan that tries to
39+
minimize memory traffic and maximize parallelism.for BitBLAS Schedule.
40+
41+
3842
.. py:attribute:: wmma_k
3943
:type: int
4044
:value: 16
@@ -61,16 +65,67 @@ Module Contents
6165

6266
.. py:method:: infer_node_smem_usage(td, node)
6367
68+
Infers the shared memory usage of a node given a TileDict configuration.
69+
70+
:param td: The TileDict object containing the tile configuration.
71+
:type td: TileDict
72+
:param node: The node for which to infer the shared memory usage.
73+
:type node: PrimFuncNode
74+
75+
:returns: The estimated amount of shared memory used by the node.
76+
:rtype: int
77+
78+
6479

6580
.. py:method:: get_node_reduce_step_candidates(node)
6681
82+
Calculates reduction step candidates for each reduction axis in a PrimFuncNode. General idea : use factor first, since it does not require extra boundary check. for large prime number, which is rare case, use power of 2.
83+
84+
:param node: The node for which to calculate reduction step candidates. It contains reduction axes (raxis)
85+
with their domains (dom.extent).
86+
:type node: PrimFuncNode
87+
88+
:returns: A dictionary mapping axis variable names to lists of step candidates. For each axis in the node,
89+
this function calculates possible step sizes. For axes with a large prime domain, it uses powers of 2
90+
as step candidates; for others, it uses all factors of the domain.
91+
:rtype: Dict[str, List[int]]
92+
93+
6794

6895
.. py:method:: check_tile_shape_isvalid(td)
6996
97+
Checks if the tile shapes in the TileDict are valid for the nodes in this context.
98+
99+
Parameters:
100+
- td (TileDict): The TileDict object containing tile shapes and other configurations.
101+
102+
Returns:
103+
- bool: True if all tile shapes are valid, False otherwise.
104+
105+
70106

71107
.. py:method:: compute_node_stride_map(node, td)
72108
109+
Computes the stride map for a given node based on the TileDict configuration.
110+
111+
:param node: The node for which to compute the stride map.
112+
:type node: PrimFuncNode
113+
:param td: The TileDict object containing the tile configuration.
114+
:type td: TileDict
115+
116+
:returns: A tuple of dictionaries containing the output strides and tensor strides.
117+
:rtype: Tuple[Dict, Dict]
118+
119+
73120

74121
.. py:method:: plan_rasterization(td)
75122
123+
Plans the rasterization for the given TileDict. This function is not implemented yet.
124+
125+
:param td: The TileDict object to plan rasterization for.
126+
:type td: TileDict
127+
128+
:raises RasterRationPlan: This function is not implemented yet.
129+
130+
76131

_sources/autoapi/tilelang/carver/template/flashattention/index.rst.txt

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,6 @@ Module Contents
2020
Bases: :py:obj:`tilelang.carver.template.base.BaseTemplate`
2121

2222

23-
Base class template for hardware-aware configurations.
24-
This serves as an abstract base class (ABC) that defines the structure
25-
for subclasses implementing hardware-specific optimizations.
26-
27-
2823
.. py:attribute:: batch_size
2924
:type: int
3025
:value: 1

_sources/autoapi/tilelang/carver/template/general_reduce/index.rst.txt

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,6 @@ Module Contents
2020
Bases: :py:obj:`tilelang.carver.template.base.BaseTemplate`
2121

2222

23-
Base class template for hardware-aware configurations.
24-
This serves as an abstract base class (ABC) that defines the structure
25-
for subclasses implementing hardware-specific optimizations.
26-
27-
2823
.. py:attribute:: structure
2924
:type: Union[str, List[str]]
3025
:value: None
@@ -45,19 +40,6 @@ Module Contents
4540

4641
.. py:method:: get_hardware_aware_configs(arch = None, topk = 10)
4742
48-
Abstract method that must be implemented by subclasses.
49-
It should return a list of hardware-aware configurations (hints)
50-
based on the specified architecture.
51-
52-
:param arch: The target architecture. Defaults to None.
53-
:type arch: TileDevice, optional
54-
:param topk: Number of top configurations to return. Defaults to 10.
55-
:type topk: int, optional
56-
57-
:returns: A list of recommended hardware-aware configurations.
58-
:rtype: List[Hint]
59-
60-
6143
6244
.. py:method:: initialize_function()
6345

_sources/autoapi/tilelang/intrinsics/utils/index.rst.txt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,16 @@ Module Contents
3535
3636
.. py:function:: get_mma_micro_size(dtype)
3737
38+
Return the MMA (Tensor Core) micro-tile dimensions for a given data type.
39+
40+
This function returns the micro tile sizes (x, y, k) used by MMA/Tensor Core operations.
41+
- x: tile width in the output/result dimension
42+
- y: tile height in the output/result dimension
43+
- k: tile depth in the reduction/K dimension
44+
45+
Accepted dtype strings include "float16", "int8" and some FP8 identifiers ("float8_e4m3", "float8_e5m2"). For FP8 and int8 types the reduction depth (`k`) is 32; for float16 it is 16.
46+
47+
:returns: (micro_size_x, micro_size_y, micro_size_k)
48+
:rtype: tuple[int, int, int]
49+
50+

_sources/autoapi/tilelang/language/index.rst.txt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,17 @@ Package Contents
5555

5656
.. py:function:: symbolic(name, dtype = 'int32')
5757
58+
Create a TIR symbolic variable.
59+
60+
:param name: Identifier for the variable in generated TIR.
61+
:type name: str
62+
:param dtype: Data type string for the variable (e.g., "int32"). Defaults to "int32".
63+
:type dtype: str
64+
65+
:returns: A TIR variable with the given name and dtype for use in TIR/TensorIR kernels.
66+
:rtype: tir.Var
67+
68+
5869
.. py:function:: use_swizzle(panel_size, order = 'row', enable = True)
5970
6071
.. py:function:: annotate_layout(layout_map)

_sources/autoapi/tilelang/language/utils/index.rst.txt

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,41 +18,43 @@ Module Contents
1818

1919
.. py:function:: index_to_coordinates(index, shape)
2020
21-
Convert a flat (linear) index to multi-dimensional coordinates for a given shape.
21+
Convert a flat (linear) index into multi-dimensional coordinates for a given shape.
2222

23-
.. rubric:: Example
24-
25-
shape = (4, 5, 6)
26-
index = 53
27-
index_to_coordinates(53, (4, 5, 6)) -> [1, 3, 5]
28-
# Explanation:
29-
# 53 // (5*6) = 1 (1st coordinate)
30-
# 53 % (5*6) = 23
31-
# 23 // 6 = 3 (2nd coordinate)
32-
# 23 % 6 = 5 (3rd coordinate)
23+
Given a linear index and a shape (sequence of dimension extents), returns a list of coordinates (one per dimension) such that converting those coordinates back to a linear index using the usual row-major / C-order formula yields the original index. The computation iterates from the last dimension to the first using modulo and integer division, then reverses the collected coordinates.
3324

3425
:param index: The flat index to convert.
35-
:type index: int
36-
:param shape: The shape of the multi-dimensional array.
37-
:type shape: tuple or list of int
26+
:type index: int or PrimExpr
27+
:param shape: The extents of each dimension (length >= 1).
28+
:type shape: Sequence[int]
3829

39-
:returns: A list of coordinates corresponding to each dimension.
40-
:rtype: list
30+
:returns: Coordinates for each dimension in the same order as `shape`.
31+
:rtype: list[PrimExpr]
4132

4233

4334
.. py:function:: linear_index(*args)
4435
45-
Convert a list of coordinates to a flat (linear) index using strides.
36+
Compute a flat (linear) index from multi-dimensional coordinates and strides.
37+
38+
The function accepts a sequence of PrimExpr arguments where the first portion are coordinates
39+
and the trailing portion are the corresponding strides. The number of strides must equal
40+
(number of coordinates - 1). The linear index is computed as:
41+
42+
linear = coords[0]
43+
for each (coord, stride) in zip(coords[1:], strides):
44+
linear = linear * stride + coord
45+
46+
.. rubric:: Examples
47+
48+
- linear_index(i) -> i
49+
- linear_index(i, j) -> i * j_stride + j (requires j_stride provided as stride when needed)
50+
- linear_index(i, j, stride_j) -> i * stride_j + j
51+
- linear_index(i, j, k, stride_j, stride_k) -> i*stride_j*stride_k + j*stride_k + k
52+
- linear_index(i, tx, v, threads, local_size) -> i*threads*local_size + tx*local_size + v
4653

47-
Usage examples:
48-
linear_index(i) -> i
49-
linear_index(i, j) -> i * stride + j
50-
linear_index(i, j, stride_j) -> i * stride_j + j
51-
linear_index(i, j, k, stride_j, stride_k)
52-
-> i * stride_j * stride_k + j * stride_k + k
54+
:raises ValueError: If called with no arguments, or if the number of strides is not one less than
55+
the number of coordinates.
5356

54-
Example for index = i * threads * local_size + tx * local_size + v:
55-
Suppose you have i, tx, v as coordinates, and threads, local_size as strides:
56-
linear_index(i, tx, v, threads, local_size) == i * threads * local_size + tx * local_size + v
57+
:returns: The computed linear index expression.
58+
:rtype: PrimExpr
5759

5860

0 commit comments

Comments
 (0)