|
| 1 | +# Tensor Checks (Host-Side Auto-Validation) |
| 2 | + |
| 3 | +This page explains the host-side checks that TileLang automatically inserts into the generated host stub for kernels. When you pass `torch.Tensor` or any DLPack-compatible object to a TileLang kernel, the host stub validates argument count, pointer kinds, dtype, shape, strides, device, and more — so you don’t need to handwrite Python checks. This keeps the ABI stable and significantly reduces Python overhead compared to doing equivalent checks in Python or via pybind. |
| 4 | + |
| 5 | +## Why Host-Side Checks |
| 6 | +- ABI stability: the entry is based on TVM FFI + DLPack, consistently accepting tensors and scalars. |
| 7 | +- Lower overhead: shifting checks from Python into C reduces interpreter/property-access costs; the call overhead is lower than pybind-based approaches. |
| 8 | +- Focused error reporting: assertions are raised close to the call site with precise “which field failed” messages. |
| 9 | + |
| 10 | +## How To Inspect Host Source |
| 11 | +You can inspect the auto-generated host source (with all checks and the final device-kernel call) for debugging: |
| 12 | + |
| 13 | +```python |
| 14 | +print(matmul_relu_kernel.get_host_source()) |
| 15 | +``` |
| 16 | + |
| 17 | +--- |
| 18 | + |
| 19 | +## What The Host Checks |
| 20 | + |
| 21 | +### 1) Argument count and pointer kind |
| 22 | +- `num_args` must match the number of formal parameters; otherwise the kernel returns `-1` with an error message. |
| 23 | +- Each argument’s FFI type must be a pointer kind (for DLTensor/handle) or a valid scalar type; otherwise you’ll see errors like `Expect arg[i] to be pointer` or a scalar type error. |
| 24 | + |
| 25 | +### 2) Tensor checks (per tensor, after nullability decision) |
| 26 | +- Nullability |
| 27 | + - If the tensor is “statically reachable/used” by the function body, the handle must be non-NULL; otherwise: `xxx is expected to have non-NULL pointer`. |
| 28 | + - If an input tensor is not used by the function (statically unreachable), NULL is allowed; other field checks are executed only when `handle != NULL`. |
| 29 | +- Rank (`ndim`) |
| 30 | + - Runtime `ndim` must equal the compile-time rank. |
| 31 | +- Data type (`dtype`) |
| 32 | + - Match the triple `(code, bits, lanes)` with tolerance: |
| 33 | + - `float8_e4m3`: accept `e4m3`, `e4m3fn`, `e4m3fnuz`. |
| 34 | + - `float8_e5m2`: accept `e5m2`, `e5m2fnuz`. |
| 35 | + - `bool`: accept `int8/uint8` with `bits=8` (same lanes), `kDLBool(code=6, bits=1 or 8)`, and any `bitwidth=1` (lanes must match). |
| 36 | + - For packed-bit dtypes (e.g., `Int(1)`, `Int(4)`, `UInt(4)`), strict dtype checking is skipped. |
| 37 | +- Shape |
| 38 | + - Each runtime dimension is bound to the compile-time shape (constants or symbols) and checked for consistency. |
| 39 | + - Linear equations among symbolic dims can be solved on the fly (when there’s only one unknown at a given check point), enabling cross-tensor constraints. |
| 40 | +- Strides |
| 41 | + - If `buffer_type = AutoBroadcast`: allow `strides == NULL` and derive strides from `shape`. If explicit `strides` is present, bind to compile-time constraints and check for equality. |
| 42 | + - Otherwise: check per-dimension; if `strides == NULL`, derive from `shape` and compare (e.g., contiguous: `strides[-1] == 1`, `strides[-2] == shape[-1]`). |
| 43 | +- `byte_offset` |
| 44 | + - Must be 0 (non-zero raises an error) to keep addressing simple and aligned. |
| 45 | +- Device info |
| 46 | + - Assert `device_type == target backend` (CUDA/ROCM/Metal/OneAPI/WebGPU/CPU, etc.). Error messages include a DLPack code legend. |
| 47 | + - When multiple tensors participate, assert that `device_id` matches across them. |
| 48 | +- Data pointer |
| 49 | + - Must be non-NULL when the tensor is required to be non-null by the nullability rule. |
| 50 | + |
| 51 | +### 3) Scalar checks |
| 52 | +- `T.int*` family: require integer; error: `Expect arg[i] to be int`. |
| 53 | +- `T.bool`: require boolean; error: `Expect arg[i] to be boolean`. |
| 54 | + |
| 55 | +--- |
| 56 | + |
| 57 | +## Shapes and Symbolic Equations: Linear Solving |
| 58 | +When shapes are symbolic, the host binds and (when possible) solves linear relations at runtime (only one unknown per check point). Example: |
| 59 | + |
| 60 | +```python |
| 61 | +@T.prim_func |
| 62 | +def main( |
| 63 | + A: T.Tensor((m,), dtype), |
| 64 | + B: T.Tensor((m + n,), dtype), |
| 65 | + C: T.Tensor((n * k,), dtype), |
| 66 | +): |
| 67 | + ... |
| 68 | +``` |
| 69 | + |
| 70 | +This enables enforcing cross-tensor relationships like `len(B) == m + n` and `len(C) == n * k` at runtime. |
| 71 | + |
| 72 | +--- |
| 73 | + |
| 74 | +## Nullability Rules and Examples |
| 75 | +Which tensors may be NULL? |
| 76 | + |
| 77 | +- Rule: If an input tensor is not used by the function under static analysis (i.e., the access is statically unreachable), it is considered Nullable; otherwise it must be non-NULL. |
| 78 | +- Examples: |
| 79 | + |
| 80 | +1) Must be non-NULL (used) |
| 81 | +```python |
| 82 | +@T.prim_func |
| 83 | +def main(A: T.Tensor((M, K), dtype)): |
| 84 | + A[0] = 1 |
| 85 | +``` |
| 86 | +Passing `None` raises: `main.A_handle is expected to have non-NULL pointer`. |
| 87 | + |
| 88 | +2) Still must be non-NULL (constant-true branch) |
| 89 | +```python |
| 90 | +some_cond: bool = True |
| 91 | +@T.prim_func |
| 92 | +def main(A: T.Tensor((M, K), dtype)): |
| 93 | + if some_cond: |
| 94 | + A[0] = 1 |
| 95 | +``` |
| 96 | + |
| 97 | +3) Nullable (constant-false branch, statically unreachable) |
| 98 | +```python |
| 99 | +some_cond: bool = False |
| 100 | +@T.prim_func |
| 101 | +def main(A: T.Tensor((M, K), dtype)): |
| 102 | + if some_cond: |
| 103 | + A[0] = 1 |
| 104 | +``` |
| 105 | + |
| 106 | +4) Must be non-NULL (runtime condition) |
| 107 | +```python |
| 108 | +@T.prim_func |
| 109 | +def main(A: T.Tensor((M, K), dtype), some_cond: T.bool): |
| 110 | + if some_cond: |
| 111 | + A[0] = 1 |
| 112 | +``` |
| 113 | +Since `some_cond` is only known at runtime, static analysis cannot prove `A` is unused; `A` is thus non-nullable. |
| 114 | + |
| 115 | +--- |
| 116 | + |
| 117 | +## Device Type Codes (DLPack) |
| 118 | +Supported and referenced device codes in error messages: `1=CPU, 2=CUDA, 7=Vulkan, 8=Metal, 10=ROCM, 14=OneAPI, 15=WebGPU`. |
| 119 | +Kernels assert that `device_type` matches the target backend, and require `device_id` consistency across tensors. |
| 120 | + |
| 121 | +--- |
| 122 | + |
| 123 | +## Common Error Examples (What you’ll see) |
| 124 | +- Argument count mismatch (num_args) |
| 125 | + - Trigger: missing/extra argument |
| 126 | + - Error: `<kernel>: num_args should be N; expected: <num_args>, got: N` |
| 127 | + |
| 128 | +- Pointer-typed argument expected |
| 129 | + - Trigger: scalar passed where a tensor is expected |
| 130 | + - Error: `<kernel>: Expect arg[i] to be pointer` |
| 131 | + |
| 132 | +- Rank (ndim) mismatch |
| 133 | + - Trigger: runtime rank differs from compile-time rank |
| 134 | + - Error: `<kernel>.<name>.ndim is expected to equal R, but got mismatched ndim` |
| 135 | + |
| 136 | +- Dtype mismatch |
| 137 | + - Trigger: dtype not equal to the compiled dtype and not within the tolerance set |
| 138 | + - Error: `<kernel>.<name>.dtype is expected to be <dtype>, but got incompatible dtype` |
| 139 | + |
| 140 | +- Shape constraint violation |
| 141 | + - Trigger: a dimension doesn’t match a constant/symbol binding |
| 142 | + - Error: `Argument <kernel>.<name>.shape[i] has an unsatisfied constraint: ... == <expected>` |
| 143 | + |
| 144 | +- Strides check failed (e.g., non-contiguous layout) |
| 145 | + - Trigger: transposed/sliced tensors that violate expected strides |
| 146 | + - Error: `Argument <kernel>.<name>.strides[j] has an unsatisfied constraint: ... == <expected>` |
| 147 | + |
| 148 | +- Device type mismatch |
| 149 | + - Trigger: calling a CUDA kernel with CPU tensors, etc. |
| 150 | + - Error: `<kernel>.<name>.device_type mismatch [expected: <code> (<name>)] ...` |
| 151 | + |
| 152 | +- Device id mismatch |
| 153 | + - Trigger: mixing tensors from different GPUs |
| 154 | + - Error: `Argument <kernel>.<name>.device_id has an unsatisfied constraint: ... == ...` |
| 155 | + |
| 156 | +- NULL data pointer |
| 157 | + - Trigger: tensor required to be non-null has a NULL data pointer |
| 158 | + - Error: `<kernel>.<name> is expected to have non-NULL data pointer, but got NULL` |
| 159 | + |
| 160 | +- Scalar type mismatch |
| 161 | + - Trigger: passing float to `T.int32`, or non-boolean to `T.bool` |
| 162 | + - Error: `<kernel>: Expect arg[i] to be int/boolean` |
| 163 | + |
| 164 | +--- |
| 165 | + |
| 166 | +## Troubleshooting Tips |
| 167 | +- Print the host source: `print(fn.get_host_source())` to see the exact assertion and expected vs. actual fields. |
| 168 | +- Fix strides: call `.contiguous()` for non-contiguous tensors, or avoid generating transposed/sliced layouts that break assumptions. |
| 169 | +- Align devices: ensure all participating tensors share the same `device_type` and `device_id`. |
| 170 | +- Align dtype: use `.to(<dtype>)` or construct tensors with the correct dtype; pay attention to `float8` and `bool` tolerance. |
| 171 | +- Dynamic shapes: ensure cross-tensor linear relations can be uniquely determined at the check point (only one unknown at a time). |
| 172 | + |
| 173 | +--- |
| 174 | + |
| 175 | +## FAQ |
| 176 | +- Can I disable the checks? |
| 177 | + - Not recommended and usually not supported. Checks are done on the host to preserve ABI stability and fail early close to the device call. |
| 178 | +- Is the overhead noticeable? |
| 179 | + - The checks are lightweight (branches and field reads). Compared to Python-side checks, it’s faster; the dominating cost remains the Python→C boundary. Overall it’s cheaper than equivalent checks in Python. |
| 180 | + |
| 181 | +--- |
| 182 | + |
| 183 | +## Reference Example (Matmul + ReLU) |
| 184 | + |
| 185 | +```python |
| 186 | +@T.prim_func |
| 187 | +def matmul_relu_kernel( |
| 188 | + A: T.Tensor((M, K), dtype), |
| 189 | + B: T.Tensor((K, N), dtype), |
| 190 | + C: T.Tensor((M, N), dtype), |
| 191 | +): |
| 192 | + # Initialize Kernel Context |
| 193 | + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): |
| 194 | + A_shared = T.alloc_shared((block_M, block_K), dtype) |
| 195 | + B_shared = T.alloc_shared((block_K, block_N), dtype) |
| 196 | + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) |
| 197 | + T.clear(C_local) |
| 198 | + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0): |
| 199 | + T.copy(A[by * block_M, ko * block_K], A_shared) |
| 200 | + T.copy(B[ko * block_K, bx * block_N], B_shared) |
| 201 | + T.gemm(A_shared, B_shared, C_local) |
| 202 | + T.copy(C_local, C[by * block_M, bx * block_N]) |
| 203 | + |
| 204 | +# For debugging, print the host source |
| 205 | +print(matmul_relu_kernel.get_host_source()) |
| 206 | +``` |
| 207 | + |
| 208 | +The host will insert all checks described above for this example. |
| 209 | + |
| 210 | +--- |
| 211 | + |
| 212 | +## Quick Error Reference (Short List) |
| 213 | +- Argument count |
| 214 | + - Trigger: missing/extra args; Error: `num_args should be N; expected: <num_args>, got: N`. |
| 215 | +- Pointer kind |
| 216 | + - Trigger: scalar passed to tensor arg; Error: `Expect arg[i] to be pointer`. |
| 217 | +- Rank (ndim) |
| 218 | + - Trigger: runtime rank != compile-time; Error: `ndim ... expected to equal R`. |
| 219 | +- Dtype |
| 220 | + - Trigger: mismatch and not tolerated; Error: `dtype ... expected to be <dtype>`. |
| 221 | +- Shape |
| 222 | + - Trigger: constant/symbol binding violated; Error: `shape[i] ... == <expected>`. |
| 223 | +- Strides |
| 224 | + - Trigger: layout mismatch; Error: `strides[j] ... == <expected>`. |
| 225 | +- Device type |
| 226 | + - Trigger: wrong backend device; Error: `device_type mismatch [expected: ...]`. |
| 227 | +- Device id |
| 228 | + - Trigger: tensors on different GPUs; Error: `device_id ... == ...`. |
| 229 | +- Data pointer |
| 230 | + - Trigger: required non-NULL but NULL; Error: `non-NULL data pointer`. |
| 231 | +- Scalar types |
| 232 | + - Trigger: wrong scalar type; Error: `Expect arg[i] to be int/boolean`. |
| 233 | + |
| 234 | +--- |
| 235 | + |
| 236 | +## Host Error Troubleshooting (Minimal Repros) |
| 237 | + |
| 238 | +Below are minimal repro snippets for common host-side errors, assuming a CUDA-targeted kernel like `matmul_relu_kernel` with: |
| 239 | + |
| 240 | +```python |
| 241 | +# Convention: |
| 242 | +# A: float16 [M, K] |
| 243 | +# B: float16 [K, N] |
| 244 | +# C: float16 [M, N] |
| 245 | +# Target: CUDA (device_type=2) |
| 246 | +fn = matmul_relu_kernel # your compiled function |
| 247 | +M = N = K = 1024 |
| 248 | +``` |
| 249 | + |
| 250 | +Adjust dtype/device if your kernel differs. |
| 251 | + |
| 252 | +### 0. Tip: print the host source |
| 253 | +```python |
| 254 | +print(fn.get_host_source()) |
| 255 | +``` |
| 256 | + |
| 257 | +### 1. num_args mismatch |
| 258 | +```python |
| 259 | +import torch |
| 260 | + |
| 261 | +A = torch.empty((M, K), device='cuda', dtype=torch.float16) |
| 262 | +B = torch.empty((K, N), device='cuda', dtype=torch.float16) |
| 263 | +# Missing C |
| 264 | +fn(A, B) |
| 265 | +``` |
| 266 | +Expected: `<kernel>: num_args should be 3; expected: <num_args>, got: 3`. |
| 267 | + |
| 268 | +Fix: pass all arguments per the signature. |
| 269 | + |
| 270 | +### 2. Expect pointer (tensor) but got scalar |
| 271 | +```python |
| 272 | +import torch |
| 273 | + |
| 274 | +B = torch.empty((K, N), device='cuda', dtype=torch.float16) |
| 275 | +C = torch.empty((M, N), device='cuda', dtype=torch.float16) |
| 276 | +fn(1, B, C) |
| 277 | +``` |
| 278 | +Expected: `<kernel>: Expect arg[0] to be pointer`. |
| 279 | + |
| 280 | +Fix: pass a DLPack-compatible tensor (e.g., torch.Tensor). |
| 281 | + |
| 282 | +### 3. ndim mismatch |
| 283 | +```python |
| 284 | +import torch |
| 285 | + |
| 286 | +A = torch.empty((M, K, 1), device='cuda', dtype=torch.float16) # rank=3 |
| 287 | +B = torch.empty((K, N), device='cuda', dtype=torch.float16) |
| 288 | +C = torch.empty((M, N), device='cuda', dtype=torch.float16) |
| 289 | +fn(A, B, C) |
| 290 | +``` |
| 291 | +Expected: `<kernel>.A_handle.ndim is expected to equal 2, but got mismatched ndim`. |
| 292 | + |
| 293 | +Fix: ensure runtime rank equals compiled rank. |
| 294 | + |
| 295 | +### 4. dtype mismatch |
| 296 | +```python |
| 297 | +import torch |
| 298 | + |
| 299 | +A = torch.empty((M, K), device='cuda', dtype=torch.float32) # should be float16 |
| 300 | +B = torch.empty((K, N), device='cuda', dtype=torch.float16) |
| 301 | +C = torch.empty((M, N), device='cuda', dtype=torch.float16) |
| 302 | +fn(A, B, C) |
| 303 | +``` |
| 304 | +Expected: `<kernel>.A_handle.dtype is expected to be float16, but got incompatible dtype`. |
| 305 | + |
| 306 | +Fix: `A = A.to(torch.float16)` or create with the correct dtype. |
| 307 | + |
| 308 | +### 5. Shape constant/symbol mismatch |
| 309 | +```python |
| 310 | +import torch |
| 311 | + |
| 312 | +A = torch.empty((M, K + 1), device='cuda', dtype=torch.float16) # K mismatched |
| 313 | +B = torch.empty((K, N), device='cuda', dtype=torch.float16) |
| 314 | +C = torch.empty((M, N), device='cuda', dtype=torch.float16) |
| 315 | +fn(A, B, C) |
| 316 | +``` |
| 317 | +Expected: `Argument <kernel>.A_handle.shape[i] has an unsatisfied constraint: ... == <expected>`. |
| 318 | + |
| 319 | +Fix: satisfy linear constraints and constants across tensors. |
| 320 | + |
| 321 | +### 6. Strides check failure (non-contiguous) |
| 322 | +```python |
| 323 | +import torch |
| 324 | + |
| 325 | +A = torch.empty((M, K), device='cuda', dtype=torch.float16) |
| 326 | +A_nc = A.t() # transpose -> non-contiguous |
| 327 | +B = torch.empty((K, N), device='cuda', dtype=torch.float16) |
| 328 | +C = torch.empty((M, N), device='cuda', dtype=torch.float16) |
| 329 | +fn(A_nc, B, C) |
| 330 | +``` |
| 331 | +Expected: `Argument <kernel>.A_handle.strides[1] has an unsatisfied constraint: ... == 1`. |
| 332 | + |
| 333 | +Fix: pass `A_nc.contiguous()` or align the layout expectation in the kernel. |
| 334 | + |
| 335 | +### 7. device_type mismatch |
| 336 | +```python |
| 337 | +import torch |
| 338 | + |
| 339 | +A = torch.empty((M, K), device='cpu', dtype=torch.float16) |
| 340 | +B = torch.empty((K, N), device='cpu', dtype=torch.float16) |
| 341 | +C = torch.empty((M, N), device='cpu', dtype=torch.float16) |
| 342 | +fn(A, B, C) # CUDA-targeted kernel |
| 343 | +``` |
| 344 | +Expected: `<kernel>.A_handle.device_type mismatch [expected: 2 (cuda)] ...`. |
| 345 | + |
| 346 | +Fix: move tensors to the CUDA device. |
| 347 | + |
| 348 | +### 8. device_id mismatch (multi-GPU) |
| 349 | +```python |
| 350 | +import torch |
| 351 | + |
| 352 | +A = torch.empty((M, K), device='cuda:0', dtype=torch.float16) |
| 353 | +B = torch.empty((K, N), device='cuda:1', dtype=torch.float16) |
| 354 | +C = torch.empty((M, N), device='cuda:0', dtype=torch.float16) |
| 355 | +fn(A, B, C) |
| 356 | +``` |
| 357 | +Expected: `Argument <kernel>.B_handle.device_id has an unsatisfied constraint: ... == ...`. |
| 358 | + |
| 359 | +Fix: place all tensors on the same GPU (e.g., `cuda:0`). |
| 360 | + |
| 361 | +### 9. NULL data pointer (advanced) |
| 362 | +This usually comes from hand-constructed DLTensor/NDArray, or external frameworks passing unallocated/freed storage. Regular `torch.Tensor` allocations rarely hit this. |
| 363 | + |
| 364 | +Expected: `<kernel>.<name> is expected to have non-NULL data pointer, but got NULL`. |
| 365 | + |
| 366 | +Fix: ensure valid underlying storage; in PyTorch scenarios, avoid constructing tensors from invalid external handles. |
| 367 | + |
| 368 | +### 10. Scalar type mismatch (int / bool) |
| 369 | +```python |
| 370 | +import tilelang.language as T |
| 371 | + |
| 372 | +@T.prim_func |
| 373 | +def scalar_check(x: T.int32, flag: T.bool()): |
| 374 | + T.evaluate(0) |
| 375 | + |
| 376 | +scalar_check(1.0, True) # x is float -> Expect arg[0] to be int |
| 377 | +scalar_check(1, 2.5) # flag is float -> Expect arg[1] to be boolean |
| 378 | +``` |
| 379 | + |
| 380 | +Fix: pass correct scalar types, e.g., `scalar_check(1, True)`. |
| 381 | + |
| 382 | +--- |
| 383 | + |
| 384 | +## Closing Notes |
| 385 | +- Cross-check “shape / strides / device / dtype” against the kernel signature to localize issues efficiently. |
| 386 | +- For complex symbolic relations, print the host source to confirm binding/solving order, then adjust runtime shapes/layouts accordingly. |
| 387 | + |
0 commit comments