You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Calculates reduction step candidates for each reduction axis in a PrimFuncNode. General idea : use factor first, since it does not require extra boundary check. for large prime number, which is rare case, use power of 2.
83
+
84
+
:param node: The node for which to calculate reduction step candidates. It contains reduction axes (raxis)
85
+
with their domains (dom.extent).
86
+
:type node: PrimFuncNode
87
+
88
+
:returns: A dictionary mapping axis variable names to lists of step candidates. For each axis in the node,
89
+
this function calculates possible step sizes. For axes with a large prime domain, it uses powers of 2
90
+
as step candidates; for others, it uses all factors of the domain.
91
+
:rtype: Dict[str, List[int]]
92
+
93
+
67
94
68
95
.. py:method:: check_tile_shape_isvalid(td)
69
96
97
+
Checks if the tile shapes in the TileDict are valid for the nodes in this context.
98
+
99
+
Parameters:
100
+
- td (TileDict): The TileDict object containing tile shapes and other configurations.
101
+
102
+
Returns:
103
+
- bool: True if all tile shapes are valid, False otherwise.
104
+
105
+
70
106
71
107
.. py:method:: compute_node_stride_map(node, td)
72
108
109
+
Computes the stride map for a given node based on the TileDict configuration.
110
+
111
+
:param node: The node for which to compute the stride map.
112
+
:type node: PrimFuncNode
113
+
:param td: The TileDict object containing the tile configuration.
114
+
:type td: TileDict
115
+
116
+
:returns: A tuple of dictionaries containing the output strides and tensor strides.
117
+
:rtype: Tuple[Dict, Dict]
118
+
119
+
73
120
74
121
.. py:method:: plan_rasterization(td)
75
122
123
+
Plans the rasterization for the given TileDict. This function is not implemented yet.
124
+
125
+
:param td: The TileDict object to plan rasterization for.
126
+
:type td: TileDict
127
+
128
+
:raises RasterRationPlan: This function is not implemented yet.
Copy file name to clipboardExpand all lines: _sources/autoapi/tilelang/intrinsics/utils/index.rst.txt
+13Lines changed: 13 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -35,3 +35,16 @@ Module Contents
35
35
36
36
.. py:function:: get_mma_micro_size(dtype)
37
37
38
+
Return the MMA (Tensor Core) micro-tile dimensions for a given data type.
39
+
40
+
This function returns the micro tile sizes (x, y, k) used by MMA/Tensor Core operations.
41
+
- x: tile width in the output/result dimension
42
+
- y: tile height in the output/result dimension
43
+
- k: tile depth in the reduction/K dimension
44
+
45
+
Accepted dtype strings include "float16", "int8" and some FP8 identifiers ("float8_e4m3", "float8_e5m2"). For FP8 and int8 types the reduction depth (`k`) is 32; for float16 it is 16.
Convert a flat (linear) index to multi-dimensional coordinates for a given shape.
21
+
Convert a flat (linear) index into multi-dimensional coordinates for a given shape.
22
22
23
-
.. rubric:: Example
24
-
25
-
shape = (4, 5, 6)
26
-
index = 53
27
-
index_to_coordinates(53, (4, 5, 6)) -> [1, 3, 5]
28
-
# Explanation:
29
-
# 53 // (5*6) = 1 (1st coordinate)
30
-
# 53 % (5*6) = 23
31
-
# 23 // 6 = 3 (2nd coordinate)
32
-
# 23 % 6 = 5 (3rd coordinate)
23
+
Given a linear index and a shape (sequence of dimension extents), returns a list of coordinates (one per dimension) such that converting those coordinates back to a linear index using the usual row-major / C-order formula yields the original index. The computation iterates from the last dimension to the first using modulo and integer division, then reverses the collected coordinates.
33
24
34
25
:param index: The flat index to convert.
35
-
:type index: int
36
-
:param shape: The shape of the multi-dimensional array.
37
-
:type shape:tuple or list of int
26
+
:type index: int or PrimExpr
27
+
:param shape: The extents of each dimension (length >= 1).
28
+
:type shape:Sequence[int]
38
29
39
-
:returns:A list of coordinates corresponding to each dimension.
40
-
:rtype: list
30
+
:returns:Coordinates for each dimension in the same order as `shape`.
31
+
:rtype: list[PrimExpr]
41
32
42
33
43
34
.. py:function:: linear_index(*args)
44
35
45
-
Convert a list of coordinates to a flat (linear) index using strides.
36
+
Compute a flat (linear) index from multi-dimensional coordinates and strides.
37
+
38
+
The function accepts a sequence of PrimExpr arguments where the first portion are coordinates
39
+
and the trailing portion are the corresponding strides. The number of strides must equal
40
+
(number of coordinates - 1). The linear index is computed as:
41
+
42
+
linear = coords[0]
43
+
for each (coord, stride) in zip(coords[1:], strides):
44
+
linear = linear * stride + coord
45
+
46
+
.. rubric:: Examples
47
+
48
+
- linear_index(i) -> i
49
+
- linear_index(i, j) -> i * j_stride + j (requires j_stride provided as stride when needed)
50
+
- linear_index(i, j, stride_j) -> i * stride_j + j
0 commit comments