Skip to content

Commit e4a6730

Browse files
author
David Pichler
committed
feat: add support for compiling onnx models with trtexec
1 parent 7d25e54 commit e4a6730

File tree

10 files changed

+1844
-2
lines changed

10 files changed

+1844
-2
lines changed

examples/trt_compile/.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
*.onnx
2+
model-repo

examples/trt_compile/README.md

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# TensorRT Compilation — ResNet-18 Image Classification
2+
3+
This example compiles a pretrained ResNet-18 ONNX model to a TensorRT engine during the build phase using `trt_compile`, then serves it on Triton Inference Server.
4+
5+
During build, `tsbk` will:
6+
7+
1. Download the ONNX model artifact from MLflow
8+
2. Compile it to a `.plan` file using `trtexec` with fp16 precision (via Docker or a Kubernetes Job)
9+
3. Set the backend to `tensorrt` in the generated `config.pbtxt`
10+
4. Cache the compiled engine locally so subsequent builds skip compilation
11+
12+
## Prerequisites
13+
14+
- Install example requirements:
15+
16+
```bash
17+
pip install -r requirements.txt
18+
```
19+
20+
- **Docker with GPU access** (for local compilation), or
21+
- **Kubernetes cluster with GPU nodes** + `TSBK_S3_PREFIX` env var set (for remote compilation)
22+
23+
## Setup
24+
25+
Export a pretrained ResNet-18 to ONNX and register it with MLflow:
26+
27+
```bash
28+
python create-model.py
29+
```
30+
31+
This exports `resnet18.onnx` with:
32+
- Input: `image``[batch, 3, 224, 224]` float32 (ImageNet-normalized RGB)
33+
- Output: `logits``[batch, 1000]` float32 (class scores)
34+
35+
## Build and Run (local GPU via Docker)
36+
37+
```bash
38+
python server.py --test
39+
```
40+
41+
This will:
42+
- Build the model repo, compiling the ONNX model to TensorRT with fp16 precision
43+
- Launch Triton server in a Docker container with GPU access
44+
- Run the MLflow registered input example as a test case
45+
- Stop the server
46+
47+
## Build and Run (remote GPU via Kubernetes)
48+
49+
If you don't have a local GPU but have access to a Kubernetes cluster with GPU nodes, pass `--gpu-name` to target a specific GPU type via Karpenter:
50+
51+
```bash
52+
export TSBK_S3_PREFIX=s3://your-bucket/tsbk-cache
53+
python server.py --test --gpu-name a10g
54+
```
55+
56+
The `--gpu-name` value maps to a Karpenter node selector (`karpenter.k8s.aws/instance-gpu-name`) so the compilation job is scheduled on the correct hardware.
57+
58+
## Build Only
59+
60+
```bash
61+
python server.py --build-only
62+
```
63+
64+
After building, the model repo will look like:
65+
66+
```
67+
model-repo/
68+
└── resnet18-trt/
69+
└── resnet18/
70+
├── config.pbtxt # backend: "tensorrt", max_batch_size: 8
71+
└── 1/
72+
└── model.plan # compiled TensorRT engine (fp16)
73+
```
74+
75+
## SDK Usage
76+
77+
The key addition compared to the standard SDK example is the `trt_compile` dict on the model version:
78+
79+
```python
80+
tsbk.TritonModel(
81+
max_batch_size=8,
82+
versions=[
83+
tsbk.TritonModelVersion(
84+
artifact_uri="models:/resnet18-imagenet/1",
85+
trt_compile={
86+
"enabled": True,
87+
"precision": "fp16", # optional: fp16, int8, best
88+
"workspace_size": 4096, # optional: max workspace in MB
89+
"gpu_name": "a10g", # optional: Karpenter GPU node selector for K8s
90+
"trt_image": "nvcr.io/...", # optional: override TRT container image
91+
"extra_args": "--verbose", # optional: raw trtexec flags
92+
},
93+
)
94+
],
95+
)
96+
```
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import mlflow
2+
import numpy as np
3+
import onnx
4+
import torch
5+
import torchvision.models as models
6+
7+
# Load a pretrained ResNet-18 for image classification
8+
resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
9+
resnet.eval()
10+
11+
# Export to ONNX with a batch dimension and standard ImageNet input shape
12+
# Input: [batch, 3, 224, 224] RGB image normalized to ImageNet stats
13+
# Output: [batch, 1000] class logits
14+
model_path = "resnet18.onnx"
15+
dummy_input = torch.randn(1, 3, 224, 224)
16+
torch.onnx.export(
17+
resnet,
18+
dummy_input,
19+
model_path,
20+
input_names=["image"],
21+
output_names=["logits"],
22+
dynamic_axes={
23+
"image": {0: "batch_size"},
24+
"logits": {0: "batch_size"},
25+
},
26+
opset_version=17,
27+
)
28+
29+
onnx_model = onnx.load(model_path)
30+
31+
# Log model to MLflow with a sample input (a random "image" tensor)
32+
with mlflow.start_run() as run:
33+
mlflow.onnx.log_model(
34+
onnx_model,
35+
artifact_path="resnet18",
36+
registered_model_name="resnet18-imagenet",
37+
input_example={"image": np.random.randn(1, 3, 224, 224).astype(np.float32)},
38+
)
39+
print(f"Model registered: models:/resnet18-imagenet/1 (run_id={run.info.run_id})")
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
tsbk
2+
onnx
3+
torch
4+
torchvision
5+
mlflow-skinny

examples/trt_compile/server.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import argparse
2+
3+
import tsbk
4+
5+
6+
def model_repo(model_repo_path: str, artifact_uri: str, gpu_name: str | None = None) -> tsbk.TritonModelRepo:
7+
"""Build a model repo that compiles a ResNet-18 ONNX model to TensorRT.
8+
9+
The trt_compile config on the version tells tsbk to:
10+
1. Download the ONNX model artifact
11+
2. Compile it to a TensorRT .plan file using trtexec (via Docker or K8s)
12+
3. Replace the .onnx with the .plan and set backend to tensorrt
13+
14+
The compiled engine is cached under TSBK_DIR/trt_engines/ — subsequent
15+
builds with the same ONNX model and compile params skip compilation.
16+
17+
When gpu_name is set, Kubernetes compilation uses it as a Karpenter node
18+
selector (karpenter.k8s.aws/instance-gpu-name) to schedule on the right
19+
GPU hardware. Requires TSBK_S3_PREFIX to be set for artifact transfer.
20+
"""
21+
trt_compile = {
22+
"enabled": True,
23+
"precision": "fp16",
24+
}
25+
if gpu_name:
26+
trt_compile["gpu_name"] = gpu_name
27+
28+
return tsbk.TritonModelRepo(
29+
"resnet18-trt",
30+
path=model_repo_path,
31+
models={
32+
"resnet18": tsbk.TritonModel(
33+
max_batch_size=8,
34+
versions=[
35+
tsbk.TritonModelVersion(
36+
artifact_uri=artifact_uri,
37+
trt_compile=trt_compile,
38+
)
39+
],
40+
)
41+
},
42+
)
43+
44+
45+
def main(args):
46+
repo = model_repo(args.model_repo, args.model_artifact_uri, gpu_name=args.gpu_name)
47+
repo.build()
48+
49+
if args.build_only:
50+
return
51+
52+
repo.run(detach=args.test, gpus=True)
53+
54+
if args.test:
55+
repo.test(url=repo.http_url)
56+
repo.stop()
57+
print("Tests passed!")
58+
return
59+
60+
61+
if __name__ == "__main__":
62+
parser = argparse.ArgumentParser(description="Build and serve a TensorRT-compiled ResNet-18 with tsbk")
63+
parser.add_argument(
64+
"--model_artifact_uri",
65+
type=str,
66+
default="models:/resnet18-imagenet/1",
67+
help="MLflow model URI for the ONNX ResNet-18",
68+
)
69+
parser.add_argument("--model-repo", type=str, default="./model-repo", help="Path to the model repository")
70+
parser.add_argument(
71+
"--build-only", action="store_true", help="Only build the model repository without starting the server"
72+
)
73+
parser.add_argument("--test", action="store_true", help="Run in test mode")
74+
parser.add_argument(
75+
"--gpu-name",
76+
type=str,
77+
default=None,
78+
help="Target GPU for compilation, used as Karpenter node selector (e.g. a10g, t4)",
79+
)
80+
args = parser.parse_args()
81+
82+
assert not (args.build_only and args.test), "Cannot use --build-only and --test together"
83+
84+
main(args)

src/tsbk/model.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,19 @@ def init(
206206
"Please specify a model backend via triton config or by passing the backend directly."
207207
)
208208

209+
# Override backend to tensorrt if any version has trt_compile enabled
210+
has_trt_compile = any(
211+
mv.trt_compile and mv.trt_compile.get("enabled")
212+
for mv in self.versions
213+
if hasattr(mv, "trt_compile") and mv.trt_compile
214+
)
215+
if has_trt_compile:
216+
if self.backend not in ["tensorrt", "onnxruntime"]:
217+
raise ValueError(
218+
"Must specify tensorrt or onnxruntime backend if trt_compile is enabled for any model version. "
219+
)
220+
self.backend = "tensorrt"
221+
209222
if self.python_version and self.backend not in {"python", "mlflow"}:
210223
raise ValueError("python_version can only be specified for python models")
211224

src/tsbk/model_version.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from tsbk.utils import link_or_copy
1111
from tsbk.utils.dbx import download_mlflow_model, get_input_example_from_model
1212
from tsbk.utils.s3 import download_s3_path, s3_path_exists
13+
from tsbk.utils.trtexec import _find_onnx_file, build_trt_engine
1314

1415

1516
class TritonModelVersion:
@@ -19,17 +20,20 @@ def __init__(
1920
python_model_file: str | None = None,
2021
version: int | None = None,
2122
test_cases: list[TestCase | dict] | None = None,
23+
trt_compile: dict | None = None,
2224
):
2325
"""A Triton model version.
2426
2527
Args:
2628
artifact_uri: The URI of the model artifact, which can be an MLflow model or an S3 object.
2729
python_model_file: The path to the Python model file, which is required for Python models.
2830
version: The version number of the model.
31+
trt_compile: Configuration for compiling ONNX models to TensorRT engines.
2932
"""
3033
self.artifact_uri = artifact_uri
3134
self.python_model_file = python_model_file
3235
self.version = version
36+
self.trt_compile = trt_compile
3337
self.test_cases = test_cases or []
3438
self.test_cases = [
3539
TestCase(**test_case) if isinstance(test_case, dict) else test_case for test_case in self.test_cases
@@ -145,8 +149,32 @@ def build(self) -> None:
145149
copy_func(dst_path=output_file_path)
146150

147151
case "tensorrt":
148-
output_file_path = self.path.joinpath("model.plan").as_posix()
149-
copy_func(dst_path=output_file_path)
152+
if self.trt_compile and self.trt_compile.get("enabled"):
153+
# Get the ONNX model from cache (no copy into version dir)
154+
if source == "mlflow":
155+
cached_model = copy_func()
156+
else:
157+
cached_model = Path(self.artifact_uri)
158+
if not cached_model.exists():
159+
copy_func(dst_path=cached_model)
160+
161+
# Find the .onnx file in the cached model and compile it
162+
onnx_path = _find_onnx_file(cached_model)
163+
plan_path = build_trt_engine(
164+
onnx_path=onnx_path,
165+
precision=self.trt_compile.get("precision"),
166+
workspace_size=self.trt_compile.get("workspace_size"),
167+
extra_args=self.trt_compile.get("extra_args"),
168+
trt_image=self.trt_compile.get("trt_image"),
169+
gpu_name=self.trt_compile.get("gpu_name"),
170+
)
171+
172+
# Only place the compiled plan in the version directory
173+
link_or_copy(plan_path, self.path.joinpath("model.plan"))
174+
else:
175+
# Standard tensorrt: just copy the .plan file
176+
output_file_path = self.path.joinpath("model.plan").as_posix()
177+
copy_func(dst_path=output_file_path)
150178

151179
case _:
152180
output_file_path = self.path.joinpath(self.artifact_uri.split("/")[-1])

src/tsbk/spec.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,23 @@ class TritonDTypeSpec(BaseModel):
4242
"""The dimensions of the Triton data type, which can be used to specify the shape of the input or output tensors."""
4343

4444

45+
class TrtCompileSpec(BaseModel):
46+
model_config = ConfigDict(extra="forbid")
47+
48+
enabled: bool = False
49+
"""Whether to compile ONNX models to TensorRT .plan files using trtexec."""
50+
trt_image: str | None = None
51+
"""Override the TensorRT container image. Defaults to nvcr.io/nvidia/tensorrt:{triton_version}-py3."""
52+
precision: str | None = None
53+
"""Precision mode for trtexec: 'fp16', 'int8', 'best', or None for default (fp32)."""
54+
workspace_size: int | None = None
55+
"""Max workspace size in MB for trtexec."""
56+
extra_args: str | None = None
57+
"""Additional raw trtexec CLI arguments as a string."""
58+
gpu_name: str | None = None
59+
"""Target GPU architecture for compilation. Used as Karpenter node selector for K8s scheduling (e.g., 'A10G', 'T4')."""
60+
61+
4562
class TritonModelVersionSpec(BaseModel):
4663
model_config = ConfigDict(extra="forbid")
4764

@@ -53,6 +70,8 @@ class TritonModelVersionSpec(BaseModel):
5370
"""The version number of the model"""
5471
test_cases: list[TestCaseSpec] | None = None
5572
"""A list of test cases for the model version, which can be used to validate the model's behavior."""
73+
trt_compile: TrtCompileSpec | None = None
74+
"""Configuration for compiling ONNX models to TensorRT engines during build."""
5675

5776

5877
class TritonModelSpec(BaseModel):

0 commit comments

Comments
 (0)