Skip to content

Commit b2f0ab2

Browse files
committed
Fixup
1 parent dbf2a1e commit b2f0ab2

File tree

1 file changed

+20
-18
lines changed

1 file changed

+20
-18
lines changed

examples/convert_to_onnx.ipynb

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
"metadata": {},
5353
"outputs": [],
5454
"source": [
55-
"model = smp.Unet('resnet34', encoder_weights='imagenet', classes=1)\n",
55+
"model = smp.Unet(\"resnet34\", encoder_weights=\"imagenet\", classes=1)\n",
5656
"model = model.eval()"
5757
]
5858
},
@@ -70,23 +70,23 @@
7070
"outputs": [],
7171
"source": [
7272
"# dynamic_axes is used to specify the variable length axes. it can be just batch size\n",
73-
"dynamic_axes = {0: 'batch_size', 2: \"height\", 3: \"width\"}\n",
73+
"dynamic_axes = {0: \"batch_size\", 2: \"height\", 3: \"width\"}\n",
7474
"\n",
75-
"onnx_model_name = 'unet_resnet34.onnx'\n",
75+
"onnx_model_name = \"unet_resnet34.onnx\"\n",
7676
"\n",
7777
"onnx_model = torch.onnx.export(\n",
78-
" model, # model being run\n",
79-
" torch.randn(1, 3, 224, 224), # model input\n",
80-
" onnx_model_name, # where to save the model (can be a file or file-like object) \n",
81-
" export_params=True, # store the trained parameter weights inside the model file\n",
82-
" opset_version=17, # the ONNX version to export\n",
83-
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
84-
" input_names=['input'], # the model's input names\n",
85-
" output_names=['output'], # the model's output names\n",
86-
" dynamic_axes={ # variable length axes\n",
87-
" 'input': dynamic_axes,\n",
88-
" 'output': dynamic_axes\n",
89-
" }\n",
78+
" model, # model being run\n",
79+
" torch.randn(1, 3, 224, 224), # model input\n",
80+
" onnx_model_name, # where to save the model (can be a file or file-like object)\n",
81+
" export_params=True, # store the trained parameter weights inside the model file\n",
82+
" opset_version=17, # the ONNX version to export\n",
83+
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
84+
" input_names=[\"input\"], # the model's input names\n",
85+
" output_names=[\"output\"], # the model's output names\n",
86+
" dynamic_axes={ # variable length axes\n",
87+
" \"input\": dynamic_axes,\n",
88+
" \"output\": dynamic_axes,\n",
89+
" },\n",
9090
")"
9191
]
9292
},
@@ -153,11 +153,13 @@
153153
}
154154
],
155155
"source": [
156-
"# create sample with different batch size, height and width \n",
156+
"# create sample with different batch size, height and width\n",
157157
"# from what we used in export above\n",
158-
"sample = torch.randn(2, 3, 512, 512) \n",
158+
"sample = torch.randn(2, 3, 512, 512)\n",
159159
"\n",
160-
"ort_session = onnxruntime.InferenceSession(onnx_model_name, providers=[\"CPUExecutionProvider\"])\n",
160+
"ort_session = onnxruntime.InferenceSession(\n",
161+
" onnx_model_name, providers=[\"CPUExecutionProvider\"]\n",
162+
")\n",
161163
"\n",
162164
"# compute ONNX Runtime output prediction\n",
163165
"ort_inputs = {\"input\": sample.numpy()}\n",

0 commit comments

Comments
 (0)