Skip to content

Commit ce76595

Browse files
Update docs
1 parent 87dbe67 commit ce76595

File tree

161 files changed

+9471
-9593
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

161 files changed

+9471
-9593
lines changed
Lines changed: 323 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,323 @@
1+
tilelang.language.atomic
2+
========================
3+
4+
.. py:module:: tilelang.language.atomic
5+
6+
.. autoapi-nested-parse::
7+
8+
Atomic operations for tilelang.
9+
10+
11+
12+
Functions
13+
---------
14+
15+
.. autoapisummary::
16+
17+
tilelang.language.atomic.atomic_max
18+
tilelang.language.atomic.atomic_min
19+
tilelang.language.atomic.atomic_add
20+
tilelang.language.atomic.atomic_addx2
21+
tilelang.language.atomic.atomic_addx4
22+
tilelang.language.atomic.atomic_load
23+
tilelang.language.atomic.atomic_store
24+
25+
26+
Module Contents
27+
---------------
28+
29+
.. py:function:: atomic_max(dst, value, memory_order = None, return_prev = False)
30+
31+
Perform an atomic maximum on the value stored at dst with an optional memory-order.
32+
33+
If memory_order is None the runtime extern "AtomicMax" is called without an explicit memory-order id; otherwise the provided memory_order string is mapped to a numeric id using the module's memory-order map and passed to the extern.
34+
35+
:param dst: Destination buffer/address to apply the atomic max.
36+
:type dst: Buffer
37+
:param value: Value to compare/store atomically.
38+
:type value: PrimExpr
39+
:param memory_order: Optional memory-order name (e.g. "relaxed", "acquire", "seq_cst").
40+
If provided, it is translated to the corresponding numeric memory-order id before the call.
41+
:type memory_order: Optional[str]
42+
:param return_prev: If True, return the previous value; if False, return handle (default False).
43+
:type return_prev: bool
44+
45+
:returns: A handle/expression representing the issued atomic maximum operation, or the previous value if return_prev is True.
46+
:rtype: PrimExpr
47+
48+
.. rubric:: Examples
49+
50+
>>> # Basic atomic max operation
51+
>>> counter = T.Tensor([1], "float32", name="counter")
52+
>>> atomic_max(counter, 42.0)
53+
54+
>>> # With memory ordering
55+
>>> atomic_max(counter, 100.0, memory_order="acquire")
56+
57+
>>> # Get the previous value
58+
>>> prev_value = atomic_max(counter, 50.0, return_prev=True)
59+
>>> # prev_value now contains the value that was in counter before the max operation
60+
61+
>>> # Use in parallel reduction to find global maximum
62+
>>> @T.prim_func
63+
>>> def find_max(data: T.Buffer, result: T.Buffer):
64+
>>> for i in T.thread_binding(128, "threadIdx.x"):
65+
>>> atomic_max(result, data[i])
66+
67+
68+
.. py:function:: atomic_min(dst, value, memory_order = None, return_prev = False)
69+
70+
Atomically update the value at dst to the minimum of its current value and value.
71+
72+
If memory_order is provided, it selects the memory-order semantic used by the underlying extern call;
73+
allowed names are "relaxed", "consume", "acquire", "release", "acq_rel", and "seq_cst" (mapped internally
74+
to integer IDs). If memory_order is None, the extern is invoked without an explicit memory-order argument.
75+
76+
:param dst: Destination buffer/address to apply the atomic min.
77+
:type dst: Buffer
78+
:param value: Value to compare/store atomically.
79+
:type value: PrimExpr
80+
:param memory_order: Optional memory-order name controlling the atomic operation's ordering.
81+
:type memory_order: Optional[str]
82+
:param return_prev: If True, return the previous value; if False, return handle (default False).
83+
:type return_prev: bool
84+
85+
:returns: A handle expression representing the atomic-min operation, or the previous value if return_prev is True.
86+
:rtype: PrimExpr
87+
88+
.. rubric:: Examples
89+
90+
>>> # Basic atomic min operation
91+
>>> min_val = T.Tensor([1], "int32", name="min_val")
92+
>>> atomic_min(min_val, 10)
93+
94+
>>> # Find minimum across threads
95+
>>> @T.prim_func
96+
>>> def find_min(data: T.Buffer, result: T.Buffer):
97+
>>> for i in T.thread_binding(256, "threadIdx.x"):
98+
>>> atomic_min(result, data[i])
99+
100+
>>> # Track minimum with previous value
101+
>>> threshold = T.Tensor([1], "float32", name="threshold")
102+
>>> old_min = atomic_min(threshold, 3.14, return_prev=True)
103+
>>> # old_min contains the previous minimum value
104+
105+
>>> # With relaxed memory ordering for performance
106+
>>> atomic_min(min_val, 5, memory_order="relaxed")
107+
108+
109+
.. py:function:: atomic_add(dst, value, memory_order = None, return_prev = False)
110+
111+
Atomically add `value` into `dst`, returning a handle to the operation.
112+
113+
Supports scalar/addressed extern atomic add when neither argument exposes extents, or tile-region-based atomic add for Buffer/BufferRegion/BufferLoad inputs. If both arguments are plain Buffers their shapes must be structurally equal. If at least one side exposes extents, extents are aligned (missing dimensions are treated as size 1); an assertion is raised if extents cannot be deduced. The optional `memory_order` (one of "relaxed","consume","acquire","release","acq_rel","seq_cst") is used only for the direct extern `AtomicAdd` path when no extents are available — otherwise the tile-region path ignores `memory_order`.
114+
115+
:param dst: Destination buffer/address to apply the atomic add.
116+
:type dst: Buffer
117+
:param value: Value to add atomically.
118+
:type value: PrimExpr
119+
:param memory_order: Optional memory-order name controlling the atomic operation's ordering.
120+
:type memory_order: Optional[str]
121+
:param return_prev: If True, return the previous value; if False, return handle (default False).
122+
:type return_prev: bool
123+
124+
:returns: A handle representing the atomic addition operation, or the previous value if return_prev is True.
125+
:rtype: PrimExpr
126+
127+
.. rubric:: Examples
128+
129+
>>> # Basic atomic addition
130+
>>> counter = T.Tensor([1], "int32", name="counter")
131+
>>> atomic_add(counter, 1) # Increment counter by 1
132+
133+
>>> # Parallel sum reduction
134+
>>> @T.prim_func
135+
>>> def parallel_sum(data: T.Buffer, result: T.Buffer):
136+
>>> for i in T.thread_binding(1024, "threadIdx.x"):
137+
>>> atomic_add(result, data[i])
138+
139+
>>> # Get previous value for debugging
140+
>>> old_value = atomic_add(counter, 5, return_prev=True)
141+
>>> # old_value contains the value before adding 5
142+
143+
>>> # Tensor-to-tensor atomic add (tile-region based)
144+
>>> src_tensor = T.Tensor([128, 64], "float32", name="src")
145+
>>> dst_tensor = T.Tensor([128, 64], "float32", name="dst")
146+
>>> atomic_add(dst_tensor, src_tensor) # Add entire tensors atomically
147+
148+
>>> # With memory ordering for scalar operations
149+
>>> atomic_add(counter, 10, memory_order="acquire")
150+
151+
>>> # Accumulate gradients in training
152+
>>> gradients = T.Tensor([1000], "float32", name="gradients")
153+
>>> global_grad = T.Tensor([1000], "float32", name="global_grad")
154+
>>> atomic_add(global_grad, gradients)
155+
156+
157+
.. py:function:: atomic_addx2(dst, value, return_prev = False)
158+
159+
Perform an atomic addition operation with double-width operands.
160+
161+
:param dst: Destination buffer where the atomic addition will be performed
162+
:type dst: Buffer
163+
:param value: Value to be atomically added (double-width)
164+
:type value: PrimExpr
165+
:param return_prev: If True, return the previous value; if False, return handle (default False)
166+
:type return_prev: bool
167+
168+
:returns: Handle to the double-width atomic addition operation, or the previous value if return_prev is True
169+
:rtype: PrimExpr
170+
171+
.. rubric:: Examples
172+
173+
>>> # Atomic addition with FP16 pairs
174+
>>> half_dst = T.Tensor([2], "float16", name="half_dst")
175+
>>> half_val = T.Tensor([2], "float16", name="half_val")
176+
>>> atomic_addx2(half_dst, half_val)
177+
178+
>>> # BF16 vectorized atomic add (requires CUDA Arch > 750)
179+
>>> bf16_dst = T.Tensor([2], "bfloat16", name="bf16_dst")
180+
>>> bf16_val = T.Tensor([2], "bfloat16", name="bf16_val")
181+
>>> atomic_addx2(bf16_dst, bf16_val)
182+
183+
>>> # Get previous paired values
184+
>>> prev_values = atomic_addx2(half_dst, half_val, return_prev=True)
185+
>>> # prev_values is a half2 containing the two previous FP16 values
186+
187+
>>> # Efficient gradient accumulation for mixed precision training
188+
>>> @T.prim_func
189+
>>> def accumulate_fp16_gradients(grads: T.Buffer, global_grads: T.Buffer):
190+
>>> for i in T.thread_binding(128, "threadIdx.x"):
191+
>>> for j in range(0, grads.shape[1], 2): # Process in pairs
192+
>>> atomic_addx2(global_grads[i, j:j+2], grads[i, j:j+2])
193+
194+
195+
.. py:function:: atomic_addx4(dst, value, return_prev = False)
196+
197+
Perform an atomic addition operation with quad-width operands.
198+
199+
:param dst: Destination buffer where the atomic addition will be performed
200+
:type dst: Buffer
201+
:param value: Value to be atomically added (quad-width)
202+
:type value: PrimExpr
203+
:param return_prev: If True, return the previous value; if False, return handle (default False)
204+
:type return_prev: bool
205+
206+
:returns: Handle to the quad-width atomic addition operation, or the previous value if return_prev is True
207+
:rtype: PrimExpr
208+
209+
.. rubric:: Examples
210+
211+
>>> # Atomic addition with float4 (requires CUDA Arch >= 900)
212+
>>> float4_dst = T.Tensor([4], "float32", name="float4_dst")
213+
>>> float4_val = T.Tensor([4], "float32", name="float4_val")
214+
>>> atomic_addx4(float4_dst, float4_val)
215+
216+
>>> # Get previous float4 values
217+
>>> prev_float4 = atomic_addx4(float4_dst, float4_val, return_prev=True)
218+
>>> # prev_float4 is a float4 containing the four previous float32 values
219+
220+
>>> # High-throughput gradient accumulation for large models
221+
>>> @T.prim_func
222+
>>> def accumulate_float4_gradients(grads: T.Buffer, global_grads: T.Buffer):
223+
>>> for i in T.thread_binding(256, "threadIdx.x"):
224+
>>> for j in range(0, grads.shape[1], 4): # Process 4 floats at once
225+
>>> atomic_addx4(global_grads[i, j:j+4], grads[i, j:j+4])
226+
227+
>>> # Efficient RGBA pixel blending
228+
>>> rgba_dst = T.Tensor([4], "float32", name="rgba_dst") # R, G, B, A channels
229+
>>> rgba_add = T.Tensor([4], "float32", name="rgba_add")
230+
>>> atomic_addx4(rgba_dst, rgba_add) # Atomic blend of all 4 channels
231+
232+
233+
.. py:function:: atomic_load(src, memory_order = 'seq_cst')
234+
235+
Load a value from the given buffer using the specified atomic memory ordering.
236+
237+
Performs an atomic load from `src` and returns a PrimExpr representing the loaded value.
238+
memory_order selects the ordering and must be one of: "relaxed", "consume", "acquire",
239+
"release", "acq_rel", or "seq_cst" (default).
240+
Raises KeyError if an unknown memory_order is provided.
241+
242+
Note: atomic_load always returns the loaded value, so no return_prev parameter is needed.
243+
244+
.. rubric:: Examples
245+
246+
>>> # Basic atomic load
247+
>>> shared_var = T.Tensor([1], "int32", name="shared_var")
248+
>>> value = atomic_load(shared_var)
249+
250+
>>> # Load with specific memory ordering
251+
>>> value = atomic_load(shared_var, memory_order="acquire")
252+
>>> # Ensures all subsequent memory operations happen after this load
253+
254+
>>> # Relaxed load for performance-critical code
255+
>>> value = atomic_load(shared_var, memory_order="relaxed")
256+
257+
>>> # Producer-consumer pattern
258+
>>> @T.prim_func
259+
>>> def consumer(flag: T.Buffer, data: T.Buffer, result: T.Buffer):
260+
>>> # Wait until producer sets flag
261+
>>> while atomic_load(flag, memory_order="acquire") == 0:
262+
>>> pass # Spin wait
263+
>>> # Now safely read data
264+
>>> result[0] = data[0]
265+
266+
>>> # Load counter for statistics
267+
>>> counter = T.Tensor([1], "int64", name="counter")
268+
>>> current_count = atomic_load(counter, memory_order="relaxed")
269+
270+
271+
.. py:function:: atomic_store(dst, src, memory_order = 'seq_cst')
272+
273+
Perform an atomic store of `src` into `dst` with the given memory ordering.
274+
275+
:param dst: Destination buffer to store into.
276+
:type dst: Buffer
277+
:param src: Value to store.
278+
:type src: PrimExpr
279+
:param memory_order: Memory ordering name; one of "relaxed", "consume",
280+
"acquire", "release", "acq_rel", or "seq_cst". Defaults to "seq_cst".
281+
The name is mapped to an internal numeric ID used by the underlying runtime.
282+
:type memory_order: str, optional
283+
284+
:returns: A handle representing the issued atomic store operation.
285+
:rtype: PrimExpr
286+
287+
:raises KeyError: If `memory_order` is not one of the supported names.
288+
289+
Note: atomic_store doesn't return a previous value, so no return_prev parameter is needed.
290+
291+
.. rubric:: Examples
292+
293+
>>> # Basic atomic store
294+
>>> shared_var = T.Tensor([1], "int32", name="shared_var")
295+
>>> atomic_store(shared_var, 42)
296+
297+
>>> # Store with release ordering to publish data
298+
>>> data = T.Tensor([1000], "float32", name="data")
299+
>>> ready_flag = T.Tensor([1], "int32", name="ready_flag")
300+
>>> # ... fill data ...
301+
>>> atomic_store(ready_flag, 1, memory_order="release")
302+
>>> # Ensures all previous writes are visible before flag is set
303+
304+
>>> # Relaxed store for performance
305+
>>> atomic_store(shared_var, 100, memory_order="relaxed")
306+
307+
>>> # Producer-consumer synchronization
308+
>>> @T.prim_func
309+
>>> def producer(data: T.Buffer, flag: T.Buffer):
310+
>>> data[0] = 3.14159 # Write data first
311+
>>> atomic_store(flag, 1, memory_order="release")
312+
>>> # Consumer can now safely read data after seeing flag == 1
313+
314+
>>> # Update configuration atomically
315+
>>> config = T.Tensor([1], "int32", name="config")
316+
>>> new_config = 0x12345678
317+
>>> atomic_store(config, new_config, memory_order="seq_cst")
318+
319+
>>> # Thread-safe logging counter
320+
>>> log_counter = T.Tensor([1], "int64", name="log_counter")
321+
>>> atomic_store(log_counter, 0) # Reset counter atomically
322+
323+

_sources/autoapi/tilelang/language/builtin/index.rst.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ Module Contents
301301
The value to shuffle
302302

303303

304-
.. py:function:: sync_threads()
304+
.. py:function:: sync_threads(barrier_id = None, arrive_count = None)
305305
306306
Synchronize all threads in a block.
307307

0 commit comments

Comments
 (0)