|
1 | 1 | import cv2
|
2 | 2 | import onnx
|
3 | 3 | import torch
|
4 |
| -from albumentations import ( |
5 |
| - Compose, |
6 |
| - Resize, |
7 |
| -) |
| 4 | +from albumentations import (Compose,Resize,) |
8 | 5 | from albumentations.augmentations.transforms import Normalize
|
9 | 6 | from albumentations.pytorch.transforms import ToTensor
|
10 | 7 | from torchvision import models
|
11 | 8 |
|
12 |
| -# load pre-trained model ------------------------------------------------------ |
13 |
| -model = models.resnet50(pretrained=True) |
14 | 9 |
|
15 |
| -# preprocessing stage --------------------------------------------------------- |
16 |
| -# transformations for the input data |
17 |
| -transforms = Compose( |
18 |
| - [ |
| 10 | +def preprocess_image(img_path): |
| 11 | + # transformations for the input data |
| 12 | + transforms = Compose([ |
19 | 13 | Resize(224, 224, interpolation=cv2.INTER_NEAREST),
|
20 | 14 | Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
21 | 15 | ToTensor(),
|
22 |
| - ], |
23 |
| -) |
24 |
| - |
25 |
| -# read input image |
26 |
| -input_img = cv2.imread("turkish_coffee.jpg") |
27 |
| -# do transformations |
28 |
| -input_data = transforms(image=input_img)["image"] |
29 |
| -# prepare batch |
30 |
| -batch_data = torch.unsqueeze(input_data, 0).cuda() |
31 |
| - |
32 |
| -# inference stage ------------------------------------------------------------- |
33 |
| -model.eval() |
34 |
| -model.cuda() |
35 |
| -output_data = model(batch_data) |
36 |
| - |
37 |
| -# post-processing stage ------------------------------------------------------- |
38 |
| -# get class names |
39 |
| -with open("imagenet_classes.txt") as f: |
40 |
| - classes = [line.strip() for line in f.readlines()] |
41 |
| -# calculate human-readable value by softmax |
42 |
| -confidences = torch.nn.functional.softmax(output_data, dim=1)[0] * 100 |
43 |
| -# find top predicted classes |
44 |
| -_, indices = torch.sort(output_data, descending=True) |
45 |
| -i = 0 |
46 |
| -# print the top classes predicted by the model |
47 |
| -while confidences[indices[0][i]] > 0.5: |
48 |
| - class_idx = indices[0][i] |
49 |
| - print( |
50 |
| - "class:", |
51 |
| - classes[class_idx], |
52 |
| - ", confidence:", |
53 |
| - confidences[class_idx].item(), |
54 |
| - "%, index:", |
55 |
| - class_idx.item(), |
56 |
| - ) |
57 |
| - i += 1 |
58 |
| - |
59 |
| -# convert to ONNX ------------------------------------------------------------- |
60 |
| -onnx_filename = "resnet50.onnx" |
61 |
| -torch.onnx.export( |
62 |
| - model, |
63 |
| - batch_data, |
64 |
| - onnx_filename, |
65 |
| - input_names=["input"], |
66 |
| - output_names=["output"], |
67 |
| - export_params=True, |
68 |
| -) |
69 |
| - |
70 |
| -onnx_model = onnx.load(onnx_filename) |
71 |
| -# check that the model converted fine |
72 |
| -onnx.checker.check_model(onnx_model) |
73 |
| - |
74 |
| -print("Model was successfully converted to ONNX format.") |
75 |
| -print("It was saved to", onnx_filename) |
| 16 | + ]) |
| 17 | + |
| 18 | + # read input image |
| 19 | + input_img = cv2.imread(img_path) |
| 20 | + # do transformations |
| 21 | + input_data = transforms(image=input_img)["image"] |
| 22 | + # prepare batch |
| 23 | + batch_data = torch.unsqueeze(input_data, 0) |
| 24 | + |
| 25 | + return batch_data |
| 26 | + |
| 27 | + |
| 28 | +def postprocess(output_data): |
| 29 | + # get class names |
| 30 | + with open("imagenet_classes.txt") as f: |
| 31 | + classes = [line.strip() for line in f.readlines()] |
| 32 | + # calculate human-readable value by softmax |
| 33 | + confidences = torch.nn.functional.softmax(output_data, dim=1)[0] * 100 |
| 34 | + # find top predicted classes |
| 35 | + _, indices = torch.sort(output_data, descending=True) |
| 36 | + i = 0 |
| 37 | + # print the top classes predicted by the model |
| 38 | + while confidences[indices[0][i]] > 0.5: |
| 39 | + class_idx = indices[0][i] |
| 40 | + print( |
| 41 | + "class:", |
| 42 | + classes[class_idx], |
| 43 | + ", confidence:", |
| 44 | + confidences[class_idx].item(), |
| 45 | + "%, index:", |
| 46 | + class_idx.item(), |
| 47 | + ) |
| 48 | + i += 1 |
| 49 | + |
| 50 | + |
| 51 | +def main(): |
| 52 | + # load pre-trained model ------------------------------------------------------------------------------------------- |
| 53 | + model = models.resnet50(pretrained=True) |
| 54 | + |
| 55 | + # preprocessing stage ---------------------------------------------------------------------------------------------- |
| 56 | + input = preprocess_image("turkish_coffee.jpg").cuda() |
| 57 | + |
| 58 | + # inference stage -------------------------------------------------------------------------------------------------- |
| 59 | + model.eval() |
| 60 | + model.cuda() |
| 61 | + output = model(input) |
| 62 | + |
| 63 | + # post-processing stage -------------------------------------------------------------------------------------------- |
| 64 | + postprocess(output) |
| 65 | + |
| 66 | + # convert to ONNX -------------------------------------------------------------------------------------------------- |
| 67 | + ONNX_FILE_PATH = "resnet50.onnx" |
| 68 | + torch.onnx.export(model, input, ONNX_FILE_PATH, input_names=["input"], output_names=["output"], export_params=True) |
| 69 | + |
| 70 | + onnx_model = onnx.load(ONNX_FILE_PATH) |
| 71 | + # check that the model converted fine |
| 72 | + onnx.checker.check_model(onnx_model) |
| 73 | + |
| 74 | + print("Model was successfully converted to ONNX format.") |
| 75 | + print("It was saved to", ONNX_FILE_PATH) |
| 76 | + |
| 77 | + |
| 78 | +if __name__ == '__main__': |
| 79 | + main() |
0 commit comments