Skip to content

Conversation

@Da1sypetals
Copy link

@Da1sypetals Da1sypetals commented Jan 29, 2026

I’d like to add an example for the mHC backward pass, but I’m running into some bugs that I haven’t been able to resolve on my own. I would really appreciate it if someone could take a look.

The algorithm is described here in my blog. More importantly, there is a Triton implementation available here.

I attempted to translate the Triton code into TileLang more or less word-for-word, but I couldn’t get correct results. The current implementation always produces NaNs. Interestingly, if I add a print inside the T.serial loop, the results become correct (or at least very close).

I’d really appreciate it if someone could review the code and point out whether there is a bug in my implementation or if something else is going wrong 🌹

Summary by CodeRabbit

  • New Features
    • Added a Sinkhorn example demonstrating a forward Sinkhorn solver with configurable iterations.
    • Added an implicit-differentiation backward routine using a conjugate-gradient based solver with tunable tiling/threads parameters.
    • Included a runnable comparison script that computes and prints gradient discrepancies and diagnostic statistics between autodiff and implicit methods.

✏️ Tip: You can customize this high-level summary in your review settings.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 29, 2026

📝 Walkthrough

Walkthrough

Adds a new example module implementing a Sinkhorn forward pass and an implicit-differentiation backward pass using TileLang JIT kernels, including tiled matvec/dot helpers, a conjugate-gradient solver with shared GPU buffers, and a script comparing autograd vs implicit gradients.

Changes

Cohort / File(s) Summary
Sinkhorn example module
examples/deepseek_mhc/example_mhc_res.py
New file (+302 lines) adding sinkhorn_forward(M, iters) which computes P=exp(M) with iterative row/column normalizations returning (R, P), and sinkhorn_bwd_implicit_cg(n_stream, tilesize, threads) which returns a TileLang JIT main kernel implementing tiled matvec_A, dot reductions, a conjugate-gradient loop using shared buffers and EPS safeguards for implicit differentiation; includes a runnable script that allocates CUDA tensors, computes autograd gradients, runs the implicit kernel, and prints MAE/relative-difference statistics.

Sequence Diagram

sequenceDiagram
    participant Script as "Python Script"
    participant Forward as "sinkhorn_forward()"
    participant Autograd as "PyTorch Autograd"
    participant Backward as "sinkhorn_bwd_implicit_cg()"
    participant CG as "Conjugate Gradient Kernel"
    participant GPU as "GPU Memory"

    Script->>GPU: allocate random M
    Script->>Forward: call sinkhorn_forward(M, iters)
    Forward->>Forward: compute P = exp(M) and iteratively normalize rows/cols -> R
    Forward-->>Script: return (R, P)

    Script->>Autograd: compute grad_M_autograd via backward()
    Autograd-->>Script: grad_M_autograd

    Script->>Backward: compile/configure implicit kernel (tile sizes, streams)
    Backward->>GPU: setup tiled/shared buffers and launch params
    Backward->>CG: invoke matvec_A and dot ops
    CG->>CG: run CG iterations (matvec, reductions, updates, EPS guards)
    CG-->>Backward: res / grad_M_implicit
    Backward-->>Script: grad_M_implicit

    Script->>Script: compare grad_M_autograd vs grad_M_implicit (MAE, rel-diff)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

🐇 I hop through tiles and exponentials bright,

I nudge rows and columns until they’re right,
CG hums in shared-memory night,
Implicit gradients glimmer in sight,
I count the diffs and bounce away light.

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 12.50% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'Add an example: mHC residual projection backward' directly and specifically describes the main change in the PR: adding a new example module implementing the mHC (Sinkhorn) backward pass with implicit differentiation using TileLang.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Fix all issues with AI agents
In `@examples/deepseek_mhc/example_mhc_res.py`:
- Around line 18-25: The Sinkhorn implementation in sinkhorn_forward computes P
= torch.exp(M) and repeatedly normalizes R by dividing by row/column sums, which
can be zero due to underflow; modify sinkhorn_forward to clamp the row and
column denominator tensors with a small epsilon (e.g., 1e-8) before division
(use R.sum(-2, keepdim=True).clamp_min(eps) and R.sum(-1,
keepdim=True).clamp_min(eps)) so divisions R = R / denom never produce NaNs;
keep use of variables P and R and the existing loop/iters logic unchanged.
- Around line 83-85: The R buffer allocation is missing the dtype argument
causing a type mismatch with the macro signature; update the T.alloc_shared call
that creates R (symbol: R) to pass dtype=dtype just like dR and RdR (symbols:
dR, RdR) so all three use T.alloc_shared([tilesize, n_stream, n_stream],
dtype=dtype) and match the expected T.SharedBuffer([tilesize, n_stream,
n_stream], dtype) signature.
- Around line 82-107: The kernel launches ceildiv(seqlen, tilesize) tiles but
copies full tiles into shared buffers R and dR via T.copy (in the T.Kernel with
i_seq) which leaves the tail portion uninitialized when seqlen % tilesize != 0;
fix by adding bounds handling: either pad the input tensors (out and dout) to a
multiple of tilesize before invoking the kernel, or add index masks inside the
kernel (use the tile/global index i_seq and per-tile index i_tile) to guard
accesses and copies with a condition like i_seq * tilesize + i_tile < seqlen so
T.copy and subsequent Parallel(tilesize, ...) loops only read/write valid
positions; update all uses of R, dR and downstream loops (e.g., Parallel over
tiles that consume R/dR and the final write-back) to respect the same bounds
check.
🧹 Nitpick comments (1)
examples/deepseek_mhc/example_mhc_res.py (1)

167-234: Consider guarding the example driver under if __name__ == "__main__":.

This prevents GPU‑heavy work from running on import and makes it easier to reuse sinkhorn_forward/sinkhorn_bwd_implicit_cg as a library module.

Comment on lines +18 to +28
def sinkhorn_forward(M, iters=20):
P = torch.exp(M)
R = P

for _ in range(iters):
R = R / R.sum(-2, keepdim=True)
R = R / R.sum(-1, keepdim=True)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Guard Sinkhorn normalization against zero/underflow to prevent NaNs.

Row/column sums can hit zero for extreme inputs (large negative cost matrices underflow via torch.exp), yielding NaN on division. Add epsilon clamping to the denominators.

Proposed fix
+EPS = 1e-10
 def sinkhorn_forward(M, iters=20):
     P = torch.exp(M)
     R = P
 
     for _ in range(iters):
-        R = R / R.sum(-2, keepdim=True)
-        R = R / R.sum(-1, keepdim=True)
+        R = R / (R.sum(-2, keepdim=True) + EPS)
+        R = R / (R.sum(-1, keepdim=True) + EPS)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def sinkhorn_forward(M, iters=20):
P = torch.exp(M)
R = P
for _ in range(iters):
R = R / R.sum(-2, keepdim=True)
R = R / R.sum(-1, keepdim=True)
EPS = 1e-10
def sinkhorn_forward(M, iters=20):
P = torch.exp(M)
R = P
for _ in range(iters):
R = R / (R.sum(-2, keepdim=True) + EPS)
R = R / (R.sum(-1, keepdim=True) + EPS)
🤖 Prompt for AI Agents
In `@examples/deepseek_mhc/example_mhc_res.py` around lines 18 - 25, The Sinkhorn
implementation in sinkhorn_forward computes P = torch.exp(M) and repeatedly
normalizes R by dividing by row/column sums, which can be zero due to underflow;
modify sinkhorn_forward to clamp the row and column denominator tensors with a
small epsilon (e.g., 1e-8) before division (use R.sum(-2,
keepdim=True).clamp_min(eps) and R.sum(-1, keepdim=True).clamp_min(eps)) so
divisions R = R / denom never produce NaNs; keep use of variables P and R and
the existing loop/iters logic unchanged.

Comment on lines 82 to 110
with T.Kernel(T.ceildiv(seqlen, tilesize), threads=threads) as i_seq:
R = T.alloc_shared([tilesize, n_stream, n_stream])
dR = T.alloc_shared([tilesize, n_stream, n_stream], dtype=dtype)
RdR = T.alloc_shared([tilesize, n_stream, n_stream], dtype=dtype)
res_tile = T.alloc_shared([tilesize, n_stream, n_stream], dtype=dtype)
b1 = T.alloc_shared([tilesize, n_stream], dtype=dtype)
b2 = T.alloc_shared([tilesize, n_stream], dtype=dtype)
x1 = T.alloc_shared([tilesize, n_stream], dtype=dtype)
x2 = T.alloc_shared([tilesize, n_stream], dtype=dtype)
r1 = T.alloc_shared([tilesize, n_stream], dtype=dtype)
r2 = T.alloc_shared([tilesize, n_stream], dtype=dtype)
p1 = T.alloc_shared([tilesize, n_stream], dtype=dtype)
p2 = T.alloc_shared([tilesize, n_stream], dtype=dtype)
r_normsq = T.alloc_shared([tilesize], dtype=dtype)
r_new_normsq = T.alloc_shared([tilesize], dtype=dtype)
Ap1 = T.alloc_shared([tilesize, n_stream], dtype=dtype)
Ap2 = T.alloc_shared([tilesize, n_stream], dtype=dtype)
pAp = T.alloc_shared([tilesize], dtype=dtype)

# Buffers for intermediate results
buf1 = T.alloc_shared([tilesize, n_stream, n_stream], dtype=dtype)
buf2 = T.alloc_shared([tilesize, n_stream, n_stream], dtype=dtype)
buf3 = T.alloc_shared([tilesize, n_stream], dtype=dtype)

T.copy(out[i_seq * tilesize : (i_seq + 1) * tilesize, :, :], R)
T.copy(dout[i_seq * tilesize : (i_seq + 1) * tilesize, :, :], dR)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

cat -n examples/deepseek_mhc/example_mhc_res.py | head -130

Repository: tile-ai/tilelang

Length of output: 5806


🏁 Script executed:

cat -n examples/deepseek_mhc/example_mhc_res.py | sed -n '82,180p'

Repository: tile-ai/tilelang

Length of output: 5308


🏁 Script executed:

rg -A 5 "T.ceildiv" --type py | head -60

Repository: tile-ai/tilelang

Length of output: 5728


🏁 Script executed:

rg -B 5 -A 5 "i_tile.*<.*seqlen\|i_bx.*<\|bounds\|tail" --type py examples/ | head -80

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

rg -A 10 "def copy" tilelang/language/ --type py | head -50

Repository: tile-ai/tilelang

Length of output: 1294


🏁 Script executed:

cat -n tilelang/language/copy_op.py | head -100

Repository: tile-ai/tilelang

Length of output: 5515


🏁 Script executed:

rg -B 3 -A 10 "T.Parallel.*tilesize\|mask\|if.*tilesize" --type py testing/ | head -100

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

find . -name "*mhc*" -type f

Repository: tile-ai/tilelang

Length of output: 233


🏁 Script executed:

cat examples/deepseek_mhc/test_example_mhc.py

Repository: tile-ai/tilelang

Length of output: 373


🏁 Script executed:

python3 << 'EOF'
seqlen = 4096
tilesize = 32
print(f"seqlen = {seqlen}, tilesize = {tilesize}")
print(f"seqlen % tilesize = {seqlen % tilesize}")
print(f"Num blocks = {(seqlen + tilesize - 1) // tilesize}")
print(f"Has tail tile: {seqlen % tilesize != 0}")
EOF

Repository: tile-ai/tilelang

Length of output: 148


🏁 Script executed:

cat -n tilelang/language/copy_op.py | sed -n '45,150p'

Repository: tile-ai/tilelang

Length of output: 5590


🏁 Script executed:

rg -A 15 "class Parallel\|def Parallel" tilelang/language/ --type py | head -60

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

rg "T\.Parallel" tilelang/language/ --type py -l

Repository: tile-ai/tilelang

Length of output: 144


🏁 Script executed:

rg -B 5 -A 20 "def Parallel" tilelang/language/loop.py

Repository: tile-ai/tilelang

Length of output: 934


🏁 Script executed:

rg -A 40 "def Parallel" tilelang/language/loop.py | head -80

Repository: tile-ai/tilelang

Length of output: 1785


🏁 Script executed:

rg -B 2 -A 10 "legalize_pairwise_extents" tilelang/utils/language.py

Repository: tile-ai/tilelang

Length of output: 628


Add bounds handling for tail tiles when seqlen % tilesize != 0.

When seqlen is not divisible by tilesize, T.ceildiv(seqlen, tilesize) launches a partial tail tile. The T.copy at lines 106–107 will only fill partial data into the full-sized R buffer, leaving the remainder uninitialized. The T.Parallel(tilesize, ...) loops at lines 109+ then iterate over all elements—including garbage from uninitialized memory—propagating incorrect values through subsequent operations. The result written back at line 162 will contain NaNs or garbage on the tail region.

Mask loop iterations on i_seq * tilesize + i_tile < seqlen or pad input tensors to a multiple of tilesize before the kernel.

🤖 Prompt for AI Agents
In `@examples/deepseek_mhc/example_mhc_res.py` around lines 82 - 107, The kernel
launches ceildiv(seqlen, tilesize) tiles but copies full tiles into shared
buffers R and dR via T.copy (in the T.Kernel with i_seq) which leaves the tail
portion uninitialized when seqlen % tilesize != 0; fix by adding bounds
handling: either pad the input tensors (out and dout) to a multiple of tilesize
before invoking the kernel, or add index masks inside the kernel (use the
tile/global index i_seq and per-tile index i_tile) to guard accesses and copies
with a condition like i_seq * tilesize + i_tile < seqlen so T.copy and
subsequent Parallel(tilesize, ...) loops only read/write valid positions; update
all uses of R, dR and downstream loops (e.g., Parallel over tiles that consume
R/dR and the final write-back) to respect the same bounds check.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@examples/deepseek_mhc/example_mhc_res.py`:
- Around line 121-157: The reduction kernels matvec_A(...) and dot(...) perform
T.reduce_sum() into shared/result buffers (e.g., Ap1/Ap2, pAp, r_normsq,
r_new_normsq) but the subsequent T.Parallel loops read those results
immediately, causing read-after-write races; fix by inserting T.sync_threads()
immediately after each call to matvec_A(...) and dot(...) wherever their results
are consumed (for example after matvec_A(R, x1, x2, ...) before the first
T.Parallel that reads r1/r2, after dot(r1, r2, ...) before using r_normsq, after
matvec_A(...) inside the CG loop before reading Ap1/Ap2 and pAp, and after
dot(...) that computes r_new_normsq before the T.Parallel that computes beta and
updates p1/p2) so all threads see the completed reduction results.

@LeiWang1999 LeiWang1999 self-requested a review January 29, 2026 10:32
@LeiWang1999
Copy link
Member

I’ll help and take look, one helpful trick for debugging is replace T.alloc_shared([tilesize, n_stream, n_stream], dtype=dtype) with T.alloc_shared([tilesize, n_stream, n_stream], dtype=dtype, scope="shared") to use static shared memory for more readable codegen

@Da1sypetals
Copy link
Author

Da1sypetals commented Jan 29, 2026

Use T.sync_threads() after reductions to prevent read-after-write races.

It seems like LLM's suggestion is correct, I got correct results after inserting T.sync_threads.

Probably triton inserted implicit synchonization in its reduce operations triton.language.sum, and I was not aware of the need to insert synchronizations in tile-level programming.

Is this a design choice to leave this to user? Or did I misunderstood something? @LeiWang1999

@LeiWang1999
Copy link
Member

@Da1sypetals Yes and I also found the sync_threads issue when retrieving the generated cuda code via print(kernel.get_kernel_source()).

// 1. Perform reduction sum to calculate pAp_frag[0]
pAp_frag[0] = tl::AllReduce<tl::SumOp, 4, 1, 0, 128>::run_hopper(pAp_frag[0]);

// 2. Write reduction results to Shared Memory
// Note: Only a subset of threads (tid % 4 == 0) are responsible for writing
if ((((int)threadIdx.x) % 4) == 0) { 
  ((float*)buf_dyn_shmem)[((((int)threadIdx.x) >> 2) + 20640)] = pAp_frag[0]; 
} 

// === CRITICAL FLAW: Missing __syncthreads(); here ===
// There is no guarantee that the write operation is complete 
// or visible to all Warps at this point.

// 3. All threads immediately read from Shared Memory to calculate alpha
// This read operation is highly likely to retrieve:
// a) Stale data from the previous iteration
// b) Garbage data (incomplete writes), leading to division by zero or NaNs
for (int i_28 = 0; i_28 < 4; ++i_28) { 
  float alpha = (((float*)buf_dyn_shmem)[...] / (((float*)buf_dyn_shmem)[... + 20640] + epsilon)); 
  // ... Update x1, r1, etc.
}

This is indeed a bug. The previous algorithm invoked __syncthreads() inside the if block, which caused program hangs. We refactored the code in PR #1631 to remove that synchronization to fix the deadlock. However, removing it entirely was incorrect for this case. We need to improve the pass to inject __syncthreads() outside the if block instead."

if ((((int)threadIdx.x) % 4) == 0) { 
  ((float*)buf_dyn_shmem)[((((int)threadIdx.x) >> 2) + 20640)] = pAp_frag[0]; 
  // previously, syncthreads will lead to a hang.
} 

Two solution:

  1. use fragment as reduce's output instead of shared
r_new_normsq = T.alloc_shared([tilesize], dtype=dtype)
->
r_new_normsq = T.alloc_fragment([tilesize], dtype=dtype)


pAp = T.alloc_shared([tilesize], dtype=dtype)
->
pAp = T.alloc_fragment([tilesize], dtype=dtype)
  1. direct invoke T.sync_thread() after dot .

We will have a fix asap, and sorry for the trouble.

@LeiWang1999
Copy link
Member

I made a fix at #1760 , thanks for pointing out the issue!

@Da1sypetals Da1sypetals reopened this Jan 30, 2026
@Da1sypetals
Copy link
Author

@LeiWang1999 Thanks! Is it correct that after the fix, it will no longer be required to manually call T.sync_threads() in the kernel? Shall I remove those then?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants