Skip to content

Commit 7436507

Browse files
committed
Add notebook with example
1 parent 1fe4964 commit 7436507

File tree

1 file changed

+221
-0
lines changed

1 file changed

+221
-0
lines changed

examples/convert_to_onnx.ipynb

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/qubvel/segmentation_models.pytorch/blob/main/examples/convert_to_onnx.ipynb)"
8+
]
9+
},
10+
{
11+
"cell_type": "code",
12+
"execution_count": null,
13+
"metadata": {},
14+
"outputs": [],
15+
"source": [
16+
"# to make onnx export work\n",
17+
"!pip install onnx onnxruntime"
18+
]
19+
},
20+
{
21+
"cell_type": "markdown",
22+
"metadata": {},
23+
"source": [
24+
"See complete tutorial in Pytorch docs:\n",
25+
" - https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html"
26+
]
27+
},
28+
{
29+
"cell_type": "code",
30+
"execution_count": 1,
31+
"metadata": {},
32+
"outputs": [],
33+
"source": [
34+
"import onnx\n",
35+
"import onnxruntime\n",
36+
"import numpy as np\n",
37+
"\n",
38+
"import torch\n",
39+
"import segmentation_models_pytorch as smp"
40+
]
41+
},
42+
{
43+
"cell_type": "markdown",
44+
"metadata": {},
45+
"source": [
46+
"### Create random model (or load your own model)"
47+
]
48+
},
49+
{
50+
"cell_type": "code",
51+
"execution_count": 2,
52+
"metadata": {},
53+
"outputs": [],
54+
"source": [
55+
"model = smp.Unet('resnet34', encoder_weights='imagenet', classes=1)\n",
56+
"model = model.eval()"
57+
]
58+
},
59+
{
60+
"cell_type": "markdown",
61+
"metadata": {},
62+
"source": [
63+
"### Export the model to ONNX"
64+
]
65+
},
66+
{
67+
"cell_type": "code",
68+
"execution_count": 3,
69+
"metadata": {},
70+
"outputs": [],
71+
"source": [
72+
"# dynamic_axes is used to specify the variable length axes. it can be just batch size\n",
73+
"dynamic_axes = {0: 'batch_size', 2: \"height\", 3: \"width\"}\n",
74+
"\n",
75+
"onnx_model_name = 'unet_resnet34.onnx'\n",
76+
"\n",
77+
"onnx_model = torch.onnx.export(\n",
78+
" model, # model being run\n",
79+
" torch.randn(1, 3, 224, 224), # model input\n",
80+
" onnx_model_name, # where to save the model (can be a file or file-like object) \n",
81+
" export_params=True, # store the trained parameter weights inside the model file\n",
82+
" opset_version=17, # the ONNX version to export\n",
83+
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
84+
" input_names=['input'], # the model's input names\n",
85+
" output_names=['output'], # the model's output names\n",
86+
" dynamic_axes={ # variable length axes\n",
87+
" 'input': dynamic_axes,\n",
88+
" 'output': dynamic_axes\n",
89+
" }\n",
90+
")"
91+
]
92+
},
93+
{
94+
"cell_type": "code",
95+
"execution_count": 4,
96+
"metadata": {},
97+
"outputs": [],
98+
"source": [
99+
"# check with onnx first\n",
100+
"onnx_model = onnx.load(onnx_model_name)\n",
101+
"onnx.checker.check_model(onnx_model)"
102+
]
103+
},
104+
{
105+
"cell_type": "markdown",
106+
"metadata": {},
107+
"source": [
108+
"### Run with onnxruntime"
109+
]
110+
},
111+
{
112+
"cell_type": "code",
113+
"execution_count": 5,
114+
"metadata": {},
115+
"outputs": [
116+
{
117+
"data": {
118+
"text/plain": [
119+
"[array([[[[-1.41701847e-01, -4.63768840e-03, 1.21411584e-01, ...,\n",
120+
" 5.22197843e-01, 3.40217263e-01, 8.52423906e-02],\n",
121+
" [-2.29843616e-01, 2.19401851e-01, 3.53053480e-01, ...,\n",
122+
" 2.79466838e-01, 3.20288718e-01, -2.22393833e-02],\n",
123+
" [-3.12503517e-01, -3.66358161e-02, 1.19251609e-02, ...,\n",
124+
" -5.48991561e-02, 3.71140465e-02, -1.82842150e-01],\n",
125+
" ...,\n",
126+
" [-3.02772015e-01, -4.22928065e-01, -1.49621412e-01, ...,\n",
127+
" -1.42241001e-01, -9.90390778e-02, -1.33311331e-01],\n",
128+
" [-1.08293816e-01, -1.28070369e-01, -5.43620177e-02, ...,\n",
129+
" -8.64556879e-02, -1.74177170e-01, 6.03154302e-03],\n",
130+
" [-1.29619062e-01, -2.96604559e-02, -2.86361389e-03, ...,\n",
131+
" -1.91345289e-01, -1.82653710e-01, 1.17175849e-02]]],\n",
132+
" \n",
133+
" \n",
134+
" [[[-6.16237633e-02, 1.12350248e-01, 1.59193069e-01, ...,\n",
135+
" 4.03313845e-01, 2.26862252e-01, 7.33022243e-02],\n",
136+
" [-1.60109222e-01, 1.21696621e-01, 1.84655115e-01, ...,\n",
137+
" 1.20978586e-01, 2.45723248e-01, 1.00066036e-01],\n",
138+
" [-2.11992145e-01, 1.71708465e-02, -1.57656223e-02, ...,\n",
139+
" -1.11918494e-01, -1.64519548e-01, -1.73958957e-01],\n",
140+
" ...,\n",
141+
" [-2.79706120e-01, -2.87421644e-01, -5.19880295e-01, ...,\n",
142+
" -8.30744207e-02, -3.48939300e-02, 1.26617640e-01],\n",
143+
" [-2.62198627e-01, -2.91804910e-01, -2.82318443e-01, ...,\n",
144+
" 1.81179233e-02, 2.32534595e-02, 1.85002953e-01],\n",
145+
" [-9.28771719e-02, -5.16399741e-05, -9.53909755e-03, ...,\n",
146+
" -2.28582099e-02, -5.09671569e-02, 2.05268264e-02]]]],\n",
147+
" dtype=float32)]"
148+
]
149+
},
150+
"execution_count": 5,
151+
"metadata": {},
152+
"output_type": "execute_result"
153+
}
154+
],
155+
"source": [
156+
"# create sample with different batch size, height and width \n",
157+
"# from what we used in export above\n",
158+
"sample = torch.randn(2, 3, 512, 512) \n",
159+
"\n",
160+
"ort_session = onnxruntime.InferenceSession(onnx_model_name, providers=[\"CPUExecutionProvider\"])\n",
161+
"\n",
162+
"# compute ONNX Runtime output prediction\n",
163+
"ort_inputs = {\"input\": sample.numpy()}\n",
164+
"ort_outputs = ort_session.run(output_names=None, input_feed=ort_inputs)\n",
165+
"ort_outputs"
166+
]
167+
},
168+
{
169+
"cell_type": "markdown",
170+
"metadata": {},
171+
"source": [
172+
"### Verify it's the same as for pytorch model"
173+
]
174+
},
175+
{
176+
"cell_type": "code",
177+
"execution_count": 6,
178+
"metadata": {},
179+
"outputs": [
180+
{
181+
"name": "stdout",
182+
"output_type": "stream",
183+
"text": [
184+
"Exported model has been tested with ONNXRuntime, and the result looks good!\n"
185+
]
186+
}
187+
],
188+
"source": [
189+
"# compute PyTorch output prediction\n",
190+
"with torch.no_grad():\n",
191+
" torch_out = model(sample)\n",
192+
"\n",
193+
"# compare ONNX Runtime and PyTorch results\n",
194+
"np.testing.assert_allclose(torch_out.numpy(), ort_outputs[0], rtol=1e-03, atol=1e-05)\n",
195+
"\n",
196+
"print(\"Exported model has been tested with ONNXRuntime, and the result looks good!\")"
197+
]
198+
}
199+
],
200+
"metadata": {
201+
"kernelspec": {
202+
"display_name": ".venv",
203+
"language": "python",
204+
"name": "python3"
205+
},
206+
"language_info": {
207+
"codemirror_mode": {
208+
"name": "ipython",
209+
"version": 3
210+
},
211+
"file_extension": ".py",
212+
"mimetype": "text/x-python",
213+
"name": "python",
214+
"nbconvert_exporter": "python",
215+
"pygments_lexer": "ipython3",
216+
"version": "3.10.12"
217+
}
218+
},
219+
"nbformat": 4,
220+
"nbformat_minor": 2
221+
}

0 commit comments

Comments
 (0)