-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Description
Summary
LaunchContextBuilder::argpack_ptrs stores a raw const ArgPack * pointer without any ownership semantics. On the CUDA/LLVM backend, Program::delete_argpack() always deletes immediately because LlvmProgramImpl does not override used_in_kernel() (base class returns false). If Python's garbage collector frees an ArgPack wrapper between set_arg_argpack() and the actual kernel launch, the CUDA kernel launcher dereferences a dangling pointer — causing wild writes and memory corruption.
This bug causes random SEGV / heap corruption in any workload that uses ArgPack with the CUDA backend, especially under GC pressure (many kernel launches, multi-threaded).
Root Cause
1. Raw pointer storage without ownership
// launch_context_builder.h:137-140
std::unordered_map<std::vector<int>,
const ArgPack *, // RAW POINTER — no ref counting
hashing::Hasher<std::vector<int>>>
argpack_ptrs;
// launch_context_builder.cpp:252-254
void LaunchContextBuilder::set_arg_argpack(const std::vector<int> &arg_id,
const ArgPack &argpack) {
argpack_ptrs[arg_id] = &argpack; // stores address, no ownership2. Dangling pointer dereference during kernel launch
// cuda/kernel_launcher.cpp:127-138
auto *argpack = ctx.argpack_ptrs[key]; // dangling if GC'd
auto argpack_ptr = argpack->get_device_allocation(); // wild read
// ...
auto *argpack_parent = ctx.argpack_ptrs[key_parent];
argpack_parent->set_arg_nested_argpack_ptr( // wild WRITE
key.back(), (uint64)device_ptrs[data_ptr_idx]);Same pattern in cpu/kernel_launcher.cpp:48-63, amdgpu/kernel_launcher.cpp:82-94, and gfx/runtime.cpp:486-490.
3. used_in_kernel() guard is unimplemented on LLVM backends
// program_impl.h:103-105 (base class)
virtual bool used_in_kernel(DeviceAllocationId) {
return false; // ALWAYS FALSE on CUDA/CPU/AMDGPU
}Only GfxProgramImpl overrides this (gfx_program.h:47-48). LlvmProgramImpl inherits the base class, so delete_argpack() (program.cpp:428-443) always deletes immediately on CUDA/CPU/AMDGPU backends, regardless of pending kernel launches.
4. Python GC triggers immediate C++ destruction
# argpack.py:76-78
def __del__(self):
if impl is not None and impl.get_runtime() is not None and impl.get_runtime().prog is not None:
impl.get_runtime().prog.delete_argpack(self.__argpack)The race
- Python calls
set_arg_argpack()→ raw&argpackstored inargpack_ptrs - Python GC runs (e.g., triggered by allocation pressure in another thread)
- GC collects ArgPack Python wrapper →
__del__→delete_argpack()→used_in_kernel()returnsfalse→ C++ ArgPack freed - Kernel launcher dereferences dangling
argpack_ptrs[key]→ wild write / SEGV
Design inconsistency
set_arg_ndarray() follows a safe pattern — it copies the data pointer as an integer:
// launch_context_builder.cpp:246
intptr_t ptr = arr.get_device_allocation_ptr_as_int(); // copies VALUE, safeset_arg_argpack() deviates from this by storing an object pointer (&argpack), which is unsafe.
Evidence
Observed in a multi-threaded MARL training workload using Genesis physics simulation with Taichi CUDA backend (2048 parallel envs). 12 independent crashes over weeks of debugging:
- Random SEGV in Python GC (
tp_traverseNULL, object type confusion:range→code,dict→bool) - ASan (LD_PRELOAD) reported zero heap UAF in 19 hours — because the free happens inside Taichi's device allocator, not via system
free() - ASan +
PYTHONMALLOC=malloc: ASan's own allocator metadata was corrupted by a wild write (CHECK failed:rz_size=0x0) - Crash vfx missing in 2d examples #12: faulthandler showed "Garbage-collecting" during
taichi.kernel_impl.launch_kernel— direct evidence of GC during kernel launch - Valgrind (serializes all threads): no crash in 10+ hours — consistent with a threading race
- More envs = faster crash (2048 envs: ~5 min, 1024 envs: ~30-90 min)
Genesis issue Genesis-Embodied-AI/Genesis#492 appears to be the same bug (segfault during scene.step() after 1000+ iterations, closed without root cause).
Workaround
Add the ArgPack Python object to the tmps GC-prevention list in kernel_impl.py, following the existing pattern used for numpy arrays (line 731: tmps.append(tmp) # Purpose: DO NOT GC |tmp|!):
# kernel_impl.py, inside recursive_set_args(), after set_arg_argpack():
launch_ctx.set_arg_argpack(indices, v._ArgPack__argpack)
tmps.append(v) # prevent GC of ArgPack while C++ holds raw pointerSuggested Fix
Option A (minimal): Apply the Python-side workaround above.
Option B (proper): Change argpack_ptrs to store DeviceAllocation by value instead of a raw object pointer, matching the safe pattern used by set_arg_ndarray(). This requires updating all kernel launchers. The nested argpack write-back (set_arg_nested_argpack_ptr) would need refactoring.
Option C (defense-in-depth): Implement used_in_kernel() in LlvmProgramImpl to actually track in-flight allocations, matching the existing GFX backend implementation.
Additional note: array_ptrs has the same pattern
set_arg_ndarray() stores (void *)&ndarray_alloc_ (address of a member inside the Ndarray object) in array_ptrs. If the Ndarray is GC'd, the member address becomes invalid. The CUDA kernel launcher dereferences this at cuda/kernel_launcher.cpp:110-112:
DeviceAllocation *ptr = static_cast<DeviceAllocation *>(data_ptr);
device_ptrs[data_ptr_idx] = executor->get_device_alloc_info_ptr(*ptr);Same vulnerability, though Ndarrays tend to be long-lived so it's less likely to trigger in practice.
Environment
- Taichi 1.7.4 (commit b4b956f)
- Python 3.12, CUDA 12.8, Linux 6.17
- NVIDIA RTX 5090
Affected code (introduced in July 2023)
- PR [lang] Add is_argpack property to Argument and pass argpack ptrs to LaunchContextBuilder #8257 (
525682fc0):argpack_ptrs[arg_id] = &argpack - PR [lang] Instantiate a runtime ArgPack object when a python ArgPack is created #8241 (
29cfb5c72):delete_argpack()withused_in_kernel()guard - PR [spirv] [ir] Support argpack buffer load for spir-v backends #8263 (
cfad91fc8):argpacks_in_use_tracking added to GFX only - PR [llvm] [ir] Implemented argpack buffer load on llvm devices #8267 (
22a32e3a7): LLVM kernel launchers dereferenceargpack_ptrs
Metadata
Metadata
Assignees
Labels
Type
Projects
Status