This is a handwritten TensorRT implementation of the Vision Transformersarxiv.org.2010.11929 paper.
Note:
- Swi-GeLU activation layer is supported since TensorRT 10.0+ SDK, we can use a approximation way as TensorRT does, check below for details.
- Support TensorRT SDK 8.5.1+ ~ 10.15.1+
- Support Windows11 OS
- Support native or self-implemented Swi-GeLU
- Support native or self-implemented multihead self-attention
- Support a dummy profiler by default
- Support a dummy output allocator by default
- Use optimization profile by default
- cannot use
IAttenionwith TensorRT SDK 10.14 ~ 10.15 because of the bugs in TensorRT - TensorRT < 8 is not supported because some ops are not inplemented in cuDNN
- SM < 86, TensorRT < 10, CUDA < 12 cases are NOT fully tested yet
- use
gen_wts.pyto generate.wtsfile.
python gen_wts.py- build C++ code
pushd tensorrtx/vit
cmake -S . -B build -G Ninja --fresh
cmake --build build- serialize
.wtsmodel to engine file.
./build/vit -s- run inference
./build/vit -dOn RTX 4080, TensorRT 10.15.1 SDK, the output looks like:
...
====
1880us
-1.125, 0.4623, -0.1215, -0.007384, -0.004307, -0.7021, -0.748, 0.2031, -0.4862, -0.008939, -1.151, -0.408, -0.3259, 0.2202, 0.04537, -2.008, -0.2832, 0.04394, 0.5326, 0.1724, 0.5655,
====
prediction result:
Top: 0 idx: 285, logits: 8.262, label: Egyptian cat
Top: 1 idx: 281, logits: 7.872, label: tabby, tabby cat
Top: 2 idx: 282, logits: 6.477, label: tiger cat
========== VisionTransformerProfiler ==========
TensorRT layer name Runtime, % Invocations Runtime, ms
Reformatting CopyNode for Input Tensor 0 to patch embedding 3.2% 20 0.95
patch embedding 1.5% 20 0.45
Reformatting CopyNode for Input Tensor 0 to {ForeignNode[(Unnamed Layer* 3) [Constant]...(Unnamed Layer* 518) [ElementWise]]} 0.2% 20 0.06
__myl_ReshTran_myl3_0 0.8% 20 0.24
__myl_ConcAddCastMeanSubMulMeanAddSqrtDivMulCastMulAdd_myl3_1 0.3% 20 0.08
vit.encoder.layer.0.attentionvalue+vit.encoder.layer.0.attentionkey+vit.encoder.layer.0.attentionquery_myl3_2 1.4% 20 0.40
__myl_TranReshMove_myl3_3 0.2% 20 0.06
__myl_TranReshMove_myl3_4 0.2% 20 0.07
__myl_TranReshMove_myl3_5 0.2% 20 0.06
_gemm_mha_v2_myl3_6 0.5% 20 0.14
__myl_MoveReshTran_myl3_7 0.2% 20 0.06
...
========== VisionTransformerProfiler total runtime = 29.67 ms ==========as is shown above, we successfully triggered the internal MHA fused kernel fusion pass inside TensorRT (i.e., "Myelin" or "myl" in short), especially the MHA fused kernel: _gemm_mha_v2_myl3_6.
ViTLayer() builds one ViT encoder block (Transformer encoder layer) using TensorRT primitives. The implementation corresponds to a Pre-LayerNorm Transformer layer (typical for ViT), including:
- LayerNorm before attention
- Multi-Head Self-Attention (MHSA): QKV projections → scaled dot-product attention → output projection
- Residual connection
- LayerNorm after attention
- Feed-Forward Network (FFN / MLP): dense → GeLU → dense
- Residual connection
The function returns the final residual output tensor.
Let the input tensor (TensorRT input) be:
Where:
- (N): batch size (represented by
Nin your code) - (L): sequence length (number of tokens; dynamic in code via
-1) - (D): hidden size, fixed at 768 in this implementation
The attention head configuration:
For a standard Transformer block:
- Q/K/V projection weights: $$ \mathbf{W}_Q, \mathbf{W}_K, \mathbf{W}_V \in \mathbb{R}^{D \times D} $$
- Q/K/V biases (NOTE:Not used by native nvidia interface): $$ \mathbf{b}_Q, \mathbf{b}_K, \mathbf{b}_V \in \mathbb{R}^{D} $$
- Output projection: $$ \mathbf{W}_O \in \mathbb{R}^{D \times D}, \quad \mathbf{b}_O \in \mathbb{R}^{D} $$
- FFN (MLP) with expansion ratio 4:
$$
\mathbf{W}_1 \in \mathbb{R}^{D \times 4D}, \ \mathbf{b}_1 \in \mathbb{R}^{4D}
$$
$$
\mathbf{W}_2 \in \mathbb{R}^{4D \times D}, \ \mathbf{b}_2 \in \mathbb{R}^{D}
$$
Here (
$4 D = 3072$ ).
Pre-LN Transformer Encoder Layer implements the following canonical computation:
The function returns (
LayerNorm is applied over the last dimension (D) (hidden size), independently for each ($(n, \ell)$) position.
For a token vector (
Where:
- (
$\gamma$ ) corresponds to.weight - (
$\beta$ ) corresponds to.bias - (
$\varepsilon = \tt{param.lnorm_eps}$ )
Let:
Compute:
Multi-head attention splits the hidden dimension (D) into (H) heads of size (d).
Starting from:
Reshape:
Transpose (swap axes to put heads first):
Same for (
Code:
q_s->setReshapeDimensions(Dims4{N, -1, H, d});
q_s->setSecondTranspose({0, 2, 1, 3}); // (N,H,L,d)For each batch (n) and head (h), define:
In tensor form:
Code:
qk = MatMul(q_s, NONE, k_s, TRANSPOSE); // (N,H,L,d) x (N,H,d,L) -> (N,H,L,L)Scaled dot-product uses:
Code:
scale_val = 1/sqrt(d);
attn_qk = qk * scale; // ElementWise PRODSoftmax is applied on the last dimension (keys index), for each query position, So:
Code:
qk_softmax = SoftMax(attn_qk);
qk_softmax->setAxes(1U << (nbDims-1)); // last axisEach head output:
Thus:
Code:
attn_qkv = MatMul(qk_softmax, NONE, v_s, NONE); // (N,H,L,L)x(N,H,L,d)->(N,H,L,d)Transpose back:
Then reshape:
Code:
attn_out->setFirstTranspose({0, 2, 1, 3}); // (N,L,H,d)
attn_out->setReshapeDimensions(Dims3{N, -1, 768}); // (N,L,D)Code:
attn_fcw = MatMul(attn_out, out_proj_w^T);
attn_fcb = attn_fcw + out_proj_b;Code:
attn_residual = input + attn_fcb;This identity path is crucial for gradient flow and stability; at inference time it preserves a “direct” signal path even if attention becomes sharp or noisy.
Code:
post_lnorm = Normalization(attn_residual, post_ln_scale, post_ln_bias)ViT uses a 2-layer MLP with expansion ratio 4 and GeLU activation.
Code:
inter0 = MatMul(post_lnorm, iw^T); // iw shape conceptually (4D, D)
inter1 = inter0 + ib;Where (\Phi) is the standard normal CDF.
Common tanh approximation (widely used in implementations):
Code calls:
inter_act = addGeLU(net, inter1);Code:
out0 = MatMul(inter_act, ow^T); // ow conceptually (D, 4D)
out1 = out0 + ob;Code:
output_residual = out1 + attn_residual;
return output_residual;Below is a shape trace aligned with the main operations (assuming dynamic (L)):
Input
Pre-LN
Q/K/V projections
Reshape + transpose to heads
Attention logits
Softmax weights
Head outputs
Merge heads
Output projection
Residual
Post-LN
FFN expand
FFN project
Final residual