Skip to content

Commit 3334944

Browse files
Update docs
1 parent 3166d74 commit 3334944

File tree

186 files changed

+1675
-7
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

186 files changed

+1675
-7
lines changed
Lines changed: 387 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,387 @@
1+
# Tensor Checks (Host-Side Auto-Validation)
2+
3+
This page explains the host-side checks that TileLang automatically inserts into the generated host stub for kernels. When you pass `torch.Tensor` or any DLPack-compatible object to a TileLang kernel, the host stub validates argument count, pointer kinds, dtype, shape, strides, device, and more — so you don’t need to handwrite Python checks. This keeps the ABI stable and significantly reduces Python overhead compared to doing equivalent checks in Python or via pybind.
4+
5+
## Why Host-Side Checks
6+
- ABI stability: the entry is based on TVM FFI + DLPack, consistently accepting tensors and scalars.
7+
- Lower overhead: shifting checks from Python into C reduces interpreter/property-access costs; the call overhead is lower than pybind-based approaches.
8+
- Focused error reporting: assertions are raised close to the call site with precise “which field failed” messages.
9+
10+
## How To Inspect Host Source
11+
You can inspect the auto-generated host source (with all checks and the final device-kernel call) for debugging:
12+
13+
```python
14+
print(matmul_relu_kernel.get_host_source())
15+
```
16+
17+
---
18+
19+
## What The Host Checks
20+
21+
### 1) Argument count and pointer kind
22+
- `num_args` must match the number of formal parameters; otherwise the kernel returns `-1` with an error message.
23+
- Each argument’s FFI type must be a pointer kind (for DLTensor/handle) or a valid scalar type; otherwise you’ll see errors like `Expect arg[i] to be pointer` or a scalar type error.
24+
25+
### 2) Tensor checks (per tensor, after nullability decision)
26+
- Nullability
27+
- If the tensor is “statically reachable/used” by the function body, the handle must be non-NULL; otherwise: `xxx is expected to have non-NULL pointer`.
28+
- If an input tensor is not used by the function (statically unreachable), NULL is allowed; other field checks are executed only when `handle != NULL`.
29+
- Rank (`ndim`)
30+
- Runtime `ndim` must equal the compile-time rank.
31+
- Data type (`dtype`)
32+
- Match the triple `(code, bits, lanes)` with tolerance:
33+
- `float8_e4m3`: accept `e4m3`, `e4m3fn`, `e4m3fnuz`.
34+
- `float8_e5m2`: accept `e5m2`, `e5m2fnuz`.
35+
- `bool`: accept `int8/uint8` with `bits=8` (same lanes), `kDLBool(code=6, bits=1 or 8)`, and any `bitwidth=1` (lanes must match).
36+
- For packed-bit dtypes (e.g., `Int(1)`, `Int(4)`, `UInt(4)`), strict dtype checking is skipped.
37+
- Shape
38+
- Each runtime dimension is bound to the compile-time shape (constants or symbols) and checked for consistency.
39+
- Linear equations among symbolic dims can be solved on the fly (when there’s only one unknown at a given check point), enabling cross-tensor constraints.
40+
- Strides
41+
- If `buffer_type = AutoBroadcast`: allow `strides == NULL` and derive strides from `shape`. If explicit `strides` is present, bind to compile-time constraints and check for equality.
42+
- Otherwise: check per-dimension; if `strides == NULL`, derive from `shape` and compare (e.g., contiguous: `strides[-1] == 1`, `strides[-2] == shape[-1]`).
43+
- `byte_offset`
44+
- Must be 0 (non-zero raises an error) to keep addressing simple and aligned.
45+
- Device info
46+
- Assert `device_type == target backend` (CUDA/ROCM/Metal/OneAPI/WebGPU/CPU, etc.). Error messages include a DLPack code legend.
47+
- When multiple tensors participate, assert that `device_id` matches across them.
48+
- Data pointer
49+
- Must be non-NULL when the tensor is required to be non-null by the nullability rule.
50+
51+
### 3) Scalar checks
52+
- `T.int*` family: require integer; error: `Expect arg[i] to be int`.
53+
- `T.bool`: require boolean; error: `Expect arg[i] to be boolean`.
54+
55+
---
56+
57+
## Shapes and Symbolic Equations: Linear Solving
58+
When shapes are symbolic, the host binds and (when possible) solves linear relations at runtime (only one unknown per check point). Example:
59+
60+
```python
61+
@T.prim_func
62+
def main(
63+
A: T.Tensor((m,), dtype),
64+
B: T.Tensor((m + n,), dtype),
65+
C: T.Tensor((n * k,), dtype),
66+
):
67+
...
68+
```
69+
70+
This enables enforcing cross-tensor relationships like `len(B) == m + n` and `len(C) == n * k` at runtime.
71+
72+
---
73+
74+
## Nullability Rules and Examples
75+
Which tensors may be NULL?
76+
77+
- Rule: If an input tensor is not used by the function under static analysis (i.e., the access is statically unreachable), it is considered Nullable; otherwise it must be non-NULL.
78+
- Examples:
79+
80+
1) Must be non-NULL (used)
81+
```python
82+
@T.prim_func
83+
def main(A: T.Tensor((M, K), dtype)):
84+
A[0] = 1
85+
```
86+
Passing `None` raises: `main.A_handle is expected to have non-NULL pointer`.
87+
88+
2) Still must be non-NULL (constant-true branch)
89+
```python
90+
some_cond: bool = True
91+
@T.prim_func
92+
def main(A: T.Tensor((M, K), dtype)):
93+
if some_cond:
94+
A[0] = 1
95+
```
96+
97+
3) Nullable (constant-false branch, statically unreachable)
98+
```python
99+
some_cond: bool = False
100+
@T.prim_func
101+
def main(A: T.Tensor((M, K), dtype)):
102+
if some_cond:
103+
A[0] = 1
104+
```
105+
106+
4) Must be non-NULL (runtime condition)
107+
```python
108+
@T.prim_func
109+
def main(A: T.Tensor((M, K), dtype), some_cond: T.bool):
110+
if some_cond:
111+
A[0] = 1
112+
```
113+
Since `some_cond` is only known at runtime, static analysis cannot prove `A` is unused; `A` is thus non-nullable.
114+
115+
---
116+
117+
## Device Type Codes (DLPack)
118+
Supported and referenced device codes in error messages: `1=CPU, 2=CUDA, 7=Vulkan, 8=Metal, 10=ROCM, 14=OneAPI, 15=WebGPU`.
119+
Kernels assert that `device_type` matches the target backend, and require `device_id` consistency across tensors.
120+
121+
---
122+
123+
## Common Error Examples (What you’ll see)
124+
- Argument count mismatch (num_args)
125+
- Trigger: missing/extra argument
126+
- Error: `<kernel>: num_args should be N; expected: <num_args>, got: N`
127+
128+
- Pointer-typed argument expected
129+
- Trigger: scalar passed where a tensor is expected
130+
- Error: `<kernel>: Expect arg[i] to be pointer`
131+
132+
- Rank (ndim) mismatch
133+
- Trigger: runtime rank differs from compile-time rank
134+
- Error: `<kernel>.<name>.ndim is expected to equal R, but got mismatched ndim`
135+
136+
- Dtype mismatch
137+
- Trigger: dtype not equal to the compiled dtype and not within the tolerance set
138+
- Error: `<kernel>.<name>.dtype is expected to be <dtype>, but got incompatible dtype`
139+
140+
- Shape constraint violation
141+
- Trigger: a dimension doesn’t match a constant/symbol binding
142+
- Error: `Argument <kernel>.<name>.shape[i] has an unsatisfied constraint: ... == <expected>`
143+
144+
- Strides check failed (e.g., non-contiguous layout)
145+
- Trigger: transposed/sliced tensors that violate expected strides
146+
- Error: `Argument <kernel>.<name>.strides[j] has an unsatisfied constraint: ... == <expected>`
147+
148+
- Device type mismatch
149+
- Trigger: calling a CUDA kernel with CPU tensors, etc.
150+
- Error: `<kernel>.<name>.device_type mismatch [expected: <code> (<name>)] ...`
151+
152+
- Device id mismatch
153+
- Trigger: mixing tensors from different GPUs
154+
- Error: `Argument <kernel>.<name>.device_id has an unsatisfied constraint: ... == ...`
155+
156+
- NULL data pointer
157+
- Trigger: tensor required to be non-null has a NULL data pointer
158+
- Error: `<kernel>.<name> is expected to have non-NULL data pointer, but got NULL`
159+
160+
- Scalar type mismatch
161+
- Trigger: passing float to `T.int32`, or non-boolean to `T.bool`
162+
- Error: `<kernel>: Expect arg[i] to be int/boolean`
163+
164+
---
165+
166+
## Troubleshooting Tips
167+
- Print the host source: `print(fn.get_host_source())` to see the exact assertion and expected vs. actual fields.
168+
- Fix strides: call `.contiguous()` for non-contiguous tensors, or avoid generating transposed/sliced layouts that break assumptions.
169+
- Align devices: ensure all participating tensors share the same `device_type` and `device_id`.
170+
- Align dtype: use `.to(<dtype>)` or construct tensors with the correct dtype; pay attention to `float8` and `bool` tolerance.
171+
- Dynamic shapes: ensure cross-tensor linear relations can be uniquely determined at the check point (only one unknown at a time).
172+
173+
---
174+
175+
## FAQ
176+
- Can I disable the checks?
177+
- Not recommended and usually not supported. Checks are done on the host to preserve ABI stability and fail early close to the device call.
178+
- Is the overhead noticeable?
179+
- The checks are lightweight (branches and field reads). Compared to Python-side checks, it’s faster; the dominating cost remains the Python→C boundary. Overall it’s cheaper than equivalent checks in Python.
180+
181+
---
182+
183+
## Reference Example (Matmul + ReLU)
184+
185+
```python
186+
@T.prim_func
187+
def matmul_relu_kernel(
188+
A: T.Tensor((M, K), dtype),
189+
B: T.Tensor((K, N), dtype),
190+
C: T.Tensor((M, N), dtype),
191+
):
192+
# Initialize Kernel Context
193+
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
194+
A_shared = T.alloc_shared((block_M, block_K), dtype)
195+
B_shared = T.alloc_shared((block_K, block_N), dtype)
196+
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
197+
T.clear(C_local)
198+
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0):
199+
T.copy(A[by * block_M, ko * block_K], A_shared)
200+
T.copy(B[ko * block_K, bx * block_N], B_shared)
201+
T.gemm(A_shared, B_shared, C_local)
202+
T.copy(C_local, C[by * block_M, bx * block_N])
203+
204+
# For debugging, print the host source
205+
print(matmul_relu_kernel.get_host_source())
206+
```
207+
208+
The host will insert all checks described above for this example.
209+
210+
---
211+
212+
## Quick Error Reference (Short List)
213+
- Argument count
214+
- Trigger: missing/extra args; Error: `num_args should be N; expected: <num_args>, got: N`.
215+
- Pointer kind
216+
- Trigger: scalar passed to tensor arg; Error: `Expect arg[i] to be pointer`.
217+
- Rank (ndim)
218+
- Trigger: runtime rank != compile-time; Error: `ndim ... expected to equal R`.
219+
- Dtype
220+
- Trigger: mismatch and not tolerated; Error: `dtype ... expected to be <dtype>`.
221+
- Shape
222+
- Trigger: constant/symbol binding violated; Error: `shape[i] ... == <expected>`.
223+
- Strides
224+
- Trigger: layout mismatch; Error: `strides[j] ... == <expected>`.
225+
- Device type
226+
- Trigger: wrong backend device; Error: `device_type mismatch [expected: ...]`.
227+
- Device id
228+
- Trigger: tensors on different GPUs; Error: `device_id ... == ...`.
229+
- Data pointer
230+
- Trigger: required non-NULL but NULL; Error: `non-NULL data pointer`.
231+
- Scalar types
232+
- Trigger: wrong scalar type; Error: `Expect arg[i] to be int/boolean`.
233+
234+
---
235+
236+
## Host Error Troubleshooting (Minimal Repros)
237+
238+
Below are minimal repro snippets for common host-side errors, assuming a CUDA-targeted kernel like `matmul_relu_kernel` with:
239+
240+
```python
241+
# Convention:
242+
# A: float16 [M, K]
243+
# B: float16 [K, N]
244+
# C: float16 [M, N]
245+
# Target: CUDA (device_type=2)
246+
fn = matmul_relu_kernel # your compiled function
247+
M = N = K = 1024
248+
```
249+
250+
Adjust dtype/device if your kernel differs.
251+
252+
### 0. Tip: print the host source
253+
```python
254+
print(fn.get_host_source())
255+
```
256+
257+
### 1. num_args mismatch
258+
```python
259+
import torch
260+
261+
A = torch.empty((M, K), device='cuda', dtype=torch.float16)
262+
B = torch.empty((K, N), device='cuda', dtype=torch.float16)
263+
# Missing C
264+
fn(A, B)
265+
```
266+
Expected: `<kernel>: num_args should be 3; expected: <num_args>, got: 3`.
267+
268+
Fix: pass all arguments per the signature.
269+
270+
### 2. Expect pointer (tensor) but got scalar
271+
```python
272+
import torch
273+
274+
B = torch.empty((K, N), device='cuda', dtype=torch.float16)
275+
C = torch.empty((M, N), device='cuda', dtype=torch.float16)
276+
fn(1, B, C)
277+
```
278+
Expected: `<kernel>: Expect arg[0] to be pointer`.
279+
280+
Fix: pass a DLPack-compatible tensor (e.g., torch.Tensor).
281+
282+
### 3. ndim mismatch
283+
```python
284+
import torch
285+
286+
A = torch.empty((M, K, 1), device='cuda', dtype=torch.float16) # rank=3
287+
B = torch.empty((K, N), device='cuda', dtype=torch.float16)
288+
C = torch.empty((M, N), device='cuda', dtype=torch.float16)
289+
fn(A, B, C)
290+
```
291+
Expected: `<kernel>.A_handle.ndim is expected to equal 2, but got mismatched ndim`.
292+
293+
Fix: ensure runtime rank equals compiled rank.
294+
295+
### 4. dtype mismatch
296+
```python
297+
import torch
298+
299+
A = torch.empty((M, K), device='cuda', dtype=torch.float32) # should be float16
300+
B = torch.empty((K, N), device='cuda', dtype=torch.float16)
301+
C = torch.empty((M, N), device='cuda', dtype=torch.float16)
302+
fn(A, B, C)
303+
```
304+
Expected: `<kernel>.A_handle.dtype is expected to be float16, but got incompatible dtype`.
305+
306+
Fix: `A = A.to(torch.float16)` or create with the correct dtype.
307+
308+
### 5. Shape constant/symbol mismatch
309+
```python
310+
import torch
311+
312+
A = torch.empty((M, K + 1), device='cuda', dtype=torch.float16) # K mismatched
313+
B = torch.empty((K, N), device='cuda', dtype=torch.float16)
314+
C = torch.empty((M, N), device='cuda', dtype=torch.float16)
315+
fn(A, B, C)
316+
```
317+
Expected: `Argument <kernel>.A_handle.shape[i] has an unsatisfied constraint: ... == <expected>`.
318+
319+
Fix: satisfy linear constraints and constants across tensors.
320+
321+
### 6. Strides check failure (non-contiguous)
322+
```python
323+
import torch
324+
325+
A = torch.empty((M, K), device='cuda', dtype=torch.float16)
326+
A_nc = A.t() # transpose -> non-contiguous
327+
B = torch.empty((K, N), device='cuda', dtype=torch.float16)
328+
C = torch.empty((M, N), device='cuda', dtype=torch.float16)
329+
fn(A_nc, B, C)
330+
```
331+
Expected: `Argument <kernel>.A_handle.strides[1] has an unsatisfied constraint: ... == 1`.
332+
333+
Fix: pass `A_nc.contiguous()` or align the layout expectation in the kernel.
334+
335+
### 7. device_type mismatch
336+
```python
337+
import torch
338+
339+
A = torch.empty((M, K), device='cpu', dtype=torch.float16)
340+
B = torch.empty((K, N), device='cpu', dtype=torch.float16)
341+
C = torch.empty((M, N), device='cpu', dtype=torch.float16)
342+
fn(A, B, C) # CUDA-targeted kernel
343+
```
344+
Expected: `<kernel>.A_handle.device_type mismatch [expected: 2 (cuda)] ...`.
345+
346+
Fix: move tensors to the CUDA device.
347+
348+
### 8. device_id mismatch (multi-GPU)
349+
```python
350+
import torch
351+
352+
A = torch.empty((M, K), device='cuda:0', dtype=torch.float16)
353+
B = torch.empty((K, N), device='cuda:1', dtype=torch.float16)
354+
C = torch.empty((M, N), device='cuda:0', dtype=torch.float16)
355+
fn(A, B, C)
356+
```
357+
Expected: `Argument <kernel>.B_handle.device_id has an unsatisfied constraint: ... == ...`.
358+
359+
Fix: place all tensors on the same GPU (e.g., `cuda:0`).
360+
361+
### 9. NULL data pointer (advanced)
362+
This usually comes from hand-constructed DLTensor/NDArray, or external frameworks passing unallocated/freed storage. Regular `torch.Tensor` allocations rarely hit this.
363+
364+
Expected: `<kernel>.<name> is expected to have non-NULL data pointer, but got NULL`.
365+
366+
Fix: ensure valid underlying storage; in PyTorch scenarios, avoid constructing tensors from invalid external handles.
367+
368+
### 10. Scalar type mismatch (int / bool)
369+
```python
370+
import tilelang.language as T
371+
372+
@T.prim_func
373+
def scalar_check(x: T.int32, flag: T.bool()):
374+
T.evaluate(0)
375+
376+
scalar_check(1.0, True) # x is float -> Expect arg[0] to be int
377+
scalar_check(1, 2.5) # flag is float -> Expect arg[1] to be boolean
378+
```
379+
380+
Fix: pass correct scalar types, e.g., `scalar_check(1, True)`.
381+
382+
---
383+
384+
## Closing Notes
385+
- Cross-check “shape / strides / device / dtype” against the kernel signature to localize issues efficiently.
386+
- For complex symbolic relations, print the host source to confirm binding/solving order, then adjust runtime shapes/layouts accordingly.
387+

_sources/index.md.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ deeplearning_operators/deepseek_mla
4242

4343
compiler_internals/letstmt_inline
4444
compiler_internals/inject_fence_proxy
45+
compiler_internals/tensor_checks
4546
:::
4647

4748
:::{toctree}

0 commit comments

Comments
 (0)