-
Notifications
You must be signed in to change notification settings - Fork 423
Add an example: mHC residual projection backward #1758
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughAdds a new example module implementing a Sinkhorn forward pass and an implicit-differentiation backward pass using TileLang JIT kernels, including tiled matvec/dot helpers, a conjugate-gradient solver with shared GPU buffers, and a script comparing autograd vs implicit gradients. Changes
Sequence DiagramsequenceDiagram
participant Script as "Python Script"
participant Forward as "sinkhorn_forward()"
participant Autograd as "PyTorch Autograd"
participant Backward as "sinkhorn_bwd_implicit_cg()"
participant CG as "Conjugate Gradient Kernel"
participant GPU as "GPU Memory"
Script->>GPU: allocate random M
Script->>Forward: call sinkhorn_forward(M, iters)
Forward->>Forward: compute P = exp(M) and iteratively normalize rows/cols -> R
Forward-->>Script: return (R, P)
Script->>Autograd: compute grad_M_autograd via backward()
Autograd-->>Script: grad_M_autograd
Script->>Backward: compile/configure implicit kernel (tile sizes, streams)
Backward->>GPU: setup tiled/shared buffers and launch params
Backward->>CG: invoke matvec_A and dot ops
CG->>CG: run CG iterations (matvec, reductions, updates, EPS guards)
CG-->>Backward: res / grad_M_implicit
Backward-->>Script: grad_M_implicit
Script->>Script: compare grad_M_autograd vs grad_M_implicit (MAE, rel-diff)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🤖 Fix all issues with AI agents
In `@examples/deepseek_mhc/example_mhc_res.py`:
- Around line 18-25: The Sinkhorn implementation in sinkhorn_forward computes P
= torch.exp(M) and repeatedly normalizes R by dividing by row/column sums, which
can be zero due to underflow; modify sinkhorn_forward to clamp the row and
column denominator tensors with a small epsilon (e.g., 1e-8) before division
(use R.sum(-2, keepdim=True).clamp_min(eps) and R.sum(-1,
keepdim=True).clamp_min(eps)) so divisions R = R / denom never produce NaNs;
keep use of variables P and R and the existing loop/iters logic unchanged.
- Around line 83-85: The R buffer allocation is missing the dtype argument
causing a type mismatch with the macro signature; update the T.alloc_shared call
that creates R (symbol: R) to pass dtype=dtype just like dR and RdR (symbols:
dR, RdR) so all three use T.alloc_shared([tilesize, n_stream, n_stream],
dtype=dtype) and match the expected T.SharedBuffer([tilesize, n_stream,
n_stream], dtype) signature.
- Around line 82-107: The kernel launches ceildiv(seqlen, tilesize) tiles but
copies full tiles into shared buffers R and dR via T.copy (in the T.Kernel with
i_seq) which leaves the tail portion uninitialized when seqlen % tilesize != 0;
fix by adding bounds handling: either pad the input tensors (out and dout) to a
multiple of tilesize before invoking the kernel, or add index masks inside the
kernel (use the tile/global index i_seq and per-tile index i_tile) to guard
accesses and copies with a condition like i_seq * tilesize + i_tile < seqlen so
T.copy and subsequent Parallel(tilesize, ...) loops only read/write valid
positions; update all uses of R, dR and downstream loops (e.g., Parallel over
tiles that consume R/dR and the final write-back) to respect the same bounds
check.
🧹 Nitpick comments (1)
examples/deepseek_mhc/example_mhc_res.py (1)
167-234: Consider guarding the example driver underif __name__ == "__main__":.This prevents GPU‑heavy work from running on import and makes it easier to reuse
sinkhorn_forward/sinkhorn_bwd_implicit_cgas a library module.
| def sinkhorn_forward(M, iters=20): | ||
| P = torch.exp(M) | ||
| R = P | ||
|
|
||
| for _ in range(iters): | ||
| R = R / R.sum(-2, keepdim=True) | ||
| R = R / R.sum(-1, keepdim=True) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard Sinkhorn normalization against zero/underflow to prevent NaNs.
Row/column sums can hit zero for extreme inputs (large negative cost matrices underflow via torch.exp), yielding NaN on division. Add epsilon clamping to the denominators.
Proposed fix
+EPS = 1e-10
def sinkhorn_forward(M, iters=20):
P = torch.exp(M)
R = P
for _ in range(iters):
- R = R / R.sum(-2, keepdim=True)
- R = R / R.sum(-1, keepdim=True)
+ R = R / (R.sum(-2, keepdim=True) + EPS)
+ R = R / (R.sum(-1, keepdim=True) + EPS)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def sinkhorn_forward(M, iters=20): | |
| P = torch.exp(M) | |
| R = P | |
| for _ in range(iters): | |
| R = R / R.sum(-2, keepdim=True) | |
| R = R / R.sum(-1, keepdim=True) | |
| EPS = 1e-10 | |
| def sinkhorn_forward(M, iters=20): | |
| P = torch.exp(M) | |
| R = P | |
| for _ in range(iters): | |
| R = R / (R.sum(-2, keepdim=True) + EPS) | |
| R = R / (R.sum(-1, keepdim=True) + EPS) | |
🤖 Prompt for AI Agents
In `@examples/deepseek_mhc/example_mhc_res.py` around lines 18 - 25, The Sinkhorn
implementation in sinkhorn_forward computes P = torch.exp(M) and repeatedly
normalizes R by dividing by row/column sums, which can be zero due to underflow;
modify sinkhorn_forward to clamp the row and column denominator tensors with a
small epsilon (e.g., 1e-8) before division (use R.sum(-2,
keepdim=True).clamp_min(eps) and R.sum(-1, keepdim=True).clamp_min(eps)) so
divisions R = R / denom never produce NaNs; keep use of variables P and R and
the existing loop/iters logic unchanged.
| with T.Kernel(T.ceildiv(seqlen, tilesize), threads=threads) as i_seq: | ||
| R = T.alloc_shared([tilesize, n_stream, n_stream]) | ||
| dR = T.alloc_shared([tilesize, n_stream, n_stream], dtype=dtype) | ||
| RdR = T.alloc_shared([tilesize, n_stream, n_stream], dtype=dtype) | ||
| res_tile = T.alloc_shared([tilesize, n_stream, n_stream], dtype=dtype) | ||
| b1 = T.alloc_shared([tilesize, n_stream], dtype=dtype) | ||
| b2 = T.alloc_shared([tilesize, n_stream], dtype=dtype) | ||
| x1 = T.alloc_shared([tilesize, n_stream], dtype=dtype) | ||
| x2 = T.alloc_shared([tilesize, n_stream], dtype=dtype) | ||
| r1 = T.alloc_shared([tilesize, n_stream], dtype=dtype) | ||
| r2 = T.alloc_shared([tilesize, n_stream], dtype=dtype) | ||
| p1 = T.alloc_shared([tilesize, n_stream], dtype=dtype) | ||
| p2 = T.alloc_shared([tilesize, n_stream], dtype=dtype) | ||
| r_normsq = T.alloc_shared([tilesize], dtype=dtype) | ||
| r_new_normsq = T.alloc_shared([tilesize], dtype=dtype) | ||
| Ap1 = T.alloc_shared([tilesize, n_stream], dtype=dtype) | ||
| Ap2 = T.alloc_shared([tilesize, n_stream], dtype=dtype) | ||
| pAp = T.alloc_shared([tilesize], dtype=dtype) | ||
|
|
||
| # Buffers for intermediate results | ||
| buf1 = T.alloc_shared([tilesize, n_stream, n_stream], dtype=dtype) | ||
| buf2 = T.alloc_shared([tilesize, n_stream, n_stream], dtype=dtype) | ||
| buf3 = T.alloc_shared([tilesize, n_stream], dtype=dtype) | ||
|
|
||
| T.copy(out[i_seq * tilesize : (i_seq + 1) * tilesize, :, :], R) | ||
| T.copy(dout[i_seq * tilesize : (i_seq + 1) * tilesize, :, :], dR) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
cat -n examples/deepseek_mhc/example_mhc_res.py | head -130Repository: tile-ai/tilelang
Length of output: 5806
🏁 Script executed:
cat -n examples/deepseek_mhc/example_mhc_res.py | sed -n '82,180p'Repository: tile-ai/tilelang
Length of output: 5308
🏁 Script executed:
rg -A 5 "T.ceildiv" --type py | head -60Repository: tile-ai/tilelang
Length of output: 5728
🏁 Script executed:
rg -B 5 -A 5 "i_tile.*<.*seqlen\|i_bx.*<\|bounds\|tail" --type py examples/ | head -80Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
rg -A 10 "def copy" tilelang/language/ --type py | head -50Repository: tile-ai/tilelang
Length of output: 1294
🏁 Script executed:
cat -n tilelang/language/copy_op.py | head -100Repository: tile-ai/tilelang
Length of output: 5515
🏁 Script executed:
rg -B 3 -A 10 "T.Parallel.*tilesize\|mask\|if.*tilesize" --type py testing/ | head -100Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
find . -name "*mhc*" -type fRepository: tile-ai/tilelang
Length of output: 233
🏁 Script executed:
cat examples/deepseek_mhc/test_example_mhc.pyRepository: tile-ai/tilelang
Length of output: 373
🏁 Script executed:
python3 << 'EOF'
seqlen = 4096
tilesize = 32
print(f"seqlen = {seqlen}, tilesize = {tilesize}")
print(f"seqlen % tilesize = {seqlen % tilesize}")
print(f"Num blocks = {(seqlen + tilesize - 1) // tilesize}")
print(f"Has tail tile: {seqlen % tilesize != 0}")
EOFRepository: tile-ai/tilelang
Length of output: 148
🏁 Script executed:
cat -n tilelang/language/copy_op.py | sed -n '45,150p'Repository: tile-ai/tilelang
Length of output: 5590
🏁 Script executed:
rg -A 15 "class Parallel\|def Parallel" tilelang/language/ --type py | head -60Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
rg "T\.Parallel" tilelang/language/ --type py -lRepository: tile-ai/tilelang
Length of output: 144
🏁 Script executed:
rg -B 5 -A 20 "def Parallel" tilelang/language/loop.pyRepository: tile-ai/tilelang
Length of output: 934
🏁 Script executed:
rg -A 40 "def Parallel" tilelang/language/loop.py | head -80Repository: tile-ai/tilelang
Length of output: 1785
🏁 Script executed:
rg -B 2 -A 10 "legalize_pairwise_extents" tilelang/utils/language.pyRepository: tile-ai/tilelang
Length of output: 628
Add bounds handling for tail tiles when seqlen % tilesize != 0.
When seqlen is not divisible by tilesize, T.ceildiv(seqlen, tilesize) launches a partial tail tile. The T.copy at lines 106–107 will only fill partial data into the full-sized R buffer, leaving the remainder uninitialized. The T.Parallel(tilesize, ...) loops at lines 109+ then iterate over all elements—including garbage from uninitialized memory—propagating incorrect values through subsequent operations. The result written back at line 162 will contain NaNs or garbage on the tail region.
Mask loop iterations on i_seq * tilesize + i_tile < seqlen or pad input tensors to a multiple of tilesize before the kernel.
🤖 Prompt for AI Agents
In `@examples/deepseek_mhc/example_mhc_res.py` around lines 82 - 107, The kernel
launches ceildiv(seqlen, tilesize) tiles but copies full tiles into shared
buffers R and dR via T.copy (in the T.Kernel with i_seq) which leaves the tail
portion uninitialized when seqlen % tilesize != 0; fix by adding bounds
handling: either pad the input tensors (out and dout) to a multiple of tilesize
before invoking the kernel, or add index masks inside the kernel (use the
tile/global index i_seq and per-tile index i_tile) to guard accesses and copies
with a condition like i_seq * tilesize + i_tile < seqlen so T.copy and
subsequent Parallel(tilesize, ...) loops only read/write valid positions; update
all uses of R, dR and downstream loops (e.g., Parallel over tiles that consume
R/dR and the final write-back) to respect the same bounds check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@examples/deepseek_mhc/example_mhc_res.py`:
- Around line 121-157: The reduction kernels matvec_A(...) and dot(...) perform
T.reduce_sum() into shared/result buffers (e.g., Ap1/Ap2, pAp, r_normsq,
r_new_normsq) but the subsequent T.Parallel loops read those results
immediately, causing read-after-write races; fix by inserting T.sync_threads()
immediately after each call to matvec_A(...) and dot(...) wherever their results
are consumed (for example after matvec_A(R, x1, x2, ...) before the first
T.Parallel that reads r1/r2, after dot(r1, r2, ...) before using r_normsq, after
matvec_A(...) inside the CG loop before reading Ap1/Ap2 and pAp, and after
dot(...) that computes r_new_normsq before the T.Parallel that computes beta and
updates p1/p2) so all threads see the completed reduction results.
|
I’ll help and take look, one helpful trick for debugging is replace |
It seems like LLM's suggestion is correct, I got correct results after inserting Probably triton inserted implicit synchonization in its reduce operations Is this a design choice to leave this to user? Or did I misunderstood something? @LeiWang1999 |
|
@Da1sypetals Yes and I also found the This is indeed a bug. The previous algorithm invoked Two solution:
r_new_normsq = T.alloc_shared([tilesize], dtype=dtype)
->
r_new_normsq = T.alloc_fragment([tilesize], dtype=dtype)
pAp = T.alloc_shared([tilesize], dtype=dtype)
->
pAp = T.alloc_fragment([tilesize], dtype=dtype)
We will have a fix asap, and sorry for the trouble. |
|
I made a fix at #1760 , thanks for pointing out the issue! |
|
@LeiWang1999 Thanks! Is it correct that after the fix, it will no longer be required to manually call |
I’d like to add an example for the mHC backward pass, but I’m running into some bugs that I haven’t been able to resolve on my own. I would really appreciate it if someone could take a look.
The algorithm is described here in my blog. More importantly, there is a Triton implementation available here.
I attempted to translate the Triton code into TileLang more or less word-for-word, but I couldn’t get correct results. The current implementation always produces
NaNs. Interestingly, if I add aprintinside theT.serialloop, the results become correct (or at least very close).I’d really appreciate it if someone could review the code and point out whether there is a bug in my implementation or if something else is going wrong 🌹
Summary by CodeRabbit
✏️ Tip: You can customize this high-level summary in your review settings.