1+ """Triton utility functions for offset calculation, masking, and load/store operations."""
2+ # pylint: disable=too-many-arguments,too-many-positional-arguments,redefined-builtin,unused-argument
3+
14import triton
25import triton .language as tl
36from triton .language import constexpr as const
47
5- def cdiv (a ,b ): return (a + b - 1 ) // b
8+ def cdiv (a , b ):
9+ """Ceiling division."""
10+ return (a + b - 1 ) // b
611
712# # offsets
813
914@triton .jit
10- def offset_1d (sz : const , n_prev_chunks = 0 ): return n_prev_chunks * sz + tl .arange (0 , sz )
15+ def offset_1d (sz : const , n_prev_chunks = 0 ):
16+ """Compute 1D offset based on chunk size and previous chunks."""
17+ return n_prev_chunks * sz + tl .arange (0 , sz )
1118
1219@triton .jit
13- def offset_2d (offs0 , offs1 , stride0 , stride1 = 1 ): return tl .expand_dims (offs0 , 1 )* stride0 + tl .expand_dims (offs1 , 0 )* stride1
20+ def offset_2d (offs0 , offs1 , stride0 , stride1 = 1 ):
21+ """Compute 2D offset using strides."""
22+ return tl .expand_dims (offs0 , 1 )* stride0 + tl .expand_dims (offs1 , 0 )* stride1
1423
1524# # masks
1625
1726@triton .jit
18- def mask_1d (offs , max ): return offs < max
27+ def mask_1d (offs , max ):
28+ """Create a 1D mask based on a max bound."""
29+ return offs < max
1930
2031@triton .jit
21- def mask_2d (offs0 , offs1 , max0 , max1 ): return (tl .expand_dims (offs0 , 1 ) < max0 ) & (tl .expand_dims (offs1 , 0 ) < max1 )
32+ def mask_2d (offs0 , offs1 , max0 , max1 ):
33+ """Create a 2D mask using upper bounds for each axis."""
34+ return (tl .expand_dims (offs0 , 1 ) < max0 ) & (tl .expand_dims (offs1 , 0 ) < max1 )
2235
2336# # load
2437
2538@triton .jit
2639def load_1d (ptr , sz : const , n , max , stride = 1 ):
27- '''Chunk 1d vector (defined by ptr) into 1d grid, where each chunk has size sz. Load the nth chunk. Ie, load [n*sz,...,(n+1)*sz-1].'''
40+ """
41+ Chunk 1d vector (defined by ptr) into 1d grid, where each chunk has size sz.
42+ Load the nth chunk. Ie, load [n*sz,...,(n+1)*sz-1].
43+ """
2844 offs = offset_1d (sz , n )
2945 mask = mask_1d (offs , max )
3046 return tl .load (ptr + offs , mask )
@@ -38,7 +54,10 @@ def load_full_1d(ptr, sz: const, stride=1):
3854
3955@triton .jit
4056def load_2d (ptr , sz0 : const , sz1 : const , n0 , n1 , max0 , max1 , stride0 = None , stride1 = 1 ):
41- '''Chunk 2d matrix (defined by ptr) into 2d grid, where each chunk has size (sz0,sz1). Load the (n0,n1)th chunk. Ie, load [n0*sz0,...,(n0+1)*sz0-1] x [n1*sz1,...,(n1+1)*sz1-1].'''
57+ """
58+ Chunk 2d matrix (defined by ptr) into 2d grid, where each chunk has size (sz0,sz1).
59+ Load the (n0,n1)th chunk. Ie, load [n0*sz0,...,(n0+1)*sz0-1] x [n1*sz1,...,(n1+1)*sz1-1].
60+ """
4261 stride0 = stride0 or sz1
4362 offs0 = offset_1d (sz0 , n0 )
4463 offs1 = offset_1d (sz1 , n1 )
@@ -72,7 +91,10 @@ def store_full_1d(vals, ptr, sz: const, stride=1):
7291
7392@triton .jit
7493def store_2d (vals , ptr , sz0 : const , sz1 : const , n0 , n1 , max0 , max1 , stride0 = None , stride1 = 1 ):
75- '''Store 2d block into (n0,n1)th chunk of matrix (defined by ptr), where each chunk has size (sz0, sz1)'''
94+ """
95+ Store 2d block into (n0,n1)th chunk of matrix (defined by ptr), where each chunk has size
96+ (sz0, sz1)
97+ """
7698 stride0 = stride0 or sz1
7799 offs0 = offset_1d (sz0 , n0 )
78100 offs1 = offset_1d (sz1 , n1 )
0 commit comments