diff --git a/.github/workflows/integration_tests_inference_experimental_gpu.yml b/.github/workflows/integration_tests_inference_experimental_gpu.yml index af092f41f8..328b583061 100644 --- a/.github/workflows/integration_tests_inference_experimental_gpu.yml +++ b/.github/workflows/integration_tests_inference_experimental_gpu.yml @@ -27,17 +27,18 @@ on: jobs: integration-tests-inference-models-gpu: - name: ${{ matrix.extras.marker }}:${{ matrix.python-version }} + name: ${{ matrix.extras.marker }}:${{ matrix.python-version }}:cuda-graphs:${{ matrix.extras.enable_auto_cuda_graphs_for_trt }} runs-on: Roboflow-GPU-VM-Runner timeout-minutes: 30 strategy: matrix: python-version: ["3.12"] extras: - - { install: "onnx-cu12,mediapipe", marker: "onnx_extras", workers: "auto" } - - { install: "trt10", marker: "trt_extras", workers: "auto" } - - { install: "torch-cu124,mediapipe", marker: "torch_models", workers: "1" } - - { install: "torch-cu124", marker: "hf_vlm_models", workers: "1" } + - { install: "onnx-cu12,mediapipe", marker: "onnx_extras", workers: "auto", enable_auto_cuda_graphs_for_trt: "false" } + - { install: "trt10", marker: "trt_extras", workers: "auto", enable_auto_cuda_graphs_for_trt: "false" } + - { install: "trt10", marker: "trt_extras", workers: "auto", enable_auto_cuda_graphs_for_trt: "true" } + - { install: "torch-cu124,mediapipe", marker: "torch_models", workers: "1", "enable_auto_cuda_graphs_for_trt": "false" } + - { install: "torch-cu124", marker: "hf_vlm_models", workers: "1", "enable_auto_cuda_graphs_for_trt": "false" } steps: - name: 🛎️ Checkout if: ${{ (github.event.inputs.extras == '' || github.event.inputs.extras == matrix.extras.marker) && (github.event.inputs.python_version == '' || github.event.inputs.python_version == matrix.python-version) }} @@ -107,4 +108,4 @@ jobs: timeout-minutes: 25 run: | source .venv/bin/activate - python -m pytest -n ${{ matrix.extras.workers }} -m "${{ matrix.extras.marker }} and not cpu_only" tests/integration_tests + ENABLE_AUTO_CUDA_GRAPHS_FOR_TRT_BACKEND=${{ matrix.extras.enable_auto_cuda_graphs_for_trt }} python -m pytest -n ${{ matrix.extras.workers }} -m "${{ matrix.extras.marker }} and not cpu_only" tests/integration_tests diff --git a/inference_models/development/profiling/profile_cudagraph_vram.py b/inference_models/development/profiling/profile_cudagraph_vram.py new file mode 100644 index 0000000000..d129fc38c1 --- /dev/null +++ b/inference_models/development/profiling/profile_cudagraph_vram.py @@ -0,0 +1,185 @@ +"""Profile GPU and CPU memory usage as CUDA graphs are cached and evicted. + +Loads yolov8n-640 as a TRT model with dynamic batch size, runs forward passes +with random batch sizes, and after each step records both GPU VRAM +(driver-level) and process CPU RSS. The cache capacity is smaller than the +number of distinct batch sizes, so eviction is exercised and memory usage +should plateau. + +Example invocation: + python profile_cudagraph_vram.py \ + --device cuda:0 \ + --num-steps 64 \ + --max-batch-size 16 \ + --cache-capacity 16 \ + --output vram_sequential.png + + python profile_cudagraph_vram.py \ + --device cuda:0 \ + --num-steps 64 \ + --max-batch-size 16 \ + --cache-capacity 16 \ + --shuffle \ + --output vram_shuffle.png + + python profile_cudagraph_vram.py \ + --device cuda:0 \ + --shuffle \ + --num-steps 64 \ + --max-batch-size 16 \ + --cache-capacity 8 \ + --output vram_shuffle_eviction.png + + python profile_cudagraph_vram.py \ + --device cuda:0 \ + --shuffle \ + --num-steps 64 \ + --max-batch-size 2 \ + --cache-capacity 2 \ + --output vram_two_batch_sizes.png +""" + +import argparse +import gc +import os +import random +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import torch + +from inference_models import AutoModel +from inference_models.models.common.trt import TRTCudaGraphCache + +MODEL_ID = "yolov8n-640" +MB = 1024 ** 2 + + +def gpu_used_bytes(device: torch.device) -> int: + free, total = torch.cuda.mem_get_info(device) + return total - free + + +def cpu_rss_bytes() -> int: + with open(f"/proc/{os.getpid()}/statm") as f: + pages = int(f.read().split()[1]) + return pages * os.sysconf("SC_PAGE_SIZE") + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Profile GPU + CPU memory vs. number of cached CUDA graphs.", + ) + parser.add_argument("--device", type=str, default="cuda:0") + parser.add_argument("--max-batch-size", type=int, default=16) + parser.add_argument("--cache-capacity", type=int, default=8) + parser.add_argument("--num-steps", type=int, default=32) + parser.add_argument("--shuffle", action="store_true", help="Randomize batch size order instead of sequential cycling.") + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--output", type=str, default=None) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + device = torch.device(args.device) + + rng = random.Random(args.seed) + + model = AutoModel.from_pretrained( + model_id_or_path=MODEL_ID, + device=device, + backend="trt", + batch_size=(1, args.max_batch_size), + cuda_graph_cache_capacity=args.cache_capacity, + ) + + image = (np.random.rand(640, 640, 3) * 255).astype(np.uint8) + single_preprocessed, _ = model.pre_process(image) + + model.forward(single_preprocessed, use_cuda_graph=False) + gc.collect() + torch.cuda.synchronize(device) + torch.cuda.empty_cache() + + baseline_gpu = gpu_used_bytes(device) + baseline_cpu = cpu_rss_bytes() + + model._trt_cuda_graph_cache = TRTCudaGraphCache( + capacity=args.cache_capacity, + ) + + if args.shuffle: + batch_size_sequence = [ + rng.randint(1, args.max_batch_size) for _ in range(args.num_steps) + ] + else: + all_sizes = list(range(1, args.max_batch_size + 1)) + batch_size_sequence = [ + all_sizes[i % len(all_sizes)] for i in range(args.num_steps) + ] + + batch_sizes = [] + cumulative_gpu_mb = [] + cumulative_cpu_mb = [] + + for i, bs in enumerate(batch_size_sequence): + batched = single_preprocessed.expand(bs, -1, -1, -1).contiguous() + output = model.forward(batched, use_cuda_graph=True) + del output + gc.collect() + torch.cuda.synchronize(device) + + gpu = gpu_used_bytes(device) + cpu = cpu_rss_bytes() + cache_size = len(model._trt_cuda_graph_cache.cache) + + batch_sizes.append(bs) + cumulative_gpu_mb.append((gpu - baseline_gpu) / MB) + cumulative_cpu_mb.append((cpu - baseline_cpu) / MB) + + print( + f"[{i + 1}/{args.num_steps}] bs={bs:>2d} | " + f"cache: {cache_size}/{args.cache_capacity} | " + f"GPU: {cumulative_gpu_mb[-1]:>7.1f} MB | " + f"CPU: {cumulative_cpu_mb[-1]:>7.1f} MB" + ) + + mode = "shuffle" if args.shuffle else "sequential" + autogenerated_name = f"vram_{MODEL_ID}_cap{args.cache_capacity}_{mode}.png" + output_path = Path(args.output) if args.output else Path(autogenerated_name) + + fig, ax = plt.subplots(figsize=(14, 6)) + fig.suptitle( + f"Memory vs. Step (cache capacity={args.cache_capacity}, " + f"batch sizes 1-{args.max_batch_size}) -- {MODEL_ID}", + fontsize=14, + ) + + steps = np.arange(len(batch_sizes)) + + ax.plot(steps, cumulative_gpu_mb, color="steelblue", marker=".", label="GPU VRAM") + ax.plot(steps, cumulative_cpu_mb, color="seagreen", marker=".", label="CPU RSS") + ax.set_ylabel("Memory above baseline (MB)") + ax.set_xlabel("Step") + for i, bs in enumerate(batch_sizes): + ax.annotate( + str(bs), (i, cumulative_gpu_mb[i]), + textcoords="offset points", xytext=(0, 6), + fontsize=6, ha="center", color="steelblue", + ) + ax.legend() + + plt.tight_layout() + fig.savefig(output_path, dpi=150) + print(f"\nPlot saved to {output_path}") + + print(f"\nFinal GPU VRAM above baseline: {cumulative_gpu_mb[-1]:.1f} MB") + print(f"Final CPU RSS above baseline: {cumulative_cpu_mb[-1]:.1f} MB") + print(f"Peak GPU VRAM above baseline: {max(cumulative_gpu_mb):.1f} MB") + print(f"Cache entries at end: {cache_size}/{args.cache_capacity}") + + +if __name__ == "__main__": + main() diff --git a/inference_models/development/profiling/profile_rfdetr_trt_cudagraphs.py b/inference_models/development/profiling/profile_rfdetr_trt_cudagraphs.py new file mode 100644 index 0000000000..2791e24e3e --- /dev/null +++ b/inference_models/development/profiling/profile_rfdetr_trt_cudagraphs.py @@ -0,0 +1,66 @@ +import os +import time + +import cv2 +import numpy as np +import torch +from tqdm import tqdm + +from inference_models import AutoModel +from inference_models.models.common.trt import TRTCudaGraphCache + +IMAGE_PATH = os.environ.get("IMAGE_PATH", None) +DEVICE = os.environ.get("DEVICE", "cuda:0") +CYCLES = int(os.environ.get("CYCLES", "10_000")) +WARMUP = int(os.environ.get("WARMUP", "50")) + + +def main() -> None: + + model = AutoModel.from_pretrained( + model_id_or_path="rfdetr-nano", device=torch.device(DEVICE), backend="trt" + ) + + if IMAGE_PATH is not None: + image = cv2.imread(IMAGE_PATH) + else: + image = (np.random.rand(224, 224, 3) * 255).astype(np.uint8) + + pre_processed, _ = model.pre_process(image) + + for _ in range(WARMUP): + model.forward(pre_processed, use_cuda_graph=False) + model.forward(pre_processed, use_cuda_graph=True) + + print("Timing without CUDA graphs...") + start = time.perf_counter() + for _ in range(CYCLES): + model.forward(pre_processed, use_cuda_graph=False) + baseline_fps = CYCLES / (time.perf_counter() - start) + + print("Timing with forced CUDA graph recapture each step...") + start = time.perf_counter() + for _ in range(100): # not using CYCLES here bc this is wayyyy slower than the non-graph or the replay modes + model._trt_cuda_graph_cache = TRTCudaGraphCache(capacity=16) + model.forward(pre_processed, use_cuda_graph=True) + + cudagraph_recapture_fps = 100 / (time.perf_counter() - start) + + print("Timing with CUDA graph caching and replaying...") + model.forward(pre_processed, use_cuda_graph=True) # initial capture + start = time.perf_counter() + for _ in range(CYCLES): + model.forward(pre_processed, use_cuda_graph=True) + cudagraph_replay_fps = CYCLES / (time.perf_counter() - start) + + print(f"\n{'='*50}") + print(f"Forward pass FPS (no CUDA graphs): {baseline_fps:.1f}") + print(f"Forward pass FPS (CUDA graphs recapture): {cudagraph_recapture_fps:.1f}") + print(f"Speed factor (recapture): {cudagraph_recapture_fps / baseline_fps:.2f}x") + print(f"Forward pass FPS (CUDA graphs replay): {cudagraph_replay_fps:.1f}") + print(f"Speed factor (replay): {cudagraph_replay_fps / baseline_fps:.2f}x") + print(f"{'='*50}") + + +if __name__ == "__main__": + main() diff --git a/inference_models/development/profiling/profile_yolov8_trt_cudagraphs.py b/inference_models/development/profiling/profile_yolov8_trt_cudagraphs.py new file mode 100644 index 0000000000..ebbe543a70 --- /dev/null +++ b/inference_models/development/profiling/profile_yolov8_trt_cudagraphs.py @@ -0,0 +1,94 @@ +import os +import time + +import numpy as np +import torch +from tqdm import tqdm + +from inference_models import AutoModel + +DEVICE = os.environ.get("DEVICE", "cuda:0") +CYCLES = int(os.environ.get("CYCLES", "10_000")) +WARMUP = int(os.environ.get("WARMUP", "50")) +RECAPTURE_CYCLES = int(os.environ.get("RECAPTURE_CYCLES", "100")) + +os.environ["USE_TRT_CUDA_GRAPHS"] = "True" + +BATCH_SIZES = [1, 2, 3] + + +def main() -> None: + + model = AutoModel.from_pretrained( + model_id_or_path="yolov8n-640", + device=torch.device(DEVICE), + backend="trt", + batch_size=(1, max(BATCH_SIZES)), + ) + + image = (np.random.rand(224, 224, 3) * 255).astype(np.uint8) + pre_processed_single, _ = model.pre_process(image) + + batches = { + bs: pre_processed_single.repeat(bs, 1, 1, 1) for bs in BATCH_SIZES + } + + # ── Warmup ────────────────────────────────────────────────────────── + for _ in range(WARMUP): + for batch in batches.values(): + model.forward(batch, use_cuda_graph=False) + model.forward(batch, use_cuda_graph=True) + + bs_label = "/".join(str(bs) for bs in BATCH_SIZES) + + # ── (1) Cycling batch sizes, no CUDA graphs ───────────────────────── + print(f"Timing without CUDA graphs, cycling bs={bs_label}...") + torch.cuda.synchronize() + start = time.perf_counter() + for i in range(CYCLES): + batch = batches[BATCH_SIZES[i % len(BATCH_SIZES)]] + model.forward(batch, use_cuda_graph=False) + torch.cuda.synchronize() + baseline_fps = CYCLES / (time.perf_counter() - start) + + # ── (2) Cycling batch sizes, CUDA graphs with forced recapture ────── + print( + f"Timing with CUDA graph recapture every iteration, cycling bs={bs_label} " + f"({RECAPTURE_CYCLES} iters)..." + ) + torch.cuda.synchronize() + start = time.perf_counter() + for i in range(RECAPTURE_CYCLES): + model._trt_cuda_graph_cache.cache.clear() + batch = batches[BATCH_SIZES[i % len(BATCH_SIZES)]] + model.forward(batch, use_cuda_graph=True) + torch.cuda.synchronize() + recapture_fps = RECAPTURE_CYCLES / (time.perf_counter() - start) + + # ── (3) Cycling batch sizes, CUDA graphs with normal caching ──────── + model._trt_cuda_graph_cache.cache.clear() + for batch in batches.values(): + model.forward(batch, use_cuda_graph=True) + + print(f"Timing with CUDA graph cache replay, cycling bs={bs_label}...") + torch.cuda.synchronize() + start = time.perf_counter() + for i in range(CYCLES): + batch = batches[BATCH_SIZES[i % len(BATCH_SIZES)]] + model.forward(batch, use_cuda_graph=True) + torch.cuda.synchronize() + replay_fps = CYCLES / (time.perf_counter() - start) + + # ── Results ───────────────────────────────────────────────────────── + print(f"\n{'='*60}") + print(f" yolov8n-640 TRT — cycling batch sizes {BATCH_SIZES}") + print(f" {CYCLES} iterations (recapture: {RECAPTURE_CYCLES})") + print(f"{'='*60}") + print(f" No CUDA graphs: {baseline_fps:>8.1f} fwd/s") + print(f" CUDA graph recapture: {recapture_fps:>8.1f} fwd/s ({recapture_fps / baseline_fps:.2f}x)") + print(f" CUDA graph replay: {replay_fps:>8.1f} fwd/s ({replay_fps / baseline_fps:.2f}x)") + print(f"{'='*60}") + + +if __name__ == "__main__": + main() diff --git a/inference_models/docs/api-reference/developer-tools/trt/establish-trt-cuda-graph-cache.md b/inference_models/docs/api-reference/developer-tools/trt/establish-trt-cuda-graph-cache.md new file mode 100644 index 0000000000..3442d233ac --- /dev/null +++ b/inference_models/docs/api-reference/developer-tools/trt/establish-trt-cuda-graph-cache.md @@ -0,0 +1,6 @@ +# establish_trt_cuda_graph_cache + +::: inference_models.models.common.trt.establish_trt_cuda_graph_cache + options: + show_root_heading: true + show_source: false diff --git a/inference_models/docs/api-reference/developer-tools/trt/get-trt-engine-inputs-and-outputs.md b/inference_models/docs/api-reference/developer-tools/trt/get-trt-engine-inputs-and-outputs.md index 98179cf56c..301102ca68 100644 --- a/inference_models/docs/api-reference/developer-tools/trt/get-trt-engine-inputs-and-outputs.md +++ b/inference_models/docs/api-reference/developer-tools/trt/get-trt-engine-inputs-and-outputs.md @@ -1,4 +1,4 @@ -2# get_trt_engine_inputs_and_outputs +# get_trt_engine_inputs_and_outputs ::: inference_models.models.common.trt.get_trt_engine_inputs_and_outputs options: diff --git a/inference_models/docs/api-reference/developer-tools/trt/trt-cuda-graph-cache.md b/inference_models/docs/api-reference/developer-tools/trt/trt-cuda-graph-cache.md new file mode 100644 index 0000000000..e074a3c063 --- /dev/null +++ b/inference_models/docs/api-reference/developer-tools/trt/trt-cuda-graph-cache.md @@ -0,0 +1,6 @@ +# TRTCudaGraphCache + +::: inference_models.models.common.trt.TRTCudaGraphCache + options: + show_root_heading: true + show_source: false diff --git a/inference_models/docs/changelog.md b/inference_models/docs/changelog.md index 26180bc47f..0ca5d597b7 100644 --- a/inference_models/docs/changelog.md +++ b/inference_models/docs/changelog.md @@ -1,5 +1,14 @@ # Changelog +## `0.21.0` +### Added + +- Support for CUDA Graphs in TRT backend - all TRT models got upgraded - added ability to run with CUDA graphs, at +the expense of additional VRAM allocation, but with caller control on how many execution contexts for different +input shapes should be allowed. + +--- + ## `0.20.2` ### Added diff --git a/inference_models/docs/how-to/use-cuda-graphs.md b/inference_models/docs/how-to/use-cuda-graphs.md new file mode 100644 index 0000000000..58e1eacd8e --- /dev/null +++ b/inference_models/docs/how-to/use-cuda-graphs.md @@ -0,0 +1,210 @@ +# Using CUDA Graphs with TensorRT Models + +CUDA graphs capture a sequence of GPU operations and replay them as a single unit, eliminating per-call +CPU overhead. For TensorRT models in `inference_models`, this translates to a **7–12% FPS improvement** +on repeated inference with the same input shape. + +## Overview + +When CUDA graphs are enabled, the first `forward()` call for a given input shape captures the TensorRT +execution into a CUDA graph. Subsequent calls with the same shape replay the captured graph instead of +re-launching individual GPU kernels. Captured graphs are stored in an LRU cache keyed by +`(shape, dtype, device)`. + +CUDA graphs work with all TRT model classes that use `infer_from_trt_engine` — including object detection, +instance segmentation, keypoint detection, classification, and semantic segmentation models. + +## Prerequisites + +- A CUDA-capable GPU +- TensorRT installed (brought in by `trt-*` extras of `inference-models`) +- A TRT model package (`.plan` engine file) + +## Quick Start + +The simplest way to enable CUDA graphs is through the `ENABLE_AUTO_CUDA_GRAPHS_FOR_TRT_BACKEND` environment +variable: + +```bash +export ENABLE_AUTO_CUDA_GRAPHS_FOR_TRT_BACKEND=True +``` + +With this set, all TRT models loaded via `AutoModel.from_pretrained` will automatically create a CUDA +graph cache and use it during inference. No code changes required. + +```python +import torch +from inference_models import AutoModel + +model = AutoModel.from_pretrained( + model_id_or_path="rfdetr-nano", + device=torch.device("cuda:0"), + backend="trt", +) + +# First call captures the CUDA graph for this input shape +results = model.predict(image) + +# Subsequent calls replay the captured graph — faster +results = model.predict(image) +``` + +## Manual Cache Control + +For more control over cache behavior, create a `TRTCudaGraphCache` explicitly and pass it +to `AutoModel.from_pretrained`: + +```python +import torch +from inference_models import AutoModel +from inference_models.developer_tools import TRTCudaGraphCache + +cache = TRTCudaGraphCache(capacity=16) + +model = AutoModel.from_pretrained( + model_id_or_path="rfdetr-nano", + device=torch.device("cuda:0"), + backend="trt", + trt_cuda_graph_cache=cache, +) +``` + +The `capacity` parameter controls how many distinct input shapes can be cached simultaneously. +When the cache is full, the least recently used graph is evicted automatically. + +### Inspecting the Cache + +You can query the cache at any time to see what's been captured: + +```python +# Check how many graphs are currently cached +print(cache.get_current_size()) # e.g. 3 + +# List all cached keys — each key is a (shape, dtype, device) tuple +for key in cache.list_keys(): + shape, dtype, device = key + print(f" shape={shape}, dtype={dtype}, device={device}") + +# Check if a specific shape is cached +key = ((1, 3, 384, 384), torch.float16, torch.device("cuda:0")) +if key in cache: + print("Graph is cached for this shape") +``` + +### Removing Specific Entries + +Use `safe_remove()` to evict a single cached graph by its key. This releases the associated +CUDA graph, execution context, and GPU buffers immediately. If the key doesn't exist, the +call is a no-op: + +```python +key = ((1, 3, 384, 384), torch.float16, torch.device("cuda:0")) +cache.safe_remove(key) +``` + +### Purging the Cache + +Use `purge()` to evict multiple entries at once. When called without arguments, it clears the +entire cache. You can also pass `n_oldest` to evict only the N least recently used entries: + +```python +# Evict the 4 oldest (least recently used) entries +cache.purge(n_oldest=4) + +# Clear the entire cache +cache.purge() +``` + +`purge()` is more efficient than calling `safe_remove()` in a loop because it batches the +GPU memory cleanup — `torch.cuda.empty_cache()` is called once at the end rather than after +each individual eviction. + +!!! tip "When to purge manually" + Manual purging is useful when you know the workload is about to change — for example, + switching from processing video at one resolution to another. Purging stale entries + frees VRAM for the new shapes before they're captured. + +### Sharing a Cache Across Models + +Please **do not share single instance of `TRTCudaGraphCache`** to multiple models - as cache object is bound to +specific model instance. + +### Choosing Cache Capacity + +Each cached graph holds its own TensorRT execution context and GPU memory buffers. A reasonable +default is **8–16 entries**. Consider: + +- **Fixed input shape** (e.g. always 1×3×640×640): `capacity=1` is sufficient. +- **Variable batch sizes** (e.g. batch 1–16): set capacity to the number of distinct batch sizes + you expect, or quantize to powers of two and set `capacity=4–5`. +- **Memory-constrained environments**: lower the capacity to reduce VRAM usage. + +## Disabling CUDA Graphs Per Call + +Even with a cache configured, you can bypass CUDA graphs for individual forward passes using the +`disable_cuda_graphs` flag: + +```python +pre_processed, meta = model.pre_process(image) + +# Standard path — uses CUDA graphs if cache is configured +output = model.forward(pre_processed) + +# Bypass CUDA graphs for this specific call +output = model.forward(pre_processed, disable_cuda_graphs=True) +``` + +This is useful for debugging, benchmarking, or when you need to compare graph vs. non-graph outputs. + + +## How It Works + +The lifecycle of a CUDA graph in `inference_models`: + +1. **Cache miss** — `infer_from_trt_engine` detects that no cached graph exists for the current + `(shape, dtype, device)` key. It creates a dedicated TensorRT execution context, allocates + input/output buffers, runs a warmup pass, then captures the execution into a `torch.cuda.CUDAGraph`. + The graph and its associated state are stored in the cache. + +2. **Cache hit** — On subsequent calls with the same key, the cached graph's input buffer is updated + via `copy_()`, the graph is replayed, and output buffers are cloned and returned. No TensorRT + context setup or kernel launches happen on the CPU side. + +3. **Eviction** — When the cache exceeds its capacity, the least recently used entry is evicted. + The associated CUDA graph, execution context, and GPU buffers are released, and + `torch.cuda.empty_cache()` is called to return memory to the CUDA driver. + + +## Important Considerations + +### VRAM Usage + +Each cache entry consumes GPU memory for input buffers, output buffers, and the TensorRT execution +context's internal workspace. With large models or high cache capacities, this can be significant. +Monitor VRAM usage when tuning `capacity`. + +### Thread Safety + +One may manage cache entries and eviction from separate thread compared to the one running forward-pass. +The cache state is synchronized with thread lock. + +### Dynamic Batch Sizes + +CUDA graphs are shape-specific — a graph captured for batch size 4 cannot be replayed for batch size 8. +If your application uses variable batch sizes, each distinct size will trigger a separate graph capture. +The LRU cache handles this transparently, but be aware that frequent shape changes will cause cache +churn and recapture overhead. + +!!! tip "Quantize batch sizes for better cache utilization" + + If you control the batching logic, round batch sizes up to the nearest power of two + (1, 2, 4, 8, 16). This reduces the number of distinct shapes and keeps the cache small. + +### When CUDA Graphs Won't Help + +- **Cold start / single inference**: The first call for each shape pays the capture cost, which is + slower than a normal forward pass. CUDA graphs only pay off on subsequent replays. +- **Highly variable input shapes**: If every call has a unique shape, graphs are captured but + never replayed. +- **CPU-bound pipelines**: If your bottleneck is preprocessing or postprocessing, the GPU-side + speedup from graph replay won't be visible end-to-end. diff --git a/inference_models/inference_models/configuration.py b/inference_models/inference_models/configuration.py index c123c1c131..30e34a0c67 100644 --- a/inference_models/inference_models/configuration.py +++ b/inference_models/inference_models/configuration.py @@ -377,3 +377,8 @@ variable_name="INFERENCE_MODELS_YOLOLITE_DEFAULT_CLASS_AGNOSTIC_NMS", default=INFERENCE_MODELS_DEFAULT_CLASS_AGNOSTIC_NMS, ) + +ENABLE_AUTO_CUDA_GRAPHS_FOR_TRT_BACKEND_ENV_NAME = ( + "ENABLE_AUTO_CUDA_GRAPHS_FOR_TRT_BACKEND" +) +DEFAULT_ENABLE_AUTO_CUDA_GRAPHS_FOR_TRT_BACKEND = False diff --git a/inference_models/inference_models/developer_tools.py b/inference_models/inference_models/developer_tools.py index bee4b5ca75..44cbd90da7 100644 --- a/inference_models/inference_models/developer_tools.py +++ b/inference_models/inference_models/developer_tools.py @@ -13,7 +13,7 @@ along with library. Utilities depending on optional dependencies are exposed as lazy imports. """ -from typing import Any, Dict +from typing import Any, Dict, Union from inference_models.models.common.model_packages import get_model_package_contents from inference_models.runtime_introspection.core import ( @@ -21,7 +21,7 @@ x_ray_runtime_environment, ) from inference_models.utils.download import download_files_to_directory -from inference_models.utils.imports import LazyFunction +from inference_models.utils.imports import LazyClass, LazyFunction from inference_models.utils.onnx_introspection import ( get_selected_onnx_execution_providers, ) @@ -42,7 +42,7 @@ TRTPackageDetails, ) -OPTIONAL_IMPORTS: Dict[str, LazyFunction] = { +OPTIONAL_IMPORTS: Dict[str, Union[LazyFunction, LazyClass]] = { "use_primary_cuda_context": LazyFunction( module_name="inference_models.models.common.cuda", function_name="use_primary_cuda_context", @@ -79,6 +79,14 @@ module_name="inference_models.models.common.trt", function_name="load_trt_model", ), + "establish_trt_cuda_graph_cache": LazyFunction( + module_name="inference_models.models.common.trt", + function_name="establish_trt_cuda_graph_cache", + ), + "TRTCudaGraphCache": LazyClass( + module_name="inference_models.models.common.trt", + class_name="TRTCudaGraphCache", + ), } diff --git a/inference_models/inference_models/models/common/trt.py b/inference_models/inference_models/models/common/trt.py index a13d44cac8..4a07f4c144 100644 --- a/inference_models/inference_models/models/common/trt.py +++ b/inference_models/inference_models/models/common/trt.py @@ -1,7 +1,14 @@ +import threading +from collections import OrderedDict +from dataclasses import dataclass from typing import List, Optional, Tuple import torch +from inference_models.configuration import ( + DEFAULT_ENABLE_AUTO_CUDA_GRAPHS_FOR_TRT_BACKEND, + ENABLE_AUTO_CUDA_GRAPHS_FOR_TRT_BACKEND_ENV_NAME, +) from inference_models.errors import ( CorruptedModelPackageError, MissingDependencyError, @@ -9,6 +16,7 @@ ) from inference_models.logger import LOGGER from inference_models.models.common.roboflow.model_packages import TRTConfig +from inference_models.utils.environment import get_boolean_from_env try: import tensorrt as trt @@ -40,7 +48,6 @@ class InferenceTRTLogger(trt.ILogger): - def __init__(self, with_memory: bool = False): super().__init__() self._memory: List[Tuple[trt.ILogger.Severity, str]] = [] @@ -64,6 +71,299 @@ def get_memory(self) -> List[Tuple[trt.ILogger.Severity, str]]: return self._memory +@dataclass +class TRTCudaGraphState: + cuda_graph: torch.cuda.CUDAGraph + cuda_stream: torch.cuda.Stream + input_buffer: torch.Tensor + output_buffers: List[torch.Tensor] + execution_context: trt.IExecutionContext + + +class TRTCudaGraphCache: + + """LRU cache for captured CUDA graphs used in TensorRT inference. + + Stores captured ``torch.cuda.CUDAGraph`` objects keyed by input + ``(shape, dtype, device)`` tuples. When the cache exceeds its capacity, + the least recently used entry is evicted and its GPU resources are released. + + The cache is thread-safe — all mutating operations acquire an internal + ``threading.RLock``. + + Args: + capacity: Maximum number of CUDA graphs to store. Each entry holds + a dedicated TensorRT execution context and GPU memory buffers, + so higher values increase VRAM usage. + + Examples: + Create a cache and pass it to a model: + + >>> from inference_models.developer_tools import TRTCudaGraphCache + >>> from inference_models import AutoModel + >>> import torch + >>> + >>> cache = TRTCudaGraphCache(capacity=16) + >>> model = AutoModel.from_pretrained( + ... model_id_or_path="rfdetr-nano", + ... device=torch.device("cuda:0"), + ... backend="trt", + ... trt_cuda_graph_cache=cache, + ... ) + + See Also: + - ``establish_trt_cuda_graph_cache()``: Factory that creates a cache + based on environment configuration + - ``infer_from_trt_engine()``: Uses the cache during TRT inference + """ + + def __init__(self, capacity: int): + self._cache: OrderedDict[ + Tuple[Tuple[int, ...], torch.dtype, torch.device], TRTCudaGraphState + ] = OrderedDict() + self._capacity = capacity + self._state_lock = threading.RLock() + + def get_current_size(self) -> int: + """Return the number of CUDA graphs currently stored in the cache. + + Returns: + Number of cached entries. + + Examples: + >>> cache = TRTCudaGraphCache(capacity=16) + >>> cache.get_current_size() + 0 + """ + with self._state_lock: + return len(self._cache) + + def list_keys(self) -> List[Tuple[Tuple[int, ...], torch.dtype, torch.device]]: + """Return a list of all keys currently in the cache. + + Each key is a ``(shape, dtype, device)`` tuple representing a cached + CUDA graph. Keys are returned in insertion order (oldest first), which + reflects eviction priority. + + Returns: + List of ``(shape, dtype, device)`` tuples for all cached entries. + + Examples: + >>> cache = TRTCudaGraphCache(capacity=16) + >>> # ... after some forward passes ... + >>> for shape, dtype, device in cache.list_keys(): + ... print(f"Cached: shape={shape}, dtype={dtype}") + """ + with self._state_lock: + return list(self._cache.keys()) + + def safe_remove( + self, key: Tuple[Tuple[int, ...], torch.dtype, torch.device] + ) -> None: + """Remove a single entry from the cache by its key. + + If the key exists, the associated CUDA graph, execution context, and + GPU buffers are released and ``torch.cuda.empty_cache()`` is called. + If the key does not exist, this method is a no-op. + + Args: + key: A ``(shape, dtype, device)`` tuple identifying the entry + to remove. + + Examples: + Remove a cached graph for a specific input shape: + + >>> import torch + >>> key = ((1, 3, 384, 384), torch.float16, torch.device("cuda:0")) + >>> cache.safe_remove(key) + + Safe to call with a non-existent key: + + >>> cache.safe_remove(((99, 99), torch.float32, torch.device("cuda:0"))) + >>> # no error raised + + See Also: + - ``purge()``: Remove multiple entries at once with batched + GPU memory cleanup + """ + with self._state_lock: + if key not in self._cache: + return None + evicted = self._cache.pop(key) + self._evict(evicted=evicted) + return None + + def purge(self, n_oldest: Optional[int] = None) -> None: + """Remove entries from the cache, starting with the least recently used. + + When called without arguments, clears the entire cache. When + ``n_oldest`` is specified, only that many entries are evicted + (or all entries if the cache contains fewer). + + GPU memory cleanup (``torch.cuda.empty_cache()``) is called once + after all evictions, making this more efficient than calling + ``safe_remove()`` in a loop. + + Args: + n_oldest: Number of least recently used entries to evict. + When ``None`` (default), all entries are removed. + + Examples: + Evict the 4 oldest entries: + + >>> cache.purge(n_oldest=4) + + Clear the entire cache: + + >>> cache.purge() + >>> cache.get_current_size() + 0 + + Note: + - Eviction order follows LRU policy — entries that haven't been + accessed recently are removed first + - Each evicted entry's CUDA graph, execution context, and GPU + buffers are released + + See Also: + - ``safe_remove()``: Remove a single entry by key + """ + with self._state_lock: + if n_oldest is None: + n_oldest = len(self._cache) + to_evict = min(len(self._cache), n_oldest) + for _ in range(to_evict): + _, evicted = self._cache.popitem(last=False) + self._evict(evicted=evicted, empty_cuda_cache=False) + torch.cuda.empty_cache() + + def __contains__( + self, key: Tuple[Tuple[int, ...], torch.dtype, torch.device] + ) -> bool: + with self._state_lock: + return key in self._cache + + def __getitem__( + self, key: Tuple[Tuple[int, ...], torch.dtype, torch.device] + ) -> TRTCudaGraphState: + with self._state_lock: + value = self._cache[key] + self._cache.move_to_end(key) + return value + + def __setitem__( + self, + key: Tuple[Tuple[int, ...], torch.dtype, torch.device], + value: TRTCudaGraphState, + ): + with self._state_lock: + self._cache[key] = value + self._cache.move_to_end(key) + if len(self._cache) > self._capacity: + _, evicted = self._cache.popitem(last=False) + self._evict(evicted=evicted) + + def _evict(self, evicted: TRTCudaGraphState, empty_cuda_cache: bool = True) -> None: + del evicted.cuda_graph + del evicted.input_buffer + del evicted.output_buffers + del evicted.execution_context + if empty_cuda_cache: + torch.cuda.empty_cache() + + +def establish_trt_cuda_graph_cache( + default_cuda_graph_cache_size: int, + cuda_graph_cache: Optional[TRTCudaGraphCache] = None, +) -> Optional[TRTCudaGraphCache]: + """Establish a CUDA graph cache for TensorRT inference acceleration. + + Resolves which CUDA graph cache to use for a TRT model. If the caller + provides a cache instance, it is returned as-is. Otherwise, the function + checks the ``ENABLE_AUTO_CUDA_GRAPHS_FOR_TRT_BACKEND`` environment variable + to decide whether to create a new cache automatically. When the environment + variable is disabled (the default), no cache is created and CUDA graphs + are not used. + + This function is typically called inside ``from_pretrained()`` of TRT model + classes. End users who want explicit control should create a + ``TRTCudaGraphCache`` themselves and pass it to ``AutoModel.from_pretrained``. + + Args: + default_cuda_graph_cache_size: Maximum number of CUDA graphs to cache + when a new cache is created automatically. Each entry holds a + dedicated TensorRT execution context and GPU memory buffers, so + higher values increase VRAM usage. + + cuda_graph_cache: Optional pre-existing cache instance. When provided, + it is returned directly and the environment variable is ignored. + This allows callers to share a single cache across multiple models + or to configure capacity explicitly. + + Returns: + A ``TRTCudaGraphCache`` instance if CUDA graphs should be used, or + ``None`` if they are disabled. When ``None`` is returned, the model + falls back to standard TensorRT execution without graph capture. + + Examples: + Automatic cache creation via environment variable: + + >>> import os + >>> os.environ["ENABLE_AUTO_CUDA_GRAPHS_FOR_TRT_BACKEND"] = "True" + >>> + >>> from inference_models.developer_tools import ( + ... establish_trt_cuda_graph_cache, + ... ) + >>> + >>> cache = establish_trt_cuda_graph_cache(default_cuda_graph_cache_size=8) + >>> print(type(cache)) # + + Caller-provided cache takes priority: + + >>> from inference_models.models.common.trt import ( + ... TRTCudaGraphCache, + ... establish_trt_cuda_graph_cache, + ... ) + >>> + >>> my_cache = TRTCudaGraphCache(capacity=32) + >>> result = establish_trt_cuda_graph_cache( + ... default_cuda_graph_cache_size=8, + ... cuda_graph_cache=my_cache, + ... ) + >>> assert result is my_cache # returned as-is + + Typical usage inside a model's from_pretrained: + + >>> cache = establish_trt_cuda_graph_cache( + ... default_cuda_graph_cache_size=8, + ... cuda_graph_cache=None, # let env var decide + ... ) + >>> # cache is None when env var is disabled (default) + + Note: + - The environment variable ``ENABLE_AUTO_CUDA_GRAPHS_FOR_TRT_BACKEND`` + defaults to ``False`` + - When a caller-provided cache is given, the environment variable + is not checked + - CUDA graphs require TensorRT and a CUDA-capable GPU + - Each cached graph consumes VRAM proportional to the model's + execution context size + + See Also: + - ``TRTCudaGraphCache``: The LRU cache class for CUDA graph state + - ``infer_from_trt_engine()``: Uses the cache during TRT inference + """ + if cuda_graph_cache is not None: + return cuda_graph_cache + auto_cuda_graphs_enabled = get_boolean_from_env( + variable_name=ENABLE_AUTO_CUDA_GRAPHS_FOR_TRT_BACKEND_ENV_NAME, + default=DEFAULT_ENABLE_AUTO_CUDA_GRAPHS_FOR_TRT_BACKEND, + ) + if not auto_cuda_graphs_enabled: + return None + return TRTCudaGraphCache(capacity=default_cuda_graph_cache_size) + + def get_trt_engine_inputs_and_outputs( engine: trt.ICudaEngine, ) -> Tuple[List[str], List[str]]: @@ -135,12 +435,20 @@ def infer_from_trt_engine( input_name: str, outputs: List[str], stream: Optional[torch.cuda.Stream] = None, + trt_cuda_graph_cache: Optional[TRTCudaGraphCache] = None, ) -> List[torch.Tensor]: - """Run inference using a TensorRT engine. + """Run inference using a TensorRT engine, optionally with CUDA graph acceleration. + + Executes inference on preprocessed images using a TensorRT engine. Handles both + static and dynamic batch sizes, automatically splitting large batches if needed. - Executes inference on preprocessed images using a TensorRT engine and execution - context. Handles both static and dynamic batch sizes, automatically splitting - large batches if needed. + When ``trt_cuda_graph_cache`` is provided, CUDA graphs are captured and replayed + for improved performance on repeated inference with the same input shape. Each + graph is keyed by (shape, dtype, device) and stored in the cache. The cache + itself must be created by the caller (typically in the model class). + + When ``trt_cuda_graph_cache`` is ``None``, inference runs through the standard + TRT execution path using the provided ``context``. Args: pre_processed_images: Preprocessed input tensor on CUDA device. @@ -151,22 +459,28 @@ def infer_from_trt_engine( engine: TensorRT CUDA engine (ICudaEngine) to use for inference. - context: TensorRT execution context (IExecutionContext) for running inference. - device: PyTorch CUDA device to use for inference. input_name: Name of the input tensor in the TensorRT engine. outputs: List of output tensor names to retrieve from the engine. - stream: CUDA stream to use for inference. + context: TensorRT execution context (IExecutionContext) for running inference. + Required when ``trt_cuda_graph_cache`` is ``None``. Ignored when using + CUDA graphs (each cached graph owns its own execution context). + + trt_cuda_graph_cache: Optional CUDA graph cache. When provided, CUDA graphs + are used for inference. When ``None``, standard TRT execution is used. + + stream: CUDA stream to use for inference. Defaults to the current stream + for the given device. Returns: List of output tensors from the TensorRT engine, in the order specified by the outputs parameter. Examples: - Run TensorRT inference: + Run TensorRT inference (standard path): >>> from inference_models.developer_tools import ( ... load_trt_model, @@ -197,7 +511,7 @@ def infer_from_trt_engine( ... context=context, ... device=torch.device("cuda:0"), ... input_name=inputs[0], - ... outputs=outputs + ... outputs=outputs, ... ) Handle large batches: @@ -212,10 +526,25 @@ def infer_from_trt_engine( ... context=context, ... device=torch.device("cuda:0"), ... input_name=inputs[0], - ... outputs=outputs + ... outputs=outputs, ... ) >>> # Results are automatically concatenated + Run with CUDA graph acceleration: + + >>> from inference_models.models.common.trt import TRTCudaGraphCache + >>> cache = TRTCudaGraphCache(capacity=16) + >>> + >>> results = infer_from_trt_engine( + ... pre_processed_images=images, + ... trt_config=trt_config, + ... engine=engine, + ... device=torch.device("cuda:0"), + ... input_name=inputs[0], + ... outputs=outputs, + ... trt_cuda_graph_cache=cache, + ... ) + Note: - Requires TensorRT and PyCUDA to be installed - Input must be on CUDA device @@ -241,6 +570,7 @@ def infer_from_trt_engine( device=device, input_name=input_name, outputs=outputs, + trt_cuda_graph_cache=trt_cuda_graph_cache, ) stream.synchronize() return results @@ -254,18 +584,14 @@ def _infer_from_trt_engine( device: torch.device, input_name: str, outputs: List[str], + trt_cuda_graph_cache: Optional[TRTCudaGraphCache] = None, ) -> List[torch.Tensor]: if trt_config.static_batch_size is not None: - return _infer_from_trt_engine_with_batch_size_boundaries( - pre_processed_images=pre_processed_images, - engine=engine, - context=context, - device=device, - input_name=input_name, - outputs=outputs, - min_batch_size=trt_config.static_batch_size, - max_batch_size=trt_config.static_batch_size, - ) + min_batch_size = trt_config.static_batch_size + max_batch_size = trt_config.static_batch_size + else: + min_batch_size = trt_config.dynamic_batch_size_min + max_batch_size = trt_config.dynamic_batch_size_max return _infer_from_trt_engine_with_batch_size_boundaries( pre_processed_images=pre_processed_images, engine=engine, @@ -273,8 +599,9 @@ def _infer_from_trt_engine( device=device, input_name=input_name, outputs=outputs, - min_batch_size=trt_config.dynamic_batch_size_min, - max_batch_size=trt_config.dynamic_batch_size_max, + min_batch_size=min_batch_size, + max_batch_size=max_batch_size, + trt_cuda_graph_cache=trt_cuda_graph_cache, ) @@ -287,6 +614,7 @@ def _infer_from_trt_engine_with_batch_size_boundaries( outputs: List[str], min_batch_size: int, max_batch_size: int, + trt_cuda_graph_cache: Optional[TRTCudaGraphCache] = None, ) -> List[torch.Tensor]: if pre_processed_images.shape[0] <= max_batch_size: reminder = min_batch_size - pre_processed_images.shape[0] @@ -309,6 +637,7 @@ def _infer_from_trt_engine_with_batch_size_boundaries( device=device, input_name=input_name, outputs=outputs, + trt_cuda_graph_cache=trt_cuda_graph_cache, ) if reminder > 0: results = [r[:-reminder] for r in results] @@ -338,6 +667,7 @@ def _infer_from_trt_engine_with_batch_size_boundaries( device=device, input_name=input_name, outputs=outputs, + trt_cuda_graph_cache=trt_cuda_graph_cache, ) if reminder > 0: results = [r[:-reminder] for r in results] @@ -353,39 +683,149 @@ def _execute_trt_engine( device: torch.device, input_name: str, outputs: List[str], + trt_cuda_graph_cache: Optional[TRTCudaGraphCache] = None, ) -> List[torch.Tensor]: - batch_size = pre_processed_images.shape[0] - results = [] - for output in outputs: - output_tensor_shape = engine.get_tensor_shape(output) - output_tensor_type = _trt_dtype_to_torch(engine.get_tensor_dtype(output)) - result = torch.empty( - (batch_size,) + output_tensor_shape[1:], - dtype=output_tensor_type, - device=device, - ) - context.set_tensor_address(output, result.data_ptr()) - results.append(result) - status = context.set_input_shape(input_name, tuple(pre_processed_images.shape)) + if trt_cuda_graph_cache is not None: + input_shape = tuple(pre_processed_images.shape) + input_dtype = pre_processed_images.dtype + cache_key = (input_shape, input_dtype, device) + + if cache_key not in trt_cuda_graph_cache: + LOGGER.debug("Capturing CUDA graph for shape %s", input_shape) + + results, trt_cuda_graph = _capture_cuda_graph( + pre_processed_images=pre_processed_images, + engine=engine, + device=device, + input_name=input_name, + outputs=outputs, + ) + trt_cuda_graph_cache[cache_key] = trt_cuda_graph + return results + + else: + trt_cuda_graph_state = trt_cuda_graph_cache[cache_key] + stream = trt_cuda_graph_state.cuda_stream + with torch.cuda.stream(stream): + trt_cuda_graph_state.input_buffer.copy_(pre_processed_images) + trt_cuda_graph_state.cuda_graph.replay() + results = [buf.clone() for buf in trt_cuda_graph_state.output_buffers] + stream.synchronize() + return results + + else: + status = context.set_input_shape(input_name, tuple(pre_processed_images.shape)) + if not status: + raise ModelRuntimeError( + message="Failed to set TRT model input shape during forward pass from the model.", + help_url="https://inference-models.roboflow.com/errors/models-runtime/#modelruntimeerror", + ) + status = context.set_tensor_address(input_name, pre_processed_images.data_ptr()) + if not status: + raise ModelRuntimeError( + message="Failed to set input tensor data pointer during forward pass from the model.", + help_url="https://inference-models.roboflow.com/errors/models-runtime/#modelruntimeerror", + ) + results = [] + for output in outputs: + output_tensor_shape = context.get_tensor_shape(output) + output_tensor_type = _trt_dtype_to_torch(engine.get_tensor_dtype(output)) + result = torch.empty( + tuple(output_tensor_shape), + dtype=output_tensor_type, + device=device, + ) + context.set_tensor_address(output, result.data_ptr()) + results.append(result) + stream = torch.cuda.current_stream(device) + status = context.execute_async_v3(stream_handle=stream.cuda_stream) + if not status: + raise ModelRuntimeError( + message="Failed to complete inference from TRT model", + help_url="https://inference-models.roboflow.com/errors/models-runtime/#modelruntimeerror", + ) + return results + + +def _capture_cuda_graph( + pre_processed_images: torch.Tensor, + engine: trt.ICudaEngine, + device: torch.device, + input_name: str, + outputs: List[str], +) -> Tuple[List[torch.Tensor], TRTCudaGraphState]: + # Each CUDA graph needs its own execution context. Sharing a single context + # across graphs for different input shapes causes TRT to reallocate internal + # workspace buffers, invalidating GPU addresses baked into earlier graphs. + graph_context = engine.create_execution_context() + + input_buffer = torch.empty_like(pre_processed_images, device=device) + input_buffer.copy_(pre_processed_images) + + status = graph_context.set_input_shape( + input_name, tuple(pre_processed_images.shape) + ) if not status: raise ModelRuntimeError( - message="Failed to set TRT model input shape during forward pass from the model.", + message="Failed to set TRT model input shape during CUDA graph capture.", help_url="https://inference-models.roboflow.com/errors/models-runtime/#modelruntimeerror", ) - status = context.set_tensor_address(input_name, pre_processed_images.data_ptr()) + status = graph_context.set_tensor_address(input_name, input_buffer.data_ptr()) if not status: raise ModelRuntimeError( - message="Failed to set input tensor data pointer during forward pass from the model.", + message="Failed to set input tensor data pointer during CUDA graph capture.", help_url="https://inference-models.roboflow.com/errors/models-runtime/#modelruntimeerror", ) - stream = torch.cuda.current_stream(device) - status = context.execute_async_v3(stream_handle=stream.cuda_stream) - if not status: - raise ModelRuntimeError( - message="Failed to complete inference from TRT model", - help_url="https://inference-models.roboflow.com/errors/models-runtime/#modelruntimeerror", + + output_buffers = [] + for output in outputs: + output_tensor_shape = graph_context.get_tensor_shape(output) + output_tensor_type = _trt_dtype_to_torch(engine.get_tensor_dtype(output)) + output_buffer = torch.empty( + tuple(output_tensor_shape), + dtype=output_tensor_type, + device=device, ) - return results + graph_context.set_tensor_address(output, output_buffer.data_ptr()) + output_buffers.append(output_buffer) + + stream = torch.cuda.Stream(device=device) + with torch.cuda.stream(stream): + status = graph_context.execute_async_v3(stream_handle=stream.cuda_stream) + if not status: + raise ModelRuntimeError( + message="Failed to execute TRT model warmup before CUDA graph capture.", + help_url="https://inference-models.roboflow.com/errors/models-runtime/#modelruntimeerror", + ) + stream.synchronize() + + cuda_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(cuda_graph, stream=stream): + status = graph_context.execute_async_v3(stream_handle=stream.cuda_stream) + if not status: + raise ModelRuntimeError( + message="Failed to capture CUDA graph from TRT model execution.", + help_url="https://inference-models.roboflow.com/errors/models-runtime/#modelruntimeerror", + ) + with torch.cuda.stream(stream): + results = [buf.clone() for buf in output_buffers] + stream.synchronize() + + # in order to avoid drift of results - it's better to replay to get the results + with torch.cuda.stream(stream): + cuda_graph.replay() + results = [buf.clone() for buf in output_buffers] + stream.synchronize() + + trt_cuda_graph_state = TRTCudaGraphState( + cuda_graph=cuda_graph, + cuda_stream=stream, + input_buffer=input_buffer, + output_buffers=output_buffers, + execution_context=graph_context, + ) + + return results, trt_cuda_graph_state def _trt_dtype_to_torch(trt_dtype): diff --git a/inference_models/inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py b/inference_models/inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py index c807f4a641..2bc949760a 100644 --- a/inference_models/inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +++ b/inference_models/inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py @@ -38,6 +38,8 @@ pre_process_network_input, ) from inference_models.models.common.trt import ( + TRTCudaGraphCache, + establish_trt_cuda_graph_cache, get_trt_engine_inputs_and_outputs, infer_from_trt_engine, load_trt_model, @@ -81,6 +83,8 @@ def from_pretrained( model_name_or_path: str, device: torch.device = DEFAULT_DEVICE, engine_host_code_allowed: bool = False, + trt_cuda_graph_cache: Optional[TRTCudaGraphCache] = None, + default_trt_cuda_graph_cache_size: int = 8, **kwargs, ) -> "DeepLabV3PlusForSemanticSegmentationTRT": if device.type != "cuda": @@ -146,6 +150,10 @@ def from_pretrained( message=f"Implementation assume single model output, found: {len(outputs)}.", help_url="https://inference-models.roboflow.com/errors/model-loading/#corruptedmodelpackageerror", ) + trt_cuda_graph_cache = establish_trt_cuda_graph_cache( + default_cuda_graph_cache_size=default_trt_cuda_graph_cache_size, + cuda_graph_cache=trt_cuda_graph_cache, + ) return cls( engine=engine, input_name=inputs[0], @@ -157,6 +165,7 @@ def from_pretrained( device=device, cuda_context=cuda_context, execution_context=execution_context, + trt_cuda_graph_cache=trt_cuda_graph_cache, ) def __init__( @@ -171,6 +180,7 @@ def __init__( device: torch.device, cuda_context: cuda.Context, execution_context: trt.IExecutionContext, + trt_cuda_graph_cache: Optional[TRTCudaGraphCache], ): self._engine = engine self._input_name = input_name @@ -182,6 +192,7 @@ def __init__( self._device = device self._cuda_context = cuda_context self._execution_context = execution_context + self._trt_cuda_graph_cache = trt_cuda_graph_cache self._lock = Lock() self._inference_stream = torch.cuda.Stream(device=self._device) self._thread_local_storage = threading.local() @@ -212,8 +223,10 @@ def pre_process( def forward( self, pre_processed_images: PreprocessedInputs, + disable_cuda_graphs: bool = False, **kwargs, ) -> torch.Tensor: + cache = self._trt_cuda_graph_cache if not disable_cuda_graphs else None with self._lock: with use_cuda_context(context=self._cuda_context): return infer_from_trt_engine( @@ -225,6 +238,7 @@ def forward( input_name=self._input_name, outputs=self._output_names, stream=self._inference_stream, + trt_cuda_graph_cache=cache, )[0] def post_process( diff --git a/inference_models/inference_models/models/resnet/resnet_classification_trt.py b/inference_models/inference_models/models/resnet/resnet_classification_trt.py index e55a999515..8bad13c294 100644 --- a/inference_models/inference_models/models/resnet/resnet_classification_trt.py +++ b/inference_models/inference_models/models/resnet/resnet_classification_trt.py @@ -40,6 +40,8 @@ pre_process_network_input, ) from inference_models.models.common.trt import ( + TRTCudaGraphCache, + establish_trt_cuda_graph_cache, get_trt_engine_inputs_and_outputs, infer_from_trt_engine, load_trt_model, @@ -81,6 +83,8 @@ def from_pretrained( model_name_or_path: str, device: torch.device = DEFAULT_DEVICE, engine_host_code_allowed: bool = False, + trt_cuda_graph_cache: Optional[TRTCudaGraphCache] = None, + default_trt_cuda_graph_cache_size: int = 8, **kwargs, ) -> "ResNetForClassificationTRT": if device.type != "cuda": @@ -147,6 +151,10 @@ def from_pretrained( message=f"Implementation assume single model output, found: {len(outputs)}.", help_url="https://inference-models.roboflow.com/errors/model-loading/#corruptedmodelpackageerror", ) + trt_cuda_graph_cache = establish_trt_cuda_graph_cache( + default_cuda_graph_cache_size=default_trt_cuda_graph_cache_size, + cuda_graph_cache=trt_cuda_graph_cache, + ) return cls( engine=engine, input_name=inputs[0], @@ -157,6 +165,7 @@ def from_pretrained( device=device, cuda_context=cuda_context, execution_context=execution_context, + trt_cuda_graph_cache=trt_cuda_graph_cache, ) def __init__( @@ -170,6 +179,7 @@ def __init__( device: torch.device, cuda_context: cuda.Context, execution_context: trt.IExecutionContext, + trt_cuda_graph_cache: Optional[TRTCudaGraphCache], ): self._engine = engine self._input_name = input_name @@ -180,6 +190,7 @@ def __init__( self._device = device self._cuda_context = cuda_context self._execution_context = execution_context + self._trt_cuda_graph_cache = trt_cuda_graph_cache self._lock = Lock() self._inference_stream = torch.cuda.Stream(device=self._device) self._thread_local_storage = threading.local() @@ -212,8 +223,10 @@ def pre_process( def forward( self, pre_processed_images: PreprocessedInputs, + disable_cuda_graphs: bool = False, **kwargs, ) -> torch.Tensor: + cache = self._trt_cuda_graph_cache if not disable_cuda_graphs else None with self._lock: with use_cuda_context(context=self._cuda_context): return infer_from_trt_engine( @@ -225,6 +238,7 @@ def forward( input_name=self._input_name, outputs=self._output_names, stream=self._inference_stream, + trt_cuda_graph_cache=cache, )[0] def post_process( @@ -271,6 +285,8 @@ def from_pretrained( model_name_or_path: str, device: torch.device = DEFAULT_DEVICE, engine_host_code_allowed: bool = False, + trt_cuda_graph_cache: Optional[TRTCudaGraphCache] = None, + default_trt_cuda_graph_cache_size: int = 8, **kwargs, ) -> "ResNetForMultiLabelClassificationTRT": if device.type != "cuda": @@ -337,6 +353,10 @@ def from_pretrained( message=f"Implementation assume single model output, found: {len(outputs)}.", help_url="https://inference-models.roboflow.com/errors/model-loading/#corruptedmodelpackageerror", ) + trt_cuda_graph_cache = establish_trt_cuda_graph_cache( + default_cuda_graph_cache_size=default_trt_cuda_graph_cache_size, + cuda_graph_cache=trt_cuda_graph_cache, + ) return cls( engine=engine, input_name=inputs[0], @@ -347,6 +367,7 @@ def from_pretrained( device=device, cuda_context=cuda_context, execution_context=execution_context, + trt_cuda_graph_cache=trt_cuda_graph_cache, ) def __init__( @@ -360,6 +381,7 @@ def __init__( device: torch.device, cuda_context: cuda.Context, execution_context: trt.IExecutionContext, + trt_cuda_graph_cache: Optional[TRTCudaGraphCache], ): self._engine = engine self._input_name = input_name @@ -370,6 +392,7 @@ def __init__( self._device = device self._cuda_context = cuda_context self._execution_context = execution_context + self._trt_cuda_graph_cache = trt_cuda_graph_cache self._lock = Lock() self._inference_stream = torch.cuda.Stream(device=self._device) self._thread_local_storage = threading.local() @@ -402,8 +425,10 @@ def pre_process( def forward( self, pre_processed_images: PreprocessedInputs, + disable_cuda_graphs: bool = False, **kwargs, ) -> torch.Tensor: + cache = self._trt_cuda_graph_cache if not disable_cuda_graphs else None with self._lock: with use_cuda_context(context=self._cuda_context): return infer_from_trt_engine( @@ -415,6 +440,7 @@ def forward( input_name=self._input_name, outputs=self._output_names, stream=self._inference_stream, + trt_cuda_graph_cache=cache, )[0] def post_process( diff --git a/inference_models/inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py b/inference_models/inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py index b4da6efbe5..ebc59bfdf9 100644 --- a/inference_models/inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +++ b/inference_models/inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py @@ -34,6 +34,8 @@ parse_trt_config, ) from inference_models.models.common.trt import ( + TRTCudaGraphCache, + establish_trt_cuda_graph_cache, get_trt_engine_inputs_and_outputs, infer_from_trt_engine, load_trt_model, @@ -85,13 +87,14 @@ class RFDetrForInstanceSegmentationTRT( Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ] ): - @classmethod def from_pretrained( cls, model_name_or_path: str, device: torch.device = DEFAULT_DEVICE, engine_host_code_allowed: bool = False, + trt_cuda_graph_cache: Optional[TRTCudaGraphCache] = None, + default_trt_cuda_graph_cache_size: int = 8, **kwargs, ) -> "RFDetrForInstanceSegmentationTRT": if device.type != "cuda": @@ -160,6 +163,10 @@ def from_pretrained( message=f"Implementation assume 3 model outputs, found: {len(outputs)}.", help_url="https://inference-models.roboflow.com/errors/model-loading/#corruptedmodelpackageerror", ) + trt_cuda_graph_cache = establish_trt_cuda_graph_cache( + default_cuda_graph_cache_size=default_trt_cuda_graph_cache_size, + cuda_graph_cache=trt_cuda_graph_cache, + ) return cls( engine=engine, input_name=inputs[0], @@ -171,6 +178,7 @@ def from_pretrained( device=device, cuda_context=cuda_context, execution_context=execution_context, + trt_cuda_graph_cache=trt_cuda_graph_cache, ) def __init__( @@ -185,6 +193,7 @@ def __init__( device: torch.device, cuda_context: cuda.Context, execution_context: trt.IExecutionContext, + trt_cuda_graph_cache: Optional[TRTCudaGraphCache], ): self._engine = engine self._input_name = input_name @@ -196,6 +205,7 @@ def __init__( self._cuda_context = cuda_context self._execution_context = execution_context self._trt_config = trt_config + self._trt_cuda_graph_cache = trt_cuda_graph_cache self._lock = threading.Lock() self._inference_stream = torch.cuda.Stream(device=self._device) self._thread_local_storage = threading.local() @@ -228,8 +238,10 @@ def pre_process( def forward( self, pre_processed_images: torch.Tensor, + disable_cuda_graphs: bool = False, **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + cache = self._trt_cuda_graph_cache if not disable_cuda_graphs else None with self._lock: with use_cuda_context(context=self._cuda_context): detections, labels, masks = infer_from_trt_engine( @@ -241,6 +253,7 @@ def forward( input_name=self._input_name, outputs=self._output_names, stream=self._inference_stream, + trt_cuda_graph_cache=cache, ) return detections, labels, masks diff --git a/inference_models/inference_models/models/rfdetr/rfdetr_object_detection_trt.py b/inference_models/inference_models/models/rfdetr/rfdetr_object_detection_trt.py index 3ed19f32cb..cecc0e4c9d 100644 --- a/inference_models/inference_models/models/rfdetr/rfdetr_object_detection_trt.py +++ b/inference_models/inference_models/models/rfdetr/rfdetr_object_detection_trt.py @@ -33,6 +33,8 @@ rescale_image_detections, ) from inference_models.models.common.trt import ( + TRTCudaGraphCache, + establish_trt_cuda_graph_cache, get_trt_engine_inputs_and_outputs, infer_from_trt_engine, load_trt_model, @@ -78,13 +80,14 @@ class RFDetrForObjectDetectionTRT( ] ) ): - @classmethod def from_pretrained( cls, model_name_or_path: str, device: torch.device = DEFAULT_DEVICE, engine_host_code_allowed: bool = False, + trt_cuda_graph_cache: Optional[TRTCudaGraphCache] = None, + default_trt_cuda_graph_cache_size: int = 8, **kwargs, ) -> "RFDetrForObjectDetectionTRT": if device.type != "cuda": @@ -158,6 +161,10 @@ def from_pretrained( message=f"Expected model outputs to be named `output0` and `output1`, but found: {outputs}.", help_url="https://inference-models.roboflow.com/errors/model-loading/#corruptedmodelpackageerror", ) + trt_cuda_graph_cache = establish_trt_cuda_graph_cache( + default_cuda_graph_cache_size=default_trt_cuda_graph_cache_size, + cuda_graph_cache=trt_cuda_graph_cache, + ) return cls( engine=engine, input_name=inputs[0], @@ -169,6 +176,7 @@ def from_pretrained( device=device, cuda_context=cuda_context, execution_context=execution_context, + trt_cuda_graph_cache=trt_cuda_graph_cache, ) def __init__( @@ -183,6 +191,7 @@ def __init__( device: torch.device, cuda_context: cuda.Context, execution_context: trt.IExecutionContext, + trt_cuda_graph_cache: Optional[TRTCudaGraphCache], ): self._engine = engine self._input_name = input_name @@ -194,6 +203,7 @@ def __init__( self._cuda_context = cuda_context self._execution_context = execution_context self._trt_config = trt_config + self._trt_cuda_graph_cache = trt_cuda_graph_cache self._lock = threading.Lock() self._inference_stream = torch.cuda.Stream(device=self._device) self._thread_local_storage = threading.local() @@ -224,8 +234,10 @@ def pre_process( def forward( self, pre_processed_images: torch.Tensor, + disable_cuda_graphs: bool = False, **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: + cache = self._trt_cuda_graph_cache if not disable_cuda_graphs else None with self._lock: with use_cuda_context(context=self._cuda_context): detections, labels = infer_from_trt_engine( @@ -237,6 +249,7 @@ def forward( input_name=self._input_name, outputs=self._output_names, stream=self._inference_stream, + trt_cuda_graph_cache=cache, ) return detections, labels diff --git a/inference_models/inference_models/models/vit/vit_classification_trt.py b/inference_models/inference_models/models/vit/vit_classification_trt.py index 948d544d56..0ed60ff0f0 100644 --- a/inference_models/inference_models/models/vit/vit_classification_trt.py +++ b/inference_models/inference_models/models/vit/vit_classification_trt.py @@ -40,6 +40,8 @@ pre_process_network_input, ) from inference_models.models.common.trt import ( + TRTCudaGraphCache, + establish_trt_cuda_graph_cache, get_trt_engine_inputs_and_outputs, infer_from_trt_engine, load_trt_model, @@ -81,6 +83,8 @@ def from_pretrained( model_name_or_path: str, device: torch.device = DEFAULT_DEVICE, engine_host_code_allowed: bool = False, + trt_cuda_graph_cache: Optional[TRTCudaGraphCache] = None, + default_trt_cuda_graph_cache_size: int = 8, **kwargs, ) -> "VITForClassificationTRT": if device.type != "cuda": @@ -147,6 +151,10 @@ def from_pretrained( message=f"Implementation assume single model output, found: {len(outputs)}.", help_url="https://inference-models.roboflow.com/errors/model-loading/#corruptedmodelpackageerror", ) + trt_cuda_graph_cache = establish_trt_cuda_graph_cache( + default_cuda_graph_cache_size=default_trt_cuda_graph_cache_size, + cuda_graph_cache=trt_cuda_graph_cache, + ) return cls( engine=engine, input_name=inputs[0], @@ -157,6 +165,7 @@ def from_pretrained( device=device, cuda_context=cuda_context, execution_context=execution_context, + trt_cuda_graph_cache=trt_cuda_graph_cache, ) def __init__( @@ -170,6 +179,7 @@ def __init__( device: torch.device, cuda_context: cuda.Context, execution_context: trt.IExecutionContext, + trt_cuda_graph_cache: Optional[TRTCudaGraphCache], ): self._engine = engine self._input_name = input_name @@ -180,6 +190,7 @@ def __init__( self._device = device self._cuda_context = cuda_context self._execution_context = execution_context + self._trt_cuda_graph_cache = trt_cuda_graph_cache self._lock = Lock() self._inference_stream = torch.cuda.Stream(device=self._device) self._thread_local_storage = threading.local() @@ -210,8 +221,10 @@ def pre_process( def forward( self, pre_processed_images: PreprocessedInputs, + disable_cuda_graphs: bool = False, **kwargs, ) -> torch.Tensor: + cache = self._trt_cuda_graph_cache if not disable_cuda_graphs else None with self._lock: with use_cuda_context(context=self._cuda_context): return infer_from_trt_engine( @@ -223,6 +236,7 @@ def forward( input_name=self._input_name, outputs=self._output_names, stream=self._inference_stream, + trt_cuda_graph_cache=cache, )[0] def post_process( @@ -270,6 +284,8 @@ def from_pretrained( model_name_or_path: str, device: torch.device = DEFAULT_DEVICE, engine_host_code_allowed: bool = False, + trt_cuda_graph_cache: Optional[TRTCudaGraphCache] = None, + default_trt_cuda_graph_cache_size: int = 8, **kwargs, ) -> "VITForMultiLabelClassificationTRT": if device.type != "cuda": @@ -336,6 +352,10 @@ def from_pretrained( message=f"Implementation assume single model output, found: {len(outputs)}.", help_url="https://inference-models.roboflow.com/errors/model-loading/#corruptedmodelpackageerror", ) + trt_cuda_graph_cache = establish_trt_cuda_graph_cache( + default_cuda_graph_cache_size=default_trt_cuda_graph_cache_size, + cuda_graph_cache=trt_cuda_graph_cache, + ) return cls( engine=engine, input_name=inputs[0], @@ -346,6 +366,7 @@ def from_pretrained( device=device, cuda_context=cuda_context, execution_context=execution_context, + trt_cuda_graph_cache=trt_cuda_graph_cache, ) def __init__( @@ -359,6 +380,7 @@ def __init__( device: torch.device, cuda_context: cuda.Context, execution_context: trt.IExecutionContext, + trt_cuda_graph_cache: Optional[TRTCudaGraphCache], ): self._engine = engine self._input_name = input_name @@ -369,6 +391,7 @@ def __init__( self._device = device self._cuda_context = cuda_context self._execution_context = execution_context + self._trt_cuda_graph_cache = trt_cuda_graph_cache self._lock = Lock() self._inference_stream = torch.cuda.Stream(device=self._device) self._thread_local_storage = threading.local() @@ -399,8 +422,10 @@ def pre_process( def forward( self, pre_processed_images: PreprocessedInputs, + disable_cuda_graphs: bool = False, **kwargs, ) -> torch.Tensor: + cache = self._trt_cuda_graph_cache if not disable_cuda_graphs else None with self._lock: with use_cuda_context(context=self._cuda_context): return infer_from_trt_engine( @@ -412,6 +437,7 @@ def forward( input_name=self._input_name, outputs=self._output_names, stream=self._inference_stream, + trt_cuda_graph_cache=cache, )[0] def post_process( diff --git a/inference_models/inference_models/models/yolact/yolact_instance_segmentation_trt.py b/inference_models/inference_models/models/yolact/yolact_instance_segmentation_trt.py index ab2f2648b4..dfdfaf4a29 100644 --- a/inference_models/inference_models/models/yolact/yolact_instance_segmentation_trt.py +++ b/inference_models/inference_models/models/yolact/yolact_instance_segmentation_trt.py @@ -46,6 +46,8 @@ pre_process_network_input, ) from inference_models.models.common.trt import ( + TRTCudaGraphCache, + establish_trt_cuda_graph_cache, get_trt_engine_inputs_and_outputs, infer_from_trt_engine, load_trt_model, @@ -93,6 +95,8 @@ def from_pretrained( model_name_or_path: str, device: torch.device = DEFAULT_DEVICE, engine_host_code_allowed: bool = False, + trt_cuda_graph_cache: Optional[TRTCudaGraphCache] = None, + default_trt_cuda_graph_cache_size: int = 8, **kwargs, ) -> "YOLOACTForInstanceSegmentationTRT": if device.type != "cuda": @@ -154,6 +158,10 @@ def from_pretrained( message=f"Implementation assume 5 model outputs, found: {len(outputs)}.", help_url="https://inference-models.roboflow.com/errors/model-loading/#corruptedmodelpackageerror", ) + trt_cuda_graph_cache = establish_trt_cuda_graph_cache( + default_cuda_graph_cache_size=default_trt_cuda_graph_cache_size, + cuda_graph_cache=trt_cuda_graph_cache, + ) return cls( engine=engine, input_name=inputs[0], @@ -164,6 +172,7 @@ def from_pretrained( device=device, cuda_context=cuda_context, execution_context=execution_context, + trt_cuda_graph_cache=trt_cuda_graph_cache, ) def __init__( @@ -177,6 +186,7 @@ def __init__( device: torch.device, cuda_context: cuda.Context, execution_context: trt.IExecutionContext, + trt_cuda_graph_cache: Optional[TRTCudaGraphCache], ): self._engine = engine self._input_name = input_name @@ -187,6 +197,7 @@ def __init__( self._device = device self._cuda_context = cuda_context self._execution_context = execution_context + self._trt_cuda_graph_cache = trt_cuda_graph_cache self._lock = Lock() self._inference_stream = torch.cuda.Stream(device=self._device) self._thread_local_storage = threading.local() @@ -217,8 +228,10 @@ def pre_process( def forward( self, pre_processed_images: torch.Tensor, + disable_cuda_graphs: bool = False, **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + cache = self._trt_cuda_graph_cache if not disable_cuda_graphs else None with self._lock: with use_cuda_context(context=self._cuda_context): ( @@ -239,6 +252,7 @@ def forward( input_name=self._input_name, outputs=self._output_names, stream=self._inference_stream, + trt_cuda_graph_cache=cache, ) ) all_loc_data.append(loc_data) diff --git a/inference_models/inference_models/models/yolo26/yolo26_instance_segmentation_trt.py b/inference_models/inference_models/models/yolo26/yolo26_instance_segmentation_trt.py index ca2cbf454f..cf26334653 100644 --- a/inference_models/inference_models/models/yolo26/yolo26_instance_segmentation_trt.py +++ b/inference_models/inference_models/models/yolo26/yolo26_instance_segmentation_trt.py @@ -44,6 +44,8 @@ pre_process_network_input, ) from inference_models.models.common.trt import ( + TRTCudaGraphCache, + establish_trt_cuda_graph_cache, get_trt_engine_inputs_and_outputs, infer_from_trt_engine, load_trt_model, @@ -89,6 +91,8 @@ def from_pretrained( model_name_or_path: str, device: torch.device = DEFAULT_DEVICE, engine_host_code_allowed: bool = False, + trt_cuda_graph_cache: Optional[TRTCudaGraphCache] = None, + default_trt_cuda_graph_cache_size: int = 8, **kwargs, ) -> "YOLO26ForInstanceSegmentationTRT": if device.type != "cuda": @@ -155,6 +159,10 @@ def from_pretrained( message=f"Expected model outputs to be named `output0` and `output1`, but found: {outputs}.", help_url="https://inference-models.roboflow.com/errors/model-loading/#corruptedmodelpackageerror", ) + trt_cuda_graph_cache = establish_trt_cuda_graph_cache( + default_cuda_graph_cache_size=default_trt_cuda_graph_cache_size, + cuda_graph_cache=trt_cuda_graph_cache, + ) return cls( engine=engine, input_name=inputs[0], @@ -165,6 +173,7 @@ def from_pretrained( device=device, execution_context=execution_context, cuda_context=cuda_context, + trt_cuda_graph_cache=trt_cuda_graph_cache, ) def __init__( @@ -178,6 +187,7 @@ def __init__( device: torch.device, cuda_context: cuda.Context, execution_context: trt.IExecutionContext, + trt_cuda_graph_cache: Optional[TRTCudaGraphCache], ): self._engine = engine self._input_name = input_name @@ -188,6 +198,7 @@ def __init__( self._device = device self._cuda_context = cuda_context self._execution_context = execution_context + self._trt_cuda_graph_cache = trt_cuda_graph_cache self._session_thread_lock = Lock() self._inference_stream = torch.cuda.Stream(device=self._device) self._thread_local_storage = threading.local() @@ -218,8 +229,10 @@ def pre_process( def forward( self, pre_processed_images: torch.Tensor, + disable_cuda_graphs: bool = False, **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: + cache = self._trt_cuda_graph_cache if not disable_cuda_graphs else None with self._session_thread_lock: with use_cuda_context(context=self._cuda_context): instances, protos = infer_from_trt_engine( @@ -231,6 +244,7 @@ def forward( input_name=self._input_name, outputs=self._output_names, stream=self._inference_stream, + trt_cuda_graph_cache=cache, ) return instances, protos diff --git a/inference_models/inference_models/models/yolo26/yolo26_key_points_detection_trt.py b/inference_models/inference_models/models/yolo26/yolo26_key_points_detection_trt.py index 5dd7bdc141..ee944775cc 100644 --- a/inference_models/inference_models/models/yolo26/yolo26_key_points_detection_trt.py +++ b/inference_models/inference_models/models/yolo26/yolo26_key_points_detection_trt.py @@ -45,6 +45,8 @@ pre_process_network_input, ) from inference_models.models.common.trt import ( + TRTCudaGraphCache, + establish_trt_cuda_graph_cache, get_trt_engine_inputs_and_outputs, infer_from_trt_engine, load_trt_model, @@ -88,6 +90,8 @@ def from_pretrained( model_name_or_path: str, device: torch.device = DEFAULT_DEVICE, engine_host_code_allowed: bool = False, + trt_cuda_graph_cache: Optional[TRTCudaGraphCache] = None, + default_trt_cuda_graph_cache_size: int = 8, **kwargs, ) -> "YOLO26ForKeyPointsDetectionTRT": if device.type != "cuda": @@ -153,6 +157,10 @@ def from_pretrained( message=f"Implementation assume single model output, found: {len(outputs)}.", help_url="https://inference-models.roboflow.com/errors/model-loading/#corruptedmodelpackageerror", ) + trt_cuda_graph_cache = establish_trt_cuda_graph_cache( + default_cuda_graph_cache_size=default_trt_cuda_graph_cache_size, + cuda_graph_cache=trt_cuda_graph_cache, + ) return cls( engine=engine, input_name=inputs[0], @@ -165,6 +173,7 @@ def from_pretrained( device=device, cuda_context=cuda_context, execution_context=execution_context, + trt_cuda_graph_cache=trt_cuda_graph_cache, ) def __init__( @@ -180,12 +189,14 @@ def __init__( device: torch.device, cuda_context: cuda.Context, execution_context: trt.IExecutionContext, + trt_cuda_graph_cache: Optional[TRTCudaGraphCache], ): self._engine = engine self._input_name = input_name self._output_names = [output_name] self._cuda_context = cuda_context self._execution_context = execution_context + self._trt_cuda_graph_cache = trt_cuda_graph_cache self._class_names = class_names self._skeletons = skeletons self._inference_config = inference_config @@ -193,7 +204,6 @@ def __init__( self._trt_config = trt_config self._device = device self._session_thread_lock = Lock() - self._parsed_key_points_metadata = parsed_key_points_metadata self._key_points_classes_for_instances = torch.tensor( [len(e) for e in self._parsed_key_points_metadata], device=device ) @@ -237,8 +247,10 @@ def pre_process( def forward( self, pre_processed_images: torch.Tensor, + disable_cuda_graphs: bool = False, **kwargs, ) -> torch.Tensor: + cache = self._trt_cuda_graph_cache if not disable_cuda_graphs else None with self._session_thread_lock: with use_cuda_context(context=self._cuda_context): return infer_from_trt_engine( @@ -250,6 +262,7 @@ def forward( input_name=self._input_name, outputs=self._output_names, stream=self._inference_stream, + trt_cuda_graph_cache=cache, )[0] def post_process( diff --git a/inference_models/inference_models/models/yolo26/yolo26_object_detection_trt.py b/inference_models/inference_models/models/yolo26/yolo26_object_detection_trt.py index f7c299aa9a..b87666d40d 100644 --- a/inference_models/inference_models/models/yolo26/yolo26_object_detection_trt.py +++ b/inference_models/inference_models/models/yolo26/yolo26_object_detection_trt.py @@ -37,6 +37,8 @@ pre_process_network_input, ) from inference_models.models.common.trt import ( + TRTCudaGraphCache, + establish_trt_cuda_graph_cache, get_trt_engine_inputs_and_outputs, infer_from_trt_engine, load_trt_model, @@ -80,6 +82,8 @@ def from_pretrained( model_name_or_path: str, device: torch.device = DEFAULT_DEVICE, engine_host_code_allowed: bool = False, + trt_cuda_graph_cache: Optional[TRTCudaGraphCache] = None, + default_trt_cuda_graph_cache_size: int = 8, **kwargs, ) -> "YOLO26ForObjectDetectionTRT": if device.type != "cuda": @@ -141,6 +145,10 @@ def from_pretrained( message=f"Implementation assume single model output, found: {len(outputs)}.", help_url="https://inference-models.roboflow.com/errors/model-loading/#corruptedmodelpackageerror", ) + trt_cuda_graph_cache = establish_trt_cuda_graph_cache( + default_cuda_graph_cache_size=default_trt_cuda_graph_cache_size, + cuda_graph_cache=trt_cuda_graph_cache, + ) return cls( engine=engine, input_name=inputs[0], @@ -151,6 +159,7 @@ def from_pretrained( device=device, cuda_context=cuda_context, execution_context=execution_context, + trt_cuda_graph_cache=trt_cuda_graph_cache, ) def __init__( @@ -164,6 +173,7 @@ def __init__( device: torch.device, cuda_context: cuda.Context, execution_context: trt.IExecutionContext, + trt_cuda_graph_cache: Optional[TRTCudaGraphCache], ): self._engine = engine self._input_name = input_name @@ -174,6 +184,7 @@ def __init__( self._device = device self._cuda_context = cuda_context self._execution_context = execution_context + self._trt_cuda_graph_cache = trt_cuda_graph_cache self._lock = threading.Lock() self._inference_stream = torch.cuda.Stream(device=self._device) self._thread_local_storage = threading.local() @@ -204,8 +215,10 @@ def pre_process( def forward( self, pre_processed_images: torch.Tensor, + disable_cuda_graphs: bool = False, **kwargs, ) -> torch.Tensor: + cache = self._trt_cuda_graph_cache if not disable_cuda_graphs else None with self._lock: with use_cuda_context(context=self._cuda_context): return infer_from_trt_engine( @@ -217,6 +230,7 @@ def forward( input_name=self._input_name, outputs=self._output_names, stream=self._inference_stream, + trt_cuda_graph_cache=cache, )[0] def post_process( diff --git a/inference_models/inference_models/models/yolonas/yolonas_object_detection_trt.py b/inference_models/inference_models/models/yolonas/yolonas_object_detection_trt.py index fd8c3c1c59..d74bcc3cb5 100644 --- a/inference_models/inference_models/models/yolonas/yolonas_object_detection_trt.py +++ b/inference_models/inference_models/models/yolonas/yolonas_object_detection_trt.py @@ -38,6 +38,8 @@ pre_process_network_input, ) from inference_models.models.common.trt import ( + TRTCudaGraphCache, + establish_trt_cuda_graph_cache, get_trt_engine_inputs_and_outputs, infer_from_trt_engine, load_trt_model, @@ -83,6 +85,8 @@ def from_pretrained( model_name_or_path: str, device: torch.device = DEFAULT_DEVICE, engine_host_code_allowed: bool = False, + trt_cuda_graph_cache: Optional[TRTCudaGraphCache] = None, + default_trt_cuda_graph_cache_size: int = 8, **kwargs, ) -> "YOLONasForObjectDetectionTRT": if device.type != "cuda": @@ -155,6 +159,10 @@ def from_pretrained( help_url="https://inference-models.roboflow.com/errors/model-loading/#corruptedmodelpackageerror", ) # git rid of outputs order and names verification, as YOLO-NAS clearly produces different outputs + trt_cuda_graph_cache = establish_trt_cuda_graph_cache( + default_cuda_graph_cache_size=default_trt_cuda_graph_cache_size, + cuda_graph_cache=trt_cuda_graph_cache, + ) return cls( engine=engine, input_name=inputs[0], @@ -165,6 +173,7 @@ def from_pretrained( device=device, cuda_context=cuda_context, execution_context=execution_context, + trt_cuda_graph_cache=trt_cuda_graph_cache, ) def __init__( @@ -178,6 +187,7 @@ def __init__( device: torch.device, cuda_context: cuda.Context, execution_context: trt.IExecutionContext, + trt_cuda_graph_cache: Optional[TRTCudaGraphCache], ): self._engine = engine self._input_name = input_name @@ -188,6 +198,7 @@ def __init__( self._device = device self._cuda_context = cuda_context self._execution_context = execution_context + self._trt_cuda_graph_cache = trt_cuda_graph_cache self._session_thread_lock = Lock() self._inference_stream = torch.cuda.Stream(device=self._device) self._thread_local_storage = threading.local() @@ -215,7 +226,13 @@ def pre_process( self._pre_process_stream.synchronize() return pre_processed_images, pre_processing_meta - def forward(self, pre_processed_images: torch.Tensor, **kwargs) -> torch.Tensor: + def forward( + self, + pre_processed_images: torch.Tensor, + disable_cuda_graphs: bool = False, + **kwargs, + ) -> torch.Tensor: + cache = self._trt_cuda_graph_cache if not disable_cuda_graphs else None with self._session_thread_lock: with use_cuda_context(context=self._cuda_context): results = infer_from_trt_engine( @@ -227,6 +244,7 @@ def forward(self, pre_processed_images: torch.Tensor, **kwargs) -> torch.Tensor: input_name=self._input_name, outputs=self._output_names, stream=self._inference_stream, + trt_cuda_graph_cache=cache, ) return torch.cat(results, dim=-1) diff --git a/inference_models/inference_models/models/yolov10/yolov10_object_detection_trt.py b/inference_models/inference_models/models/yolov10/yolov10_object_detection_trt.py index fb1ec11c73..0950f3fd5a 100644 --- a/inference_models/inference_models/models/yolov10/yolov10_object_detection_trt.py +++ b/inference_models/inference_models/models/yolov10/yolov10_object_detection_trt.py @@ -38,6 +38,8 @@ pre_process_network_input, ) from inference_models.models.common.trt import ( + TRTCudaGraphCache, + establish_trt_cuda_graph_cache, get_trt_engine_inputs_and_outputs, infer_from_trt_engine, load_trt_model, @@ -80,6 +82,8 @@ def from_pretrained( model_name_or_path: str, device: torch.device = DEFAULT_DEVICE, engine_host_code_allowed: bool = False, + trt_cuda_graph_cache: Optional[TRTCudaGraphCache] = None, + default_trt_cuda_graph_cache_size: int = 8, **kwargs, ) -> "YOLOv10ForObjectDetectionTRT": if device.type != "cuda": @@ -141,6 +145,10 @@ def from_pretrained( message=f"Implementation assume single model output, found: {len(outputs)}.", help_url="https://inference-models.roboflow.com/errors/model-loading/#corruptedmodelpackageerror", ) + trt_cuda_graph_cache = establish_trt_cuda_graph_cache( + default_cuda_graph_cache_size=default_trt_cuda_graph_cache_size, + cuda_graph_cache=trt_cuda_graph_cache, + ) return cls( engine=engine, input_name=inputs[0], @@ -151,6 +159,7 @@ def from_pretrained( device=device, cuda_context=cuda_context, execution_context=execution_context, + trt_cuda_graph_cache=trt_cuda_graph_cache, ) def __init__( @@ -164,6 +173,7 @@ def __init__( device: torch.device, cuda_context: cuda.Context, execution_context: trt.IExecutionContext, + trt_cuda_graph_cache: Optional[TRTCudaGraphCache], ): self._engine = engine self._input_name = input_name @@ -174,6 +184,7 @@ def __init__( self._device = device self._cuda_context = cuda_context self._execution_context = execution_context + self._trt_cuda_graph_cache = trt_cuda_graph_cache self._session_thread_lock = Lock() self._inference_stream = torch.cuda.Stream(device=self._device) self._thread_local_storage = threading.local() @@ -204,8 +215,10 @@ def pre_process( def forward( self, pre_processed_images: torch.Tensor, + disable_cuda_graphs: bool = False, **kwargs, ) -> torch.Tensor: + cache = self._trt_cuda_graph_cache if not disable_cuda_graphs else None with self._session_thread_lock: with use_cuda_context(context=self._cuda_context): return infer_from_trt_engine( @@ -217,6 +230,7 @@ def forward( input_name=self._input_name, outputs=self._output_names, stream=self._inference_stream, + trt_cuda_graph_cache=cache, )[0] def post_process( diff --git a/inference_models/inference_models/models/yolov5/yolov5_instance_segmentation_trt.py b/inference_models/inference_models/models/yolov5/yolov5_instance_segmentation_trt.py index 71da6f20d7..f3b7af3559 100644 --- a/inference_models/inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +++ b/inference_models/inference_models/models/yolov5/yolov5_instance_segmentation_trt.py @@ -46,6 +46,8 @@ pre_process_network_input, ) from inference_models.models.common.trt import ( + TRTCudaGraphCache, + establish_trt_cuda_graph_cache, get_trt_engine_inputs_and_outputs, infer_from_trt_engine, load_trt_model, @@ -92,6 +94,8 @@ def from_pretrained( model_name_or_path: str, device: torch.device = DEFAULT_DEVICE, engine_host_code_allowed: bool = False, + trt_cuda_graph_cache: Optional[TRTCudaGraphCache] = None, + default_trt_cuda_graph_cache_size: int = 8, **kwargs, ) -> "YOLOv5ForInstanceSegmentationTRT": if device.type != "cuda": @@ -158,6 +162,10 @@ def from_pretrained( message=f"Expected model outputs to be named `output0` and `output1`, but found: {outputs}.", help_url="https://inference-models.roboflow.com/errors/model-loading/#corruptedmodelpackageerror", ) + trt_cuda_graph_cache = establish_trt_cuda_graph_cache( + default_cuda_graph_cache_size=default_trt_cuda_graph_cache_size, + cuda_graph_cache=trt_cuda_graph_cache, + ) return cls( engine=engine, input_name=inputs[0], @@ -168,6 +176,7 @@ def from_pretrained( device=device, cuda_context=cuda_context, execution_context=execution_context, + trt_cuda_graph_cache=trt_cuda_graph_cache, ) def __init__( @@ -181,6 +190,7 @@ def __init__( device: torch.device, cuda_context: cuda.Context, execution_context: trt.IExecutionContext, + trt_cuda_graph_cache: Optional[TRTCudaGraphCache], ): self._engine = engine self._input_name = input_name @@ -191,6 +201,7 @@ def __init__( self._device = device self._cuda_context = cuda_context self._execution_context = execution_context + self._trt_cuda_graph_cache = trt_cuda_graph_cache self._session_thread_lock = Lock() self._inference_stream = torch.cuda.Stream(device=self._device) self._thread_local_storage = threading.local() @@ -219,8 +230,12 @@ def pre_process( return pre_processed_images, pre_processing_meta def forward( - self, pre_processed_images: torch.Tensor, **kwargs + self, + pre_processed_images: torch.Tensor, + disable_cuda_graphs: bool = False, + **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: + cache = self._trt_cuda_graph_cache if not disable_cuda_graphs else None with self._session_thread_lock: with use_cuda_context(context=self._cuda_context): instances, protos = infer_from_trt_engine( @@ -232,6 +247,7 @@ def forward( input_name=self._input_name, outputs=self._output_names, stream=self._inference_stream, + trt_cuda_graph_cache=cache, ) return instances, protos diff --git a/inference_models/inference_models/models/yolov5/yolov5_object_detection_trt.py b/inference_models/inference_models/models/yolov5/yolov5_object_detection_trt.py index d7f671afd1..c61078e3f9 100644 --- a/inference_models/inference_models/models/yolov5/yolov5_object_detection_trt.py +++ b/inference_models/inference_models/models/yolov5/yolov5_object_detection_trt.py @@ -38,6 +38,8 @@ pre_process_network_input, ) from inference_models.models.common.trt import ( + TRTCudaGraphCache, + establish_trt_cuda_graph_cache, get_trt_engine_inputs_and_outputs, infer_from_trt_engine, load_trt_model, @@ -82,6 +84,8 @@ def from_pretrained( model_name_or_path: str, device: torch.device = DEFAULT_DEVICE, engine_host_code_allowed: bool = False, + trt_cuda_graph_cache: Optional[TRTCudaGraphCache] = None, + default_trt_cuda_graph_cache_size: int = 8, **kwargs, ) -> "YOLOv5ForObjectDetectionTRT": if device.type != "cuda": @@ -143,6 +147,10 @@ def from_pretrained( message=f"Implementation assume single model output, found: {len(outputs)}.", help_url="https://inference-models.roboflow.com/errors/model-loading/#corruptedmodelpackageerror", ) + trt_cuda_graph_cache = establish_trt_cuda_graph_cache( + default_cuda_graph_cache_size=default_trt_cuda_graph_cache_size, + cuda_graph_cache=trt_cuda_graph_cache, + ) return cls( engine=engine, input_name=inputs[0], @@ -153,6 +161,7 @@ def from_pretrained( device=device, cuda_context=cuda_context, execution_context=execution_context, + trt_cuda_graph_cache=trt_cuda_graph_cache, ) def __init__( @@ -166,6 +175,7 @@ def __init__( device: torch.device, cuda_context: cuda.Context, execution_context: trt.IExecutionContext, + trt_cuda_graph_cache: Optional[TRTCudaGraphCache], ): self._engine = engine self._input_name = input_name @@ -176,6 +186,7 @@ def __init__( self._device = device self._cuda_context = cuda_context self._execution_context = execution_context + self._trt_cuda_graph_cache = trt_cuda_graph_cache self._session_thread_lock = Lock() self._inference_stream = torch.cuda.Stream(device=self._device) self._thread_local_storage = threading.local() @@ -203,7 +214,13 @@ def pre_process( self._pre_process_stream.synchronize() return pre_processed_images, pre_processing_meta - def forward(self, pre_processed_images: torch.Tensor, **kwargs) -> torch.Tensor: + def forward( + self, + pre_processed_images: torch.Tensor, + disable_cuda_graphs: bool = False, + **kwargs, + ) -> torch.Tensor: + cache = self._trt_cuda_graph_cache if not disable_cuda_graphs else None with self._session_thread_lock: with use_cuda_context(context=self._cuda_context): return infer_from_trt_engine( @@ -215,6 +232,7 @@ def forward(self, pre_processed_images: torch.Tensor, **kwargs) -> torch.Tensor: input_name=self._input_name, outputs=self._output_names, stream=self._inference_stream, + trt_cuda_graph_cache=cache, )[0] def post_process( diff --git a/inference_models/inference_models/models/yolov7/yolov7_instance_segmentation_trt.py b/inference_models/inference_models/models/yolov7/yolov7_instance_segmentation_trt.py index 9d8090b34e..044295646f 100644 --- a/inference_models/inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +++ b/inference_models/inference_models/models/yolov7/yolov7_instance_segmentation_trt.py @@ -47,6 +47,8 @@ pre_process_network_input, ) from inference_models.models.common.trt import ( + TRTCudaGraphCache, + establish_trt_cuda_graph_cache, get_trt_engine_inputs_and_outputs, infer_from_trt_engine, load_trt_model, @@ -92,6 +94,8 @@ def from_pretrained( model_name_or_path: str, device: torch.device = DEFAULT_DEVICE, engine_host_code_allowed: bool = False, + trt_cuda_graph_cache: Optional[TRTCudaGraphCache] = None, + default_trt_cuda_graph_cache_size: int = 8, **kwargs, ) -> "YOLOv7ForInstanceSegmentationTRT": if device.type != "cuda": @@ -154,6 +158,10 @@ def from_pretrained( help_url="https://inference-models.roboflow.com/errors/model-loading/#corruptedmodelpackageerror", ) output_tensors = [outputs[0], outputs[4]] + trt_cuda_graph_cache = establish_trt_cuda_graph_cache( + default_cuda_graph_cache_size=default_trt_cuda_graph_cache_size, + cuda_graph_cache=trt_cuda_graph_cache, + ) return cls( engine=engine, input_name=inputs[0], @@ -164,6 +172,7 @@ def from_pretrained( device=device, cuda_context=cuda_context, execution_context=execution_context, + trt_cuda_graph_cache=trt_cuda_graph_cache, ) def __init__( @@ -177,6 +186,7 @@ def __init__( device: torch.device, cuda_context: cuda.Context, execution_context: trt.IExecutionContext, + trt_cuda_graph_cache: Optional[TRTCudaGraphCache], ): self._engine = engine self._input_name = input_name @@ -189,6 +199,7 @@ def __init__( self._execution_context = execution_context self._session_thread_lock = Lock() self._inference_stream = torch.cuda.Stream(device=self._device) + self._trt_cuda_graph_cache = trt_cuda_graph_cache self._thread_local_storage = threading.local() @property @@ -215,8 +226,12 @@ def pre_process( return pre_processed_images, pre_processing_meta def forward( - self, pre_processed_images: torch.Tensor, **kwargs + self, + pre_processed_images: torch.Tensor, + disable_cuda_graphs: bool = False, + **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: + cache = self._trt_cuda_graph_cache if not disable_cuda_graphs else None with self._session_thread_lock: with use_cuda_context(context=self._cuda_context): instances, protos = infer_from_trt_engine( @@ -228,6 +243,7 @@ def forward( input_name=self._input_name, outputs=self._output_tensors, stream=self._inference_stream, + trt_cuda_graph_cache=cache, ) return instances, protos diff --git a/inference_models/inference_models/models/yolov8/yolov8_instance_segmentation_trt.py b/inference_models/inference_models/models/yolov8/yolov8_instance_segmentation_trt.py index 56c430ccf5..28aa7f5b39 100644 --- a/inference_models/inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +++ b/inference_models/inference_models/models/yolov8/yolov8_instance_segmentation_trt.py @@ -48,6 +48,8 @@ pre_process_network_input, ) from inference_models.models.common.trt import ( + TRTCudaGraphCache, + establish_trt_cuda_graph_cache, get_trt_engine_inputs_and_outputs, infer_from_trt_engine, load_trt_model, @@ -93,6 +95,8 @@ def from_pretrained( model_name_or_path: str, device: torch.device = DEFAULT_DEVICE, engine_host_code_allowed: bool = False, + trt_cuda_graph_cache: Optional[TRTCudaGraphCache] = None, + default_trt_cuda_graph_cache_size: int = 8, **kwargs, ) -> "YOLOv8ForInstanceSegmentationTRT": if device.type != "cuda": @@ -164,6 +168,10 @@ def from_pretrained( message=f"Expected model outputs to be named `output0` and `output1`, but found: {outputs}.", help_url="https://inference-models.roboflow.com/errors/model-loading/#corruptedmodelpackageerror", ) + trt_cuda_graph_cache = establish_trt_cuda_graph_cache( + default_cuda_graph_cache_size=default_trt_cuda_graph_cache_size, + cuda_graph_cache=trt_cuda_graph_cache, + ) return cls( engine=engine, input_name=inputs[0], @@ -174,6 +182,7 @@ def from_pretrained( device=device, execution_context=execution_context, cuda_context=cuda_context, + trt_cuda_graph_cache=trt_cuda_graph_cache, ) def __init__( @@ -187,6 +196,7 @@ def __init__( device: torch.device, cuda_context: cuda.Context, execution_context: trt.IExecutionContext, + trt_cuda_graph_cache: Optional[TRTCudaGraphCache], ): self._engine = engine self._input_name = input_name @@ -200,6 +210,7 @@ def __init__( self._session_thread_lock = Lock() self._inference_stream = torch.cuda.Stream(device=self._device) self._thread_local_storage = threading.local() + self._trt_cuda_graph_cache = trt_cuda_graph_cache @property def class_names(self) -> List[str]: @@ -227,8 +238,10 @@ def pre_process( def forward( self, pre_processed_images: torch.Tensor, + disable_cuda_graphs: bool = False, **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: + cache = self._trt_cuda_graph_cache if not disable_cuda_graphs else None with self._session_thread_lock: with use_cuda_context(context=self._cuda_context): instances, protos = infer_from_trt_engine( @@ -240,6 +253,7 @@ def forward( input_name=self._input_name, outputs=self._output_names, stream=self._inference_stream, + trt_cuda_graph_cache=cache, ) return instances, protos diff --git a/inference_models/inference_models/models/yolov8/yolov8_key_points_detection_trt.py b/inference_models/inference_models/models/yolov8/yolov8_key_points_detection_trt.py index 4adf21965b..cb98489a0c 100644 --- a/inference_models/inference_models/models/yolov8/yolov8_key_points_detection_trt.py +++ b/inference_models/inference_models/models/yolov8/yolov8_key_points_detection_trt.py @@ -49,6 +49,8 @@ pre_process_network_input, ) from inference_models.models.common.trt import ( + TRTCudaGraphCache, + establish_trt_cuda_graph_cache, get_trt_engine_inputs_and_outputs, infer_from_trt_engine, load_trt_model, @@ -92,6 +94,8 @@ def from_pretrained( model_name_or_path: str, device: torch.device = DEFAULT_DEVICE, engine_host_code_allowed: bool = False, + trt_cuda_graph_cache: Optional[TRTCudaGraphCache] = None, + default_trt_cuda_graph_cache_size: int = 8, **kwargs, ) -> "YOLOv8ForKeyPointsDetectionTRT": if device.type != "cuda": @@ -162,6 +166,10 @@ def from_pretrained( message=f"Implementation assume single model output, found: {len(outputs)}.", help_url="https://inference-models.roboflow.com/errors/model-loading/#corruptedmodelpackageerror", ) + trt_cuda_graph_cache = establish_trt_cuda_graph_cache( + default_cuda_graph_cache_size=default_trt_cuda_graph_cache_size, + cuda_graph_cache=trt_cuda_graph_cache, + ) return cls( engine=engine, input_name=inputs[0], @@ -174,6 +182,7 @@ def from_pretrained( device=device, cuda_context=cuda_context, execution_context=execution_context, + trt_cuda_graph_cache=trt_cuda_graph_cache, ) def __init__( @@ -189,12 +198,14 @@ def __init__( device: torch.device, cuda_context: cuda.Context, execution_context: trt.IExecutionContext, + trt_cuda_graph_cache: Optional[TRTCudaGraphCache], ): self._engine = engine self._input_name = input_name self._output_names = [output_name] self._cuda_context = cuda_context self._execution_context = execution_context + self._trt_cuda_graph_cache = trt_cuda_graph_cache self._class_names = class_names self._skeletons = skeletons self._inference_config = inference_config @@ -202,7 +213,6 @@ def __init__( self._trt_config = trt_config self._device = device self._session_thread_lock = Lock() - self._parsed_key_points_metadata = parsed_key_points_metadata self._key_points_classes_for_instances = torch.tensor( [len(e) for e in self._parsed_key_points_metadata], device=device ) @@ -246,8 +256,10 @@ def pre_process( def forward( self, pre_processed_images: torch.Tensor, + disable_cuda_graphs: bool = False, **kwargs, ) -> torch.Tensor: + cache = self._trt_cuda_graph_cache if not disable_cuda_graphs else None with self._session_thread_lock: with use_cuda_context(context=self._cuda_context): return infer_from_trt_engine( @@ -259,6 +271,7 @@ def forward( input_name=self._input_name, outputs=self._output_names, stream=self._inference_stream, + trt_cuda_graph_cache=cache, )[0] def post_process( diff --git a/inference_models/inference_models/models/yolov8/yolov8_object_detection_trt.py b/inference_models/inference_models/models/yolov8/yolov8_object_detection_trt.py index 8dbef7c2f8..1c099fd67d 100644 --- a/inference_models/inference_models/models/yolov8/yolov8_object_detection_trt.py +++ b/inference_models/inference_models/models/yolov8/yolov8_object_detection_trt.py @@ -41,6 +41,8 @@ pre_process_network_input, ) from inference_models.models.common.trt import ( + TRTCudaGraphCache, + establish_trt_cuda_graph_cache, get_trt_engine_inputs_and_outputs, infer_from_trt_engine, load_trt_model, @@ -84,6 +86,8 @@ def from_pretrained( model_name_or_path: str, device: torch.device = DEFAULT_DEVICE, engine_host_code_allowed: bool = False, + trt_cuda_graph_cache: Optional[TRTCudaGraphCache] = None, + default_trt_cuda_graph_cache_size: int = 8, **kwargs, ) -> "YOLOv8ForObjectDetectionTRT": if device.type != "cuda": @@ -150,6 +154,10 @@ def from_pretrained( message=f"Implementation assume single model output, found: {len(outputs)}.", help_url="https://inference-models.roboflow.com/errors/model-loading/#corruptedmodelpackageerror", ) + trt_cuda_graph_cache = establish_trt_cuda_graph_cache( + default_cuda_graph_cache_size=default_trt_cuda_graph_cache_size, + cuda_graph_cache=trt_cuda_graph_cache, + ) return cls( engine=engine, input_name=inputs[0], @@ -160,6 +168,7 @@ def from_pretrained( device=device, cuda_context=cuda_context, execution_context=execution_context, + trt_cuda_graph_cache=trt_cuda_graph_cache, ) def __init__( @@ -173,6 +182,7 @@ def __init__( device: torch.device, cuda_context: cuda.Context, execution_context: trt.IExecutionContext, + trt_cuda_graph_cache: Optional[TRTCudaGraphCache], ): self._engine = engine self._input_name = input_name @@ -183,6 +193,7 @@ def __init__( self._device = device self._cuda_context = cuda_context self._execution_context = execution_context + self._trt_cuda_graph_cache = trt_cuda_graph_cache self._lock = threading.Lock() self._inference_stream = torch.cuda.Stream(device=self._device) self._thread_local_storage = threading.local() @@ -213,8 +224,10 @@ def pre_process( def forward( self, pre_processed_images: torch.Tensor, + disable_cuda_graphs: bool = False, **kwargs, ) -> torch.Tensor: + cache = self._trt_cuda_graph_cache if not disable_cuda_graphs else None with self._lock: with use_cuda_context(context=self._cuda_context): return infer_from_trt_engine( @@ -226,6 +239,7 @@ def forward( input_name=self._input_name, outputs=self._output_names, stream=self._inference_stream, + trt_cuda_graph_cache=cache, )[0] def post_process( diff --git a/inference_models/mkdocs.yml b/inference_models/mkdocs.yml index b68348e4e8..983cc25323 100644 --- a/inference_models/mkdocs.yml +++ b/inference_models/mkdocs.yml @@ -103,6 +103,7 @@ nav: - Load Models Locally: how-to/local-packages.md - Understand Roboflow Model Packages: how-to/roboflow-model-packages.md - Manage Cache: how-to/cache-management.md + - Use CUDA Graphs: how-to/use-cuda-graphs.md - Contributors: - Development Environment: contributors/dev-environment.md - Core Architecture: contributors/core-architecture.md @@ -147,6 +148,8 @@ nav: - get_trt_engine_inputs_and_outputs: api-reference/developer-tools/trt/get-trt-engine-inputs-and-outputs.md - infer_from_trt_engine: api-reference/developer-tools/trt/infer-from-trt-engine.md - load_trt_model: api-reference/developer-tools/trt/load-trt-model.md + - establish_trt_cuda_graph_cache: api-reference/developer-tools/trt/establish-trt-cuda-graph-cache.md + - TRTCudaGraphCache: api-reference/developer-tools/trt/trt-cuda-graph-cache.md - Entities: - RuntimeXRayResult: api-reference/developer-tools/runtime-xray-result.md - ModelMetadata: api-reference/developer-tools/model-metadata.md diff --git a/inference_models/pyproject.toml b/inference_models/pyproject.toml index 37aedc9c11..b601421715 100644 --- a/inference_models/pyproject.toml +++ b/inference_models/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "inference-models" -version = "0.20.2" +version = "0.21.0" description = "The new inference engine for Computer Vision models" readme = "README.md" requires-python = ">=3.10,<3.13" diff --git a/inference_models/tests/integration_tests/models/conftest.py b/inference_models/tests/integration_tests/models/conftest.py index 8e81196ae6..a5e2917319 100644 --- a/inference_models/tests/integration_tests/models/conftest.py +++ b/inference_models/tests/integration_tests/models/conftest.py @@ -192,6 +192,10 @@ SAM2_PACKAGE_URL = ( "https://storage.googleapis.com/roboflow-tests-assets/rf-platform-models/sam2.zip" ) + +RFDETR_NANO_T4_TRT_PACKAGE_URL = "https://storage.googleapis.com/roboflow-tests-assets/rf-platform-models/rfdetr-nano-t4-trt.zip" +RFDETR_SEG_NANO_T4_TRT_PACKAGE_URL = "https://storage.googleapis.com/roboflow-tests-assets/rf-platform-models/rfdetr-seg-nano-t4-trt.zip" +YOLOV8N_640_T4_TRT_PACKAGE_URL = "https://storage.googleapis.com/roboflow-tests-assets/rf-platform-models/yolov8n-640-t4-trt.zip" COIN_COUNTING_TRT_PACKAGE_YOLO_V8_URL = "https://storage.googleapis.com/roboflow-tests-assets/rf-platform-models/yolov8-coin-counting-trt-t4-package.zip" COIN_COUNTING_TRT_PACKAGE_RF_DETR_URL = "https://storage.googleapis.com/roboflow-tests-assets/rf-platform-models/rfdetr-coin-counting-trt-t4-package.zip" COIN_COUNTING_TRT_PACKAGE_YOLO_NAS_URL = "https://storage.googleapis.com/roboflow-tests-assets/rf-platform-models/yolo-nas-coin-counting-trt-t4-package.zip" @@ -451,6 +455,30 @@ def coin_counting_rfdetr_nano_torch_static_crop_center_crop_package() -> str: ) +@pytest.fixture(scope="module") +def rfdetr_nano_t4_trt_package() -> str: + return download_model_package( + model_package_zip_url=RFDETR_NANO_T4_TRT_PACKAGE_URL, + package_name="rfdetr-nano-t4-trt", + ) + + +@pytest.fixture(scope="module") +def rfdetr_seg_nano_t4_trt_package() -> str: + return download_model_package( + model_package_zip_url=RFDETR_SEG_NANO_T4_TRT_PACKAGE_URL, + package_name="rfdetr-seg-nano-t4-trt", + ) + + +@pytest.fixture(scope="module") +def yolov8n_640_t4_trt_package() -> str: + return download_model_package( + model_package_zip_url=YOLOV8N_640_T4_TRT_PACKAGE_URL, + package_name="yolov8n-640-t4-trt", + ) + + @pytest.fixture(scope="module") def coin_counting_rfdetr_nano_onnx_static_bs_nonsquare_letterbox_package() -> str: return download_model_package( diff --git a/inference_models/tests/integration_tests/models/test_resnet_classifier_predictions_trt.py b/inference_models/tests/integration_tests/models/test_resnet_classifier_predictions_trt.py index 528135627d..9c0b059f6e 100644 --- a/inference_models/tests/integration_tests/models/test_resnet_classifier_predictions_trt.py +++ b/inference_models/tests/integration_tests/models/test_resnet_classifier_predictions_trt.py @@ -73,6 +73,30 @@ def test_single_label_trt_package_torch( assert abs(predictions.confidence[0, 2].item() - 0.9999516010284424) < 1e-3 +@pytest.mark.slow +@pytest.mark.trt_extras +def test_single_label_trt_package_torch_multiple_predictions_in_row( + resnet_single_label_cls_trt_package: str, + bike_image_torch: np.ndarray, +) -> None: + # given + from inference_models.models.resnet.resnet_classification_trt import ( + ResNetForClassificationTRT, + ) + + model = ResNetForClassificationTRT.from_pretrained( + model_name_or_path=resnet_single_label_cls_trt_package, + engine_host_code_allowed=True, + ) + + for _ in range(8): + # when + predictions = model(bike_image_torch) + + # then + assert abs(predictions.confidence[0, 2].item() - 0.9999516010284424) < 1e-3 + + @pytest.mark.slow @pytest.mark.trt_extras def test_single_label_trt_package_torch_list( @@ -191,6 +215,30 @@ def test_multi_label_trt_package_torch( assert abs(predictions[0].confidence[2].item() - 0.99951171875) < 1e-3 +@pytest.mark.slow +@pytest.mark.trt_extras +def test_multi_label_trt_package_torch_multiple_predictions_in_row( + resnet_multi_label_cls_trt_package: str, + dog_image_torch: torch.Tensor, +) -> None: + # given + from inference_models.models.resnet.resnet_classification_trt import ( + ResNetForMultiLabelClassificationTRT, + ) + + model = ResNetForMultiLabelClassificationTRT.from_pretrained( + model_name_or_path=resnet_multi_label_cls_trt_package, + engine_host_code_allowed=True, + ) + + for _ in range(8): + # when + predictions = model(dog_image_torch) + + # then + assert abs(predictions[0].confidence[2].item() - 0.99951171875) < 1e-3 + + @pytest.mark.slow @pytest.mark.trt_extras def test_multi_label_trt_package_torch_list( diff --git a/inference_models/tests/integration_tests/models/test_rfdetr_predictions_trt.py b/inference_models/tests/integration_tests/models/test_rfdetr_predictions_trt.py index 322481ed85..b067349920 100644 --- a/inference_models/tests/integration_tests/models/test_rfdetr_predictions_trt.py +++ b/inference_models/tests/integration_tests/models/test_rfdetr_predictions_trt.py @@ -243,6 +243,71 @@ def test_trt_package_torch( ) +@pytest.mark.slow +@pytest.mark.trt_extras +def test_trt_package_torch_multiple_predictions_in_row( + rfdetr_coin_counting_trt_package: str, + coins_counting_image_torch: torch.Tensor, +) -> None: + # given + from inference_models.models.rfdetr.rfdetr_object_detection_trt import ( + RFDetrForObjectDetectionTRT, + ) + + model = RFDetrForObjectDetectionTRT.from_pretrained( + model_name_or_path=rfdetr_coin_counting_trt_package, + engine_host_code_allowed=True, + ) + + for _ in range(8): + # when + predictions = model(coins_counting_image_torch) + + # then + assert torch.allclose( + predictions[0].confidence.cpu(), + torch.tensor( + [ + 0.9815, + 0.9674, + 0.9638, + 0.9620, + 0.9584, + 0.9565, + 0.9560, + 0.9543, + 0.9520, + 0.9491, + ] + ).cpu(), + atol=0.01, + ) + assert torch.allclose( + predictions[0].class_id.cpu(), + torch.tensor([4, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.int32).cpu(), + ) + expected_xyxy = torch.tensor( + [ + [1323, 533, 3071, 1970], + [1708, 2572, 1887, 2760], + [1172, 2635, 1372, 2850], + [1744, 2296, 1914, 2472], + [1464, 2305, 1627, 2475], + [1255, 2063, 1423, 2233], + [1091, 2354, 1253, 2524], + [1508, 1884, 1721, 2093], + [929, 1843, 1091, 2004], + [2681, 802, 2867, 976], + ], + dtype=torch.int32, + ) + assert torch.allclose( + predictions[0].xyxy.cpu(), + expected_xyxy.cpu(), + atol=5, + ) + + @pytest.mark.slow @pytest.mark.trt_extras def test_trt_package_torch_list( @@ -455,3 +520,90 @@ def test_trt_package_torch_batch( expected_xyxy.cpu(), atol=5, ) + + +@pytest.mark.slow +@pytest.mark.trt_extras +def test_trt_cudagraph_output_matches_non_cudagraph_output( + rfdetr_nano_t4_trt_package: str, + dog_image_numpy: np.ndarray, + bike_image_numpy: np.ndarray, +) -> None: + from inference_models import AutoModel + from inference_models.models.common.trt import TRTCudaGraphCache + + trt_cuda_graph_cache = TRTCudaGraphCache(capacity=16) + model = AutoModel.from_pretrained( + model_id_or_path=rfdetr_nano_t4_trt_package, + device=torch.device("cuda:0"), + trt_cuda_graph_cache=trt_cuda_graph_cache, + ) + + pre_processed_1, _ = model.pre_process(dog_image_numpy) + pre_processed_2, _ = model.pre_process(bike_image_numpy) + + outputs = [] + for pre_processed in [pre_processed_1, pre_processed_2]: + no_graph = model.forward(pre_processed, disable_cuda_graphs=True) + capture_graph = model.forward(pre_processed) + replay_graph = model.forward(pre_processed) + + outputs.append((no_graph, capture_graph, replay_graph)) + + for image_outputs in outputs: + no_graph, capture_graph, replay_graph = image_outputs + for result_idx in range(2): + assert torch.allclose( + no_graph[result_idx], + capture_graph[result_idx], + atol=1e-6, + ) + assert torch.allclose( + no_graph[result_idx], + replay_graph[result_idx], + atol=1e-6, + ) + + # make sure that the allcloses aren't true because of buffer aliasing or something weird + # outputs should be different between images and the same between execution branches. + for execution_branch_idx in range(3): + for result_idx in range(2): + assert not torch.allclose( + outputs[0][execution_branch_idx][result_idx], + outputs[1][execution_branch_idx][result_idx], + atol=1e-6, + ) + + +@pytest.mark.slow +@pytest.mark.trt_extras +def test_trt_outputs_match_expected_shapes( + rfdetr_nano_t4_trt_package: str, + dog_image_numpy: np.ndarray, +) -> None: + from inference_models import AutoModel + from inference_models.models.common.trt import TRTCudaGraphCache + + trt_cuda_graph_cache = TRTCudaGraphCache(capacity=16) + model = AutoModel.from_pretrained( + model_id_or_path=rfdetr_nano_t4_trt_package, + device=torch.device("cuda:0"), + trt_cuda_graph_cache=trt_cuda_graph_cache, + ) + + pre_processed, _ = model.pre_process(dog_image_numpy) + + output = model.forward(pre_processed, disable_cuda_graphs=True) + + assert output[0].shape == (1, 300, 4) + assert output[1].shape == (1, 300, 91) + + output = model.forward(pre_processed) # capture + + assert output[0].shape == (1, 300, 4) + assert output[1].shape == (1, 300, 91) + + output = model.forward(pre_processed) # replay + + assert output[0].shape == (1, 300, 4) + assert output[1].shape == (1, 300, 91) diff --git a/inference_models/tests/integration_tests/models/test_rfdetr_seg_predictions_trt.py b/inference_models/tests/integration_tests/models/test_rfdetr_seg_predictions_trt.py index de7594a1c8..04befce4c7 100644 --- a/inference_models/tests/integration_tests/models/test_rfdetr_seg_predictions_trt.py +++ b/inference_models/tests/integration_tests/models/test_rfdetr_seg_predictions_trt.py @@ -145,6 +145,48 @@ def test_trt_package_torch( assert 16050 <= predictions[0].mask.cpu().sum().item() <= 16100 +@pytest.mark.slow +@pytest.mark.trt_extras +def test_trt_package_torch_multiple_predictions_in_row( + rfdetr_seg_asl_trt_package: str, + asl_image_torch: torch.Tensor, +) -> None: + # given + from inference_models.models.rfdetr.rfdetr_instance_segmentation_trt import ( + RFDetrForInstanceSegmentationTRT, + ) + + model = RFDetrForInstanceSegmentationTRT.from_pretrained( + model_name_or_path=rfdetr_seg_asl_trt_package, + engine_host_code_allowed=True, + ) + + for _ in range(8): + # when + predictions = model(asl_image_torch) + + # then + assert torch.allclose( + predictions[0].confidence.cpu(), + torch.tensor([0.9527]).cpu(), + atol=0.01, + ) + assert torch.allclose( + predictions[0].class_id.cpu(), + torch.tensor([20], dtype=torch.int32).cpu(), + ) + expected_xyxy = torch.tensor( + [[63, 173, 187, 374]], + dtype=torch.int32, + ) + assert torch.allclose( + predictions[0].xyxy.cpu(), + expected_xyxy.cpu(), + atol=5, + ) + assert 16050 <= predictions[0].mask.cpu().sum().item() <= 16100 + + @pytest.mark.slow @pytest.mark.trt_extras def test_trt_package_torch_list( @@ -263,3 +305,53 @@ def test_trt_package_torch_batch( atol=5, ) assert 16050 <= predictions[1].mask.cpu().sum().item() <= 16100 + + +@pytest.mark.slow +@pytest.mark.trt_extras +def test_trt_cudagraph_output_matches_non_cudagraph_output( + rfdetr_seg_nano_t4_trt_package: str, + snake_image_numpy: np.ndarray, + dog_image_numpy: np.ndarray, +) -> None: + from inference_models import AutoModel + from inference_models.models.common.trt import TRTCudaGraphCache + + trt_cuda_graph_cache = TRTCudaGraphCache(capacity=16) + model = AutoModel.from_pretrained( + model_id_or_path=rfdetr_seg_nano_t4_trt_package, + device=torch.device("cuda:0"), + trt_cuda_graph_cache=trt_cuda_graph_cache, + ) + + pre_processed_1, _ = model.pre_process(snake_image_numpy) + pre_processed_2, _ = model.pre_process(dog_image_numpy) + + outputs = [] + for pre_processed in [pre_processed_1, pre_processed_2]: + no_graph = model.forward(pre_processed, disable_cuda_graphs=True) + capture_graph = model.forward(pre_processed) + replay_graph = model.forward(pre_processed) + outputs.append((no_graph, capture_graph, replay_graph)) + + for image_outputs in outputs: + no_graph, capture_graph, replay_graph = image_outputs + for result_idx in range(3): + assert torch.allclose( + no_graph[result_idx], + capture_graph[result_idx], + atol=1e-6, + ) + assert torch.allclose( + no_graph[result_idx], + replay_graph[result_idx], + atol=1e-6, + ) + + for execution_branch_idx in range(3): + for result_idx in range(3): + assert not torch.allclose( + outputs[0][execution_branch_idx][result_idx], + outputs[1][execution_branch_idx][result_idx], + atol=1e-6, + ) diff --git a/inference_models/tests/integration_tests/models/test_vit_classifier_predictions_trt.py b/inference_models/tests/integration_tests/models/test_vit_classifier_predictions_trt.py index 5ea6481333..70b6985ae3 100644 --- a/inference_models/tests/integration_tests/models/test_vit_classifier_predictions_trt.py +++ b/inference_models/tests/integration_tests/models/test_vit_classifier_predictions_trt.py @@ -73,6 +73,30 @@ def test_single_label_trt_package_torch( assert abs(predictions.confidence[0, 2].item() - 0.7300973534584045) < 2e-2 +@pytest.mark.slow +@pytest.mark.trt_extras +def test_single_label_trt_package_torch_multiple_predictions_in_row( + vit_single_label_cls_trt_package: str, + bike_image_torch: np.ndarray, +) -> None: + # given + from inference_models.models.vit.vit_classification_trt import ( + VITForClassificationTRT, + ) + + model = VITForClassificationTRT.from_pretrained( + model_name_or_path=vit_single_label_cls_trt_package, + engine_host_code_allowed=True, + ) + + for _ in range(8): + # when + predictions = model(bike_image_torch) + + # then + assert abs(predictions.confidence[0, 2].item() - 0.7300973534584045) < 2e-2 + + @pytest.mark.slow @pytest.mark.trt_extras def test_single_label_trt_package_torch_list( @@ -191,6 +215,30 @@ def test_multi_label_trt_package_torch( assert abs(predictions[0].confidence[2].item() - 0.833984375) < 1e-3 +@pytest.mark.slow +@pytest.mark.trt_extras +def test_multi_label_trt_package_torch_multiple_predictions_in_row( + vit_multi_label_cls_trt_package: str, + dog_image_torch: torch.Tensor, +) -> None: + # given + from inference_models.models.vit.vit_classification_trt import ( + VITForMultiLabelClassificationTRT, + ) + + model = VITForMultiLabelClassificationTRT.from_pretrained( + model_name_or_path=vit_multi_label_cls_trt_package, + engine_host_code_allowed=True, + ) + + for _ in range(8): + # when + predictions = model(dog_image_torch) + + # then + assert abs(predictions[0].confidence[2].item() - 0.833984375) < 1e-3 + + @pytest.mark.slow @pytest.mark.trt_extras def test_multi_label_trt_package_torch_list( diff --git a/inference_models/tests/integration_tests/models/test_yolo26_instance_segmentation_predictions_trt.py b/inference_models/tests/integration_tests/models/test_yolo26_instance_segmentation_predictions_trt.py index 65873c080a..14f32ad0a9 100644 --- a/inference_models/tests/integration_tests/models/test_yolo26_instance_segmentation_predictions_trt.py +++ b/inference_models/tests/integration_tests/models/test_yolo26_instance_segmentation_predictions_trt.py @@ -145,6 +145,48 @@ def test_trt_package_torch( assert 16500 <= predictions[0].mask.cpu().sum().item() <= 16600 +@pytest.mark.slow +@pytest.mark.trt_extras +def test_trt_package_torch_multiple_predictions_in_row( + yolo26_seg_asl_trt_package: str, + asl_image_torch: torch.Tensor, +) -> None: + # given + from inference_models.models.yolo26.yolo26_instance_segmentation_trt import ( + YOLO26ForInstanceSegmentationTRT, + ) + + model = YOLO26ForInstanceSegmentationTRT.from_pretrained( + model_name_or_path=yolo26_seg_asl_trt_package, + engine_host_code_allowed=True, + ) + + # when + for _ in range(8): + predictions = model(asl_image_torch) + + # then + assert torch.allclose( + predictions[0].confidence.cpu(), + torch.tensor([0.9671]).cpu(), + atol=0.01, + ) + assert torch.allclose( + predictions[0].class_id.cpu(), + torch.tensor([20], dtype=torch.int32).cpu(), + ) + expected_xyxy = torch.tensor( + [[63, 174, 186, 368]], + dtype=torch.int32, + ) + assert torch.allclose( + predictions[0].xyxy.cpu(), + expected_xyxy.cpu(), + atol=5, + ) + assert 16500 <= predictions[0].mask.cpu().sum().item() <= 16600 + + @pytest.mark.slow @pytest.mark.trt_extras def test_trt_package_torch_list( diff --git a/inference_models/tests/integration_tests/models/test_yolo26_keypoints_detection_predictions_trt.py b/inference_models/tests/integration_tests/models/test_yolo26_keypoints_detection_predictions_trt.py index ce74b631f6..c4d8083077 100644 --- a/inference_models/tests/integration_tests/models/test_yolo26_keypoints_detection_predictions_trt.py +++ b/inference_models/tests/integration_tests/models/test_yolo26_keypoints_detection_predictions_trt.py @@ -144,6 +144,50 @@ def test_trt_package_torch( assert abs(predictions[0][0].confidence.sum().item() - 26.268831253051758) < 1e-2 +@pytest.mark.slow +@pytest.mark.trt_extras +def test_trt_package_torch_multiple_predictions_in_row( + yolo26_pose_trt_package: str, + people_walking_image_torch: torch.Tensor, +) -> None: + # given + from inference_models.models.yolo26.yolo26_key_points_detection_trt import ( + YOLO26ForKeyPointsDetectionTRT, + ) + + model = YOLO26ForKeyPointsDetectionTRT.from_pretrained( + model_name_or_path=yolo26_pose_trt_package, + engine_host_code_allowed=True, + ) + + for _ in range(8): + # when + predictions = model(people_walking_image_torch) + + # then + assert torch.allclose( + predictions[1][0].confidence.cpu(), + torch.tensor([0.9271, 0.9230]).cpu(), + atol=0.01, + ) + assert torch.allclose( + predictions[1][0].class_id.cpu(), + torch.tensor([0, 0], dtype=torch.int32).cpu(), + ) + expected_xyxy = torch.tensor( + [[353, 129, 539, 758], [618, 123, 822, 771]], + dtype=torch.int32, + ) + assert torch.allclose( + predictions[1][0].xyxy.cpu(), + expected_xyxy.cpu(), + atol=5, + ) + assert ( + abs(predictions[0][0].confidence.sum().item() - 26.268831253051758) < 1e-2 + ) + + @pytest.mark.slow @pytest.mark.trt_extras def test_trt_package_torch_list( diff --git a/inference_models/tests/integration_tests/models/test_yolo26_object_detection_predictions_trt.py b/inference_models/tests/integration_tests/models/test_yolo26_object_detection_predictions_trt.py index ddd5823858..811f32f9cb 100644 --- a/inference_models/tests/integration_tests/models/test_yolo26_object_detection_predictions_trt.py +++ b/inference_models/tests/integration_tests/models/test_yolo26_object_detection_predictions_trt.py @@ -247,6 +247,75 @@ def test_trt_package_torch( ) +@pytest.mark.slow +@pytest.mark.trt_extras +def test_trt_package_torch_multiple_predictions_in_row( + yolo26_object_detections_coin_counting_trt_package: str, + coins_counting_image_torch: torch.Tensor, +) -> None: + # given + from inference_models.models.yolo26.yolo26_object_detection_trt import ( + YOLO26ForObjectDetectionTRT, + ) + + model = YOLO26ForObjectDetectionTRT.from_pretrained( + model_name_or_path=yolo26_object_detections_coin_counting_trt_package, + engine_host_code_allowed=True, + ) + + # when + for _ in range(8): + predictions = model(coins_counting_image_torch) + + # then + assert torch.allclose( + predictions[0].confidence.cpu(), + torch.tensor( + [ + 0.9837, + 0.9707, + 0.9196, + 0.8495, + 0.8418, + 0.8408, + 0.5737, + 0.4922, + 0.4282, + 0.4273, + 0.2606, + ] + ).cpu(), + atol=0.01, + ) + assert torch.allclose( + predictions[0].class_id.cpu(), + torch.tensor([2, 2, 2, 1, 3, 0, 0, 0, 3, 1, 3], dtype=torch.int32).cpu(), + ) + expected_xyxy = torch.tensor( + [ + [ + [1252, 2049, 1431, 2241], + [1741, 2286, 1921, 2480], + [1707, 2565, 1896, 2770], + [1164, 2624, 1382, 2856], + [1502, 1867, 1728, 2096], + [1459, 2296, 1633, 2476], + [923, 1836, 1100, 2009], + [1090, 2346, 1268, 2525], + [1256, 2059, 1425, 2234], + [1164, 2626, 1381, 2857], + [2671, 792, 2875, 979], + ] + ], + dtype=torch.int32, + ) + assert torch.allclose( + predictions[0].xyxy.cpu(), + expected_xyxy.cpu(), + atol=5, + ) + + @pytest.mark.slow @pytest.mark.trt_extras def test_trt_package_torch_list( diff --git a/inference_models/tests/integration_tests/models/test_yolonas_predictions_trt.py b/inference_models/tests/integration_tests/models/test_yolonas_predictions_trt.py index 21dadc4b49..3ee3ab4535 100644 --- a/inference_models/tests/integration_tests/models/test_yolonas_predictions_trt.py +++ b/inference_models/tests/integration_tests/models/test_yolonas_predictions_trt.py @@ -20,6 +20,9 @@ def test_trt_package_numpy( ) # when + # warmup + for _ in range(5): + _ = model(coins_counting_image_numpy) predictions = model(coins_counting_image_numpy) # then @@ -88,6 +91,9 @@ def test_trt_package_batch_numpy( ) # when + # warmup + for _ in range(5): + _ = model([coins_counting_image_numpy, coins_counting_image_numpy]) predictions = model([coins_counting_image_numpy, coins_counting_image_numpy]) # then @@ -202,6 +208,9 @@ def test_trt_package_torch( ) # when + # warmup + for _ in range(5): + _ = model(coins_counting_image_torch) predictions = model(coins_counting_image_torch) # then @@ -253,6 +262,78 @@ def test_trt_package_torch( ) +@pytest.mark.slow +@pytest.mark.trt_extras +def test_trt_package_torch_multiple_predictions_in_row( + yolo_nas_coin_counting_trt_package: str, + coins_counting_image_torch: torch.Tensor, +) -> None: + # given + from inference_models.models.yolonas.yolonas_object_detection_trt import ( + YOLONasForObjectDetectionTRT, + ) + + model = YOLONasForObjectDetectionTRT.from_pretrained( + model_name_or_path=yolo_nas_coin_counting_trt_package, + engine_host_code_allowed=True, + ) + + # when + # warmup + for _ in range(5): + _ = model(coins_counting_image_torch) + for _ in range(8): + predictions = model(coins_counting_image_torch) + + # then + assert torch.allclose( + predictions[0].confidence.cpu(), + torch.tensor( + [ + 0.8929, + 0.8762, + 0.8625, + 0.8573, + 0.8434, + 0.7718, + 0.7705, + 0.7628, + 0.6723, + 0.6343, + 0.4533, + 0.4388, + ] + ).cpu(), + atol=0.01, + ) + assert torch.allclose( + predictions[0].class_id.cpu(), + torch.tensor([2, 1, 0, 0, 0, 0, 3, 3, 2, 2, 0, 1], dtype=torch.int32).cpu(), + ) + expected_xyxy = torch.tensor( + [ + [1693, 2548, 1910, 2774], + [1161, 2618, 1389, 2868], + [1445, 2291, 1641, 2483], + [913, 1823, 1110, 2017], + [1080, 2334, 1275, 2537], + [1727, 2285, 1931, 2482], + [2664, 763, 2887, 1001], + [1491, 1862, 1740, 2101], + [1727, 2283, 1932, 2487], + [1238, 2041, 1438, 2243], + [1485, 1864, 1743, 2106], + [1236, 2040, 1439, 2245], + ], + dtype=torch.int32, + ) + assert torch.allclose( + predictions[0].xyxy.cpu(), + expected_xyxy.cpu(), + atol=5, + ) + + @pytest.mark.slow @pytest.mark.trt_extras def test_trt_package_torch_list( @@ -270,6 +351,9 @@ def test_trt_package_torch_list( ) # when + # warmup + for _ in range(5): + _ = model([coins_counting_image_torch, coins_counting_image_torch]) predictions = model([coins_counting_image_torch, coins_counting_image_torch]) # then @@ -384,6 +468,9 @@ def test_trt_package_torch_batch( ) # when + # warmup + for _ in range(5): + _ = model(torch.stack([coins_counting_image_torch, coins_counting_image_torch], dim=0)) predictions = model( torch.stack([coins_counting_image_torch, coins_counting_image_torch], dim=0) ) diff --git a/inference_models/tests/integration_tests/models/test_yolov10_object_detection_predictions_trt.py b/inference_models/tests/integration_tests/models/test_yolov10_object_detection_predictions_trt.py index 00159c653c..e35b16c3a5 100644 --- a/inference_models/tests/integration_tests/models/test_yolov10_object_detection_predictions_trt.py +++ b/inference_models/tests/integration_tests/models/test_yolov10_object_detection_predictions_trt.py @@ -255,3 +255,44 @@ def test_trt_package_torch_batch( expected_xyxy.cpu(), atol=5, ) + + +@pytest.mark.slow +@pytest.mark.trt_extras +def test_trt_package_torch_multiple_predictions_in_row( + yolov10_object_detection_trt_package: str, + dog_image_torch: torch.Tensor, +) -> None: + # given + from inference_models.models.yolov10.yolov10_object_detection_trt import ( + YOLOv10ForObjectDetectionTRT, + ) + + model = YOLOv10ForObjectDetectionTRT.from_pretrained( + model_name_or_path=yolov10_object_detection_trt_package, + engine_host_code_allowed=True, + ) + + # when + for _ in range(8): + predictions = model(dog_image_torch) + + # then + assert torch.allclose( + predictions[0].confidence.cpu(), + torch.tensor([0.5039]).cpu(), + atol=0.01, + ) + assert torch.allclose( + predictions[0].class_id.cpu(), + torch.tensor([16], dtype=torch.int32).cpu(), + ) + expected_xyxy = torch.tensor( + [[71, 253, 646, 970]], + dtype=torch.int32, + ) + assert torch.allclose( + predictions[0].xyxy.cpu(), + expected_xyxy.cpu(), + atol=5, + ) diff --git a/inference_models/tests/integration_tests/models/test_yolov8_instance_segmentation_predictions_trt.py b/inference_models/tests/integration_tests/models/test_yolov8_instance_segmentation_predictions_trt.py index 39a27c75df..01c6bd6ee7 100644 --- a/inference_models/tests/integration_tests/models/test_yolov8_instance_segmentation_predictions_trt.py +++ b/inference_models/tests/integration_tests/models/test_yolov8_instance_segmentation_predictions_trt.py @@ -145,6 +145,48 @@ def test_trt_package_torch( assert 16100 <= predictions[0].mask.cpu().sum().item() <= 16200 +@pytest.mark.slow +@pytest.mark.trt_extras +def test_trt_package_torch_multiple_predictions_in_row( + yolov8_seg_asl_trt_package: str, + asl_image_torch: torch.Tensor, +) -> None: + # given + from inference_models.models.yolov8.yolov8_instance_segmentation_trt import ( + YOLOv8ForInstanceSegmentationTRT, + ) + + model = YOLOv8ForInstanceSegmentationTRT.from_pretrained( + model_name_or_path=yolov8_seg_asl_trt_package, + engine_host_code_allowed=True, + ) + + # when + for _ in range(8): + predictions = model(asl_image_torch) + + # then + assert torch.allclose( + predictions[0].confidence.cpu(), + torch.tensor([0.9795]).cpu(), + atol=0.01, + ) + assert torch.allclose( + predictions[0].class_id.cpu(), + torch.tensor([20], dtype=torch.int32).cpu(), + ) + expected_xyxy = torch.tensor( + [[63, 174, 187, 368]], + dtype=torch.int32, + ) + assert torch.allclose( + predictions[0].xyxy.cpu(), + expected_xyxy.cpu(), + atol=5, + ) + assert 16100 <= predictions[0].mask.cpu().sum().item() <= 16200 + + @pytest.mark.slow @pytest.mark.trt_extras def test_trt_package_torch_list( diff --git a/inference_models/tests/integration_tests/models/test_yolov8_keypoints_detection_predictions_trt.py b/inference_models/tests/integration_tests/models/test_yolov8_keypoints_detection_predictions_trt.py index 03f6e40db0..a6e60b8bd1 100644 --- a/inference_models/tests/integration_tests/models/test_yolov8_keypoints_detection_predictions_trt.py +++ b/inference_models/tests/integration_tests/models/test_yolov8_keypoints_detection_predictions_trt.py @@ -144,6 +144,48 @@ def test_trt_package_torch( assert abs(predictions[0][0].confidence.sum().item() - 26.07147979736328) < 1e-2 +@pytest.mark.slow +@pytest.mark.trt_extras +def test_trt_package_torch_multiple_predictions_in_row( + yolov8_pose_trt_package: str, + people_walking_image_torch: torch.Tensor, +) -> None: + # given + from inference_models.models.yolov8.yolov8_key_points_detection_trt import ( + YOLOv8ForKeyPointsDetectionTRT, + ) + + model = YOLOv8ForKeyPointsDetectionTRT.from_pretrained( + model_name_or_path=yolov8_pose_trt_package, + engine_host_code_allowed=True, + ) + + for _ in range(8): + # when + predictions = model(people_walking_image_torch) + + # then + assert torch.allclose( + predictions[1][0].confidence.cpu(), + torch.tensor([0.8783, 0.8719]).cpu(), + atol=0.01, + ) + assert torch.allclose( + predictions[1][0].class_id.cpu(), + torch.tensor([0, 0], dtype=torch.int32).cpu(), + ) + expected_xyxy = torch.tensor( + [[351, 124, 540, 756], [619, 120, 824, 767]], + dtype=torch.int32, + ) + assert torch.allclose( + predictions[1][0].xyxy.cpu(), + expected_xyxy.cpu(), + atol=5, + ) + assert abs(predictions[0][0].confidence.sum().item() - 26.07147979736328) < 1e-2 + + @pytest.mark.slow @pytest.mark.trt_extras def test_trt_package_torch_list( diff --git a/inference_models/tests/integration_tests/models/test_yolov8_object_detection_predictions_trt.py b/inference_models/tests/integration_tests/models/test_yolov8_object_detection_predictions_trt.py index 780dc63f5d..1648beac82 100644 --- a/inference_models/tests/integration_tests/models/test_yolov8_object_detection_predictions_trt.py +++ b/inference_models/tests/integration_tests/models/test_yolov8_object_detection_predictions_trt.py @@ -237,6 +237,71 @@ def test_trt_package_torch( ) +@pytest.mark.slow +@pytest.mark.trt_extras +def test_trt_package_torch_multiple_predictions_in_row( + yolov8_coin_counting_trt_package: str, + coins_counting_image_torch: torch.Tensor, +) -> None: + # given + from inference_models.models.yolov8.yolov8_object_detection_trt import ( + YOLOv8ForObjectDetectionTRT, + ) + + model = YOLOv8ForObjectDetectionTRT.from_pretrained( + model_name_or_path=yolov8_coin_counting_trt_package, + engine_host_code_allowed=True, + ) + + # when + for _ in range(8): + predictions = model(coins_counting_image_torch) + + # then + assert torch.allclose( + predictions[0].confidence.cpu(), + torch.tensor( + [ + 0.9956, + 0.9727, + 0.9653, + 0.9468, + 0.9448, + 0.9390, + 0.9302, + 0.9287, + 0.9155, + 0.9019, + ] + ).cpu(), + atol=0.01, + ) + assert torch.allclose( + predictions[0].class_id.cpu(), + torch.tensor([4, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.int32).cpu(), + ) + expected_xyxy = torch.tensor( + [ + [1304, 614, 3024, 1918], + [1714, 2571, 1884, 2759], + [2678, 806, 2866, 974], + [1744, 2294, 1914, 2469], + [1260, 2058, 1424, 2233], + [1469, 2302, 1624, 2467], + [929, 1843, 1091, 1997], + [1514, 1880, 1718, 2089], + [1177, 2632, 1374, 2846], + [1099, 2348, 1260, 2522], + ], + dtype=torch.int32, + ) + assert torch.allclose( + predictions[0].xyxy.cpu(), + expected_xyxy.cpu(), + atol=5, + ) + + @pytest.mark.slow @pytest.mark.trt_extras def test_trt_package_torch_list( @@ -449,3 +514,140 @@ def test_trt_package_torch_batch( expected_xyxy.cpu(), atol=5, ) + + +@pytest.mark.slow +@pytest.mark.trt_extras +def test_trt_cudagraph_cache_reuses_previously_seen_input_shapes( + yolov8n_640_t4_trt_package: str, + dog_image_numpy: np.ndarray, +) -> None: + from inference_models import AutoModel + from inference_models.models.common.trt import TRTCudaGraphCache + + device = torch.device("cuda:0") + trt_cuda_graph_cache = TRTCudaGraphCache(capacity=16) + model = AutoModel.from_pretrained( + model_id_or_path=yolov8n_640_t4_trt_package, + device=device, + trt_cuda_graph_cache=trt_cuda_graph_cache, + ) + + pre_processed_single, _ = model.pre_process(dog_image_numpy) + + seen_shapes = set() + capture_outputs = {} + test_sequence = [1, 2, 1, 4, 2, 1, 4, 3, 3] + + for batch_size in test_sequence: + batch = pre_processed_single.repeat(batch_size, 1, 1, 1) + cache_key = (tuple(batch.shape), batch.dtype, device) + + cache_size_before = trt_cuda_graph_cache.get_current_size() + + output = model.forward(batch) + + cache_size_after = trt_cuda_graph_cache.get_current_size() + + if cache_key not in seen_shapes: + assert cache_size_after == cache_size_before + 1 + seen_shapes.add(cache_key) + capture_outputs[cache_key] = output.clone() + continue + + assert cache_size_after == cache_size_before + assert torch.allclose(capture_outputs[cache_key], output, atol=1e-6) + + assert set(trt_cuda_graph_cache.list_keys()) == seen_shapes + + +@pytest.mark.slow +@pytest.mark.trt_extras +def test_trt_cudagraph_output_matches_non_cudagraph_output( + yolov8n_640_t4_trt_package: str, + dog_image_numpy: np.ndarray, +) -> None: + from inference_models import AutoModel + from inference_models.models.common.trt import TRTCudaGraphCache + + device = torch.device("cuda:0") + trt_cuda_graph_cache = TRTCudaGraphCache(capacity=16) + model = AutoModel.from_pretrained( + model_id_or_path=yolov8n_640_t4_trt_package, + device=device, + trt_cuda_graph_cache=trt_cuda_graph_cache, + ) + pre_processed_single, _ = model.pre_process(dog_image_numpy) + + for batch_size in [1, 4]: + batch = pre_processed_single.repeat(batch_size, 1, 1, 1) + + no_graph = model.forward(batch, disable_cuda_graphs=True) + + capture_graph = model.forward(batch) + replay_graph = model.forward(batch) + + assert torch.allclose(no_graph, capture_graph, atol=1e-6) + assert torch.allclose(no_graph, replay_graph, atol=1e-6) + + +@pytest.mark.slow +@pytest.mark.trt_extras +def test_trt_cudagraph_cache_eviction( + yolov8n_640_t4_trt_package: str, + dog_image_numpy: np.ndarray, +) -> None: + from inference_models import AutoModel + from inference_models.models.common.trt import TRTCudaGraphCache + + device = torch.device("cuda:0") + trt_cuda_graph_cache = TRTCudaGraphCache(capacity=3) + model = AutoModel.from_pretrained( + model_id_or_path=yolov8n_640_t4_trt_package, + device=device, + trt_cuda_graph_cache=trt_cuda_graph_cache, + ) + + pre_processed_single, _ = model.pre_process(dog_image_numpy) + + batch_sizes = [1, 2, 3] + for bs in batch_sizes: + batch = pre_processed_single.repeat(bs, 1, 1, 1) + model.forward(batch) + + assert trt_cuda_graph_cache.get_current_size() == 3 + keys_before = list(trt_cuda_graph_cache.list_keys()) + + batch_4 = pre_processed_single.repeat(4, 1, 1, 1) + model.forward(batch_4) + + assert trt_cuda_graph_cache.get_current_size() == 3 + keys_after = trt_cuda_graph_cache.list_keys() + assert keys_before[0] not in keys_after + for key in keys_before[1:]: + assert key in keys_after + key_4 = (tuple(batch_4.shape), batch_4.dtype, device) + assert key_4 in trt_cuda_graph_cache + + batch_2 = pre_processed_single.repeat(2, 1, 1, 1) + model.forward(batch_2) + + batch_5 = pre_processed_single.repeat(5, 1, 1, 1) + model.forward(batch_5) + + assert trt_cuda_graph_cache.get_current_size() == 3 + key_3 = ( + tuple(pre_processed_single.repeat(3, 1, 1, 1).shape), + batch_2.dtype, + device, + ) + remaining_keys = trt_cuda_graph_cache.list_keys() + assert key_3 not in remaining_keys + + key_2 = (tuple(batch_2.shape), batch_2.dtype, device) + key_5 = (tuple(batch_5.shape), batch_5.dtype, device) + assert remaining_keys == [key_4, key_2, key_5] + + no_graph = model.forward(batch_5, disable_cuda_graphs=True) + replay = model.forward(batch_5) + assert torch.allclose(no_graph, replay, atol=1e-6) diff --git a/inference_models/uv.lock b/inference_models/uv.lock index f539a595de..4708931782 100644 --- a/inference_models/uv.lock +++ b/inference_models/uv.lock @@ -916,7 +916,7 @@ wheels = [ [[package]] name = "inference-models" -version = "0.20.2" +version = "0.21.0" source = { virtual = "." } dependencies = [ { name = "accelerate" }, diff --git a/requirements/_requirements.txt b/requirements/_requirements.txt index 6c47cf8365..18df5cb1cf 100644 --- a/requirements/_requirements.txt +++ b/requirements/_requirements.txt @@ -50,4 +50,4 @@ filelock>=3.12.0,<=3.17.0 onvif-zeep-async==2.0.0 # versions > 2.0.0 will not work with Python 3.9 despite docs simple-pid~=2.0.1 qrcode~=8.0.0 -inference-models~=0.20.2 +inference-models~=0.21.0