|
52 | 52 | "metadata": {}, |
53 | 53 | "outputs": [], |
54 | 54 | "source": [ |
55 | | - "model = smp.Unet('resnet34', encoder_weights='imagenet', classes=1)\n", |
| 55 | + "model = smp.Unet(\"resnet34\", encoder_weights=\"imagenet\", classes=1)\n", |
56 | 56 | "model = model.eval()" |
57 | 57 | ] |
58 | 58 | }, |
|
70 | 70 | "outputs": [], |
71 | 71 | "source": [ |
72 | 72 | "# dynamic_axes is used to specify the variable length axes. it can be just batch size\n", |
73 | | - "dynamic_axes = {0: 'batch_size', 2: \"height\", 3: \"width\"}\n", |
| 73 | + "dynamic_axes = {0: \"batch_size\", 2: \"height\", 3: \"width\"}\n", |
74 | 74 | "\n", |
75 | | - "onnx_model_name = 'unet_resnet34.onnx'\n", |
| 75 | + "onnx_model_name = \"unet_resnet34.onnx\"\n", |
76 | 76 | "\n", |
77 | 77 | "onnx_model = torch.onnx.export(\n", |
78 | | - " model, # model being run\n", |
79 | | - " torch.randn(1, 3, 224, 224), # model input\n", |
80 | | - " onnx_model_name, # where to save the model (can be a file or file-like object) \n", |
81 | | - " export_params=True, # store the trained parameter weights inside the model file\n", |
82 | | - " opset_version=17, # the ONNX version to export\n", |
83 | | - " do_constant_folding=True, # whether to execute constant folding for optimization\n", |
84 | | - " input_names=['input'], # the model's input names\n", |
85 | | - " output_names=['output'], # the model's output names\n", |
86 | | - " dynamic_axes={ # variable length axes\n", |
87 | | - " 'input': dynamic_axes,\n", |
88 | | - " 'output': dynamic_axes\n", |
89 | | - " }\n", |
| 78 | + " model, # model being run\n", |
| 79 | + " torch.randn(1, 3, 224, 224), # model input\n", |
| 80 | + " onnx_model_name, # where to save the model (can be a file or file-like object)\n", |
| 81 | + " export_params=True, # store the trained parameter weights inside the model file\n", |
| 82 | + " opset_version=17, # the ONNX version to export\n", |
| 83 | + " do_constant_folding=True, # whether to execute constant folding for optimization\n", |
| 84 | + " input_names=[\"input\"], # the model's input names\n", |
| 85 | + " output_names=[\"output\"], # the model's output names\n", |
| 86 | + " dynamic_axes={ # variable length axes\n", |
| 87 | + " \"input\": dynamic_axes,\n", |
| 88 | + " \"output\": dynamic_axes,\n", |
| 89 | + " },\n", |
90 | 90 | ")" |
91 | 91 | ] |
92 | 92 | }, |
|
153 | 153 | } |
154 | 154 | ], |
155 | 155 | "source": [ |
156 | | - "# create sample with different batch size, height and width \n", |
| 156 | + "# create sample with different batch size, height and width\n", |
157 | 157 | "# from what we used in export above\n", |
158 | | - "sample = torch.randn(2, 3, 512, 512) \n", |
| 158 | + "sample = torch.randn(2, 3, 512, 512)\n", |
159 | 159 | "\n", |
160 | | - "ort_session = onnxruntime.InferenceSession(onnx_model_name, providers=[\"CPUExecutionProvider\"])\n", |
| 160 | + "ort_session = onnxruntime.InferenceSession(\n", |
| 161 | + " onnx_model_name, providers=[\"CPUExecutionProvider\"]\n", |
| 162 | + ")\n", |
161 | 163 | "\n", |
162 | 164 | "# compute ONNX Runtime output prediction\n", |
163 | 165 | "ort_inputs = {\"input\": sample.numpy()}\n", |
|
0 commit comments