Skip to content

Conversation

@cscyuge
Copy link

@cscyuge cscyuge commented Jan 12, 2026

this pr is related to #1633, add cudagraph backend to profiler and allows AutoTuner to specify the profiling backend.

Changes

1. Profiler CUDA Graph Backend

  • tilelang/profiler/bench.py:

    • Added _bench_with_cudagraph() function implementing CUDA Graph-based benchmarking
    • Updated do_bench() to support "cudagraph" backend option
    • Implementation follows triton.testing.do_bench_cudagraph pattern
  • tilelang/profiler/__init__.py:

    • Extended backend parameter to include "cudagraph" option
    • Fixed default values for n_warmup and n_repeat (changed from 1 to 0)

2. AutoTuner Backend Selection

  • tilelang/autotuner/param.py:

    • Added backend parameter to ProfileArgs class with type Literal["event", "cupti", "cudagraph"]
  • tilelang/autotuner/tuner.py:

    • Added backend parameter to set_profile_args() method
    • Updated _profile() to use specified backend for both kernel and reference program benchmarking

3. Example Updates

  • examples/gemm/example_gemm_autotune.py:
    • Added profile_backend parameter to get_best_config() and main()
    • Added --profile_backend CLI argument
    • Updated benchmark calls to use specified backend

Usage

# AutoTuner with CUDA Graph backend
autotuner.set_profile_args(backend="cudagraph", warmup=3, rep=20)

# Profiler with CUDA Graph backend
profiler.do_bench(backend="cudagraph")

Bench Results

# python examples/gemm/example_gemm_autotune.py --m 1 --n 1024 --k 1024 --use_autotune --profile_backend cudagraph
{'block_M': 16, 'block_N': 32, 'block_K': 128, 'num_stages': 3, 'thread_num': 128, 'enable_rasteration': False}
TileLang latency: 0.0037138615734875202
Ref latency: 0.005500588566064835
TileLang TFlops: 0.5646823282189969
Ref TFlops: 0.38125956428337626

# python examples/gemm/example_gemm_autotune.py --m 1 --n 1024 --k 1024 --use_autotune --profile_backend cupti
{'block_M': 16, 'block_N': 32, 'block_K': 128, 'num_stages': 3, 'thread_num': 128, 'enable_rasteration': False}
TileLang latency: 0.0053819006309147
Ref latency: 0.006137294351630924
TileLang TFlops: 0.38966754383266483
Ref TFlops: 0.3417062763891556

# python examples/gemm/example_gemm_autotune.py --m 1 --n 1024 --k 1024 --use_autotune --profile_backend event
{'block_M': 32, 'block_N': 16, 'block_K': 128, 'num_stages': 3, 'thread_num': 128, 'enable_rasteration': False}
TileLang latency: 0.008737378753721714
Ref latency: 0.009419833309948444
TileLang TFlops: 0.24002072693789445
Ref TFlops: 0.22263154038884772

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 12, 2026

📝 Walkthrough

Walkthrough

This PR introduces a profile_backend parameter throughout the profiling infrastructure to support multiple benchmarking backends including a newly added CUDA graph-based benchmarking option ("cudagraph"). The parameter is threaded through the autotuner, example utilities, and profiler components. Additionally, block size configurations in the GEMM example are made M-dependent, and default warmup/repeat values in profiler are adjusted.

Changes

Cohort / File(s) Summary
Example and CLI Updates
examples/gemm/example_gemm_autotune.py
Added profile_backend parameter to get_best_config() and main() APIs; threaded through autotuner and benchmarking calls; extended CLI with --profile_backend argument; made block size configurations M-dependent (larger block sizes when M > 32).
Profiling Parameter Storage
tilelang/autotuner/param.py
Added backend field to ProfileArgs dataclass with type Literal["event", "cupti", "cudagraph"] and default value "event"; field included in hash computation.
Autotuner Backend Propagation
tilelang/autotuner/tuner.py
Added backend parameter to AutoTuner.set_profile_args() method; propagated backend value into ProfileArgs construction and through to profiler.do_bench() calls for both profiled and reference latency benchmarking.
Profiler Public API Updates
tilelang/profiler/__init__.py
Updated Profiler.do_bench() signature to accept "cudagraph" in backend options alongside existing "event" and "cupti"; changed default values for n_warmup and n_repeat from 1 to 0; updated docstring parameter documentation.
CUDA Graph Benchmarking Implementation
tilelang/profiler/bench.py
Added new _bench_with_cudagraph() function to capture and replay CUDA graphs with L2 cache clearing between replays; integrated cudagraph path into do_bench() control flow with conditional dispatch based on backend parameter; extended docstring to document the new backend option.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

  • [Profiler]Adds CUPTI profiler support #936: Both PRs modify profiling backend infrastructure and signatures to propagate alternative profiler options (CUPTI in referenced PR; CUPTI and CUDA graph backends in this PR), updating the do_bench signature and profiling parameter flow through autotuning components.

Suggested reviewers

  • LeiWang1999

Poem

🐰 Graphs we capture, replayed with care,
Backends now chosen with customized flair,
CUDA events dance with cudagraph's grace,
Profiling faster through this new space!

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 53.85% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately summarizes the main change: adding cudagraph backend support to the profiler. This is the primary feature introduced across multiple files (bench.py, init.py, param.py, tuner.py, and the example).

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In @tilelang/profiler/__init__.py:
- Around line 214-217: The parameter `input_tensors` in the profiler signature
is annotated as `list[torch.Tensor] = None` which implicitly uses None without
Optional; change the annotation to explicitly allow None (e.g.
`Optional[list[torch.Tensor]] = None` or `list[torch.Tensor] | None = None`) and
add the necessary import (`from typing import Optional`) if using Optional;
update the `__init__` (or the function) signature where `input_tensors` is
declared to use the explicit optional type.
🧹 Nitpick comments (2)
tilelang/profiler/bench.py (1)

209-259: CUDA graph benchmarking implementation - cache flushing semantics differ from event backend.

The implementation looks functional, but there's a semantic difference from the event backend that could affect timing accuracy:

  • Event backend: Clears L2 cache before each fn() call (line 154)
  • Cudagraph backend: Clears L2 cache once before replaying a graph containing n_repeat iterations (line 240)

This means cudagraph measurements may include cache hits for iterations 2 through n_repeat, potentially reporting faster times than the event backend for cache-sensitive kernels.

Additionally, consider:

  1. n_retries=10 is hardcoded - could be a parameter for consistency with other configurability
  2. No error handling for graph capture failures (some operations aren't CUDA graph-capturable)
Optional: Add error handling for graph capture
 def _bench_with_cudagraph(
     fn: Callable,
     cache: torch.Tensor,
     n_repeat: int,
     quantiles: list[float] | None,
     return_mode: str,
 ) -> float | list[float]:
     """Benchmark using CUDA graph for minimal launch overhead.
     ...
     """
-    with torch.cuda.stream(torch.cuda.Stream()):
-        # Construct a CUDA graph with n_repeat unrolled function calls
-        g = torch.cuda.CUDAGraph()
-        with torch.cuda.graph(g):
-            for _ in range(n_repeat):
-                fn()
+    try:
+        with torch.cuda.stream(torch.cuda.Stream()):
+            # Construct a CUDA graph with n_repeat unrolled function calls
+            g = torch.cuda.CUDAGraph()
+            with torch.cuda.graph(g):
+                for _ in range(n_repeat):
+                    fn()
+    except Exception as e:
+        raise RuntimeError(
+            f"CUDA graph capture failed. The function may contain operations "
+            f"that are not graph-capturable. Consider using 'event' backend. "
+            f"Original error: {e}"
+        ) from e
examples/gemm/example_gemm_autotune.py (1)

248-248: Consider validating the --profile_backend CLI argument.

The CLI argument accepts any string, but only "event", "cupti", and "cudagraph" are valid backends. Invalid values will cause a runtime error deep in the profiler.

Add choices validation
-    parser.add_argument("--profile_backend", type=str, default="event", help="Profiler backend")
+    parser.add_argument("--profile_backend", type=str, default="event", choices=["event", "cupti", "cudagraph"], help="Profiler backend")
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between fd260e3 and ca4de0d.

📒 Files selected for processing (5)
  • examples/gemm/example_gemm_autotune.py
  • tilelang/autotuner/param.py
  • tilelang/autotuner/tuner.py
  • tilelang/profiler/__init__.py
  • tilelang/profiler/bench.py
🧰 Additional context used
🧬 Code graph analysis (2)
examples/gemm/example_gemm_autotune.py (2)
tilelang/profiler/__init__.py (1)
  • do_bench (209-269)
tilelang/profiler/bench.py (1)
  • do_bench (64-137)
tilelang/autotuner/tuner.py (2)
tilelang/profiler/__init__.py (1)
  • do_bench (209-269)
tilelang/profiler/bench.py (1)
  • do_bench (64-137)
🪛 Ruff (0.14.10)
tilelang/profiler/__init__.py

216-216: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)

🔇 Additional comments (7)
tilelang/autotuner/param.py (1)

108-108: LGTM! Backend field properly added to ProfileArgs.

The backend field is correctly added with appropriate type annotation and default value, and properly included in the hash computation to ensure cache invalidation when the profiling backend changes.

Also applies to: 119-131

tilelang/profiler/__init__.py (1)

242-251: Backend parameter correctly propagated.

The backend parameter is properly threaded through to the underlying do_bench call from tilelang.profiler.bench.

tilelang/autotuner/tuner.py (2)

454-460: Potential inconsistency in benchmarking parameters between kernel and reference.

The kernel benchmarking at line 454 uses warmup=warmup, rep=rep (time-based auto-calculation of iterations), while the reference benchmarking at lines 458-459 uses n_warmup=warmup, n_repeat=rep (fixed iteration counts). This means with warmup=3, rep=20:

  • Kernel: Aims for ~3ms warmup time, ~20ms total benchmark time
  • Reference: Runs exactly 3 warmup iterations and 20 benchmark iterations

This asymmetry may produce inconsistent timing comparisons between the kernel and its reference implementation.

Was this intentional? If both should use the same benchmarking strategy, consider aligning the parameter usage:

Option: Use consistent time-based approach for both
             latency = profiler.do_bench(warmup=warmup, rep=rep, input_tensors=self.jit_input_tensors, backend=backend)

             if self.ref_latency_cache is None and ref_prog is not None:
                 self.ref_input_tensors = ref_input_tensors_supply()
                 self.ref_latency_cache = profiler.do_bench(
-                    ref_prog, n_warmup=warmup, n_repeat=rep, input_tensors=self.ref_input_tensors, backend=backend
+                    ref_prog, warmup=warmup, rep=rep, input_tensors=self.ref_input_tensors, backend=backend
                 )

211-212: Backend parameter properly added to set_profile_args.

The backend parameter is correctly added to the method signature and propagated to ProfileArgs construction.

Also applies to: 252-252

tilelang/profiler/bench.py (1)

134-135: Backend dispatch correctly added for cudagraph.

The dispatch logic properly routes to _bench_with_cudagraph when backend == "cudagraph".

examples/gemm/example_gemm_autotune.py (2)

79-81: Verify: block_N depends on M instead of N.

Both block_M and block_N use the condition M > 32. While this might be intentional (using M as a proxy for "small matrix"), it could also be an oversight where block_N should depend on N:

block_M = [64, 128, 256] if M > 32 else [16, 32]
block_N = [64, 128, 256] if M > 32 else [16, 32]  # Should this check N instead?

If this is intentional (e.g., for small square-ish matrices), consider adding a comment to clarify the reasoning.


110-162: Backend parameter threading looks correct throughout the example.

The profile_backend parameter is properly threaded through get_best_config, main, and both do_bench calls for consistent profiling behavior.

Also applies to: 211-226

Comment on lines +214 to +217
n_warmup: int = 0,
n_repeat: int = 0,
input_tensors: list[torch.Tensor] = None,
backend: Literal["event", "cupti"] = "event",
backend: Literal["event", "cupti", "cudagraph"] = "event",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Type annotation issue: implicit Optional for input_tensors.

As flagged by static analysis, the input_tensors parameter has None as default but lacks Optional in the type hint. PEP 484 prohibits implicit Optional.

🔧 Suggested fix
     def do_bench(
         self,
         func: Callable | None = None,
         warmup: int = 25,
         rep: int = 100,
         n_warmup: int = 0,
         n_repeat: int = 0,
-        input_tensors: list[torch.Tensor] = None,
+        input_tensors: list[torch.Tensor] | None = None,
         backend: Literal["event", "cupti", "cudagraph"] = "event",
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
n_warmup: int = 0,
n_repeat: int = 0,
input_tensors: list[torch.Tensor] = None,
backend: Literal["event", "cupti"] = "event",
backend: Literal["event", "cupti", "cudagraph"] = "event",
n_warmup: int = 0,
n_repeat: int = 0,
input_tensors: list[torch.Tensor] | None = None,
backend: Literal["event", "cupti", "cudagraph"] = "event",
🧰 Tools
🪛 Ruff (0.14.10)

216-216: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)

🤖 Prompt for AI Agents
In @tilelang/profiler/__init__.py around lines 214 - 217, The parameter
`input_tensors` in the profiler signature is annotated as `list[torch.Tensor] =
None` which implicitly uses None without Optional; change the annotation to
explicitly allow None (e.g. `Optional[list[torch.Tensor]] = None` or
`list[torch.Tensor] | None = None`) and add the necessary import (`from typing
import Optional`) if using Optional; update the `__init__` (or the function)
signature where `input_tensors` is declared to use the explicit optional type.

Copy link
Member

@LeiWang1999 LeiWang1999 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your contribution! left some comments :)

rep: int = 100,
n_warmup: int = 1,
n_repeat: int = 1,
n_warmup: int = 0,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we need to change the default value from 1 to 0?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the "cudagraph" backend, n_repeat = 0 avoids measuring a single-iteration graph where launch overhead is still present.
We also align n_repeat / n_warmup with bench.py::do_bench so the default behavior is consistent.

return kernel_time_us * 1e-3 # Convert microseconds to milliseconds


def _bench_with_cudagraph(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just want better understand the advantages of using CUDA Graphs for benchmarking. While I understand it likely reduces CPU overhead and kernel launch times, could we not simply use the CUPTI profiler if we only need to measure the kernel execution time?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You’re right that, functionally, CUPTI could be used to measure kernel execution time here. The main reason we prefer CUDA Graphs is that they better reflect inference-style execution, where the same kernels are executed repeatedly in a steady-state loop rather than as isolated launches.

This distinction can also matter for autotuning, since different measurement contexts may bias the tuning process differently. While we haven’t observed divergent optimal configurations in our current tests, using CUDA Graphs helps ensure that the benchmarking setup is aligned with the intended inference execution pattern.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants