Skip to content

Commit 7a28ca9

Browse files
Merge pull request #512 from xlabd/patch-1
Proposed changes in the readme.md
2 parents 3a3aaab + 34a52de commit 7a28ca9

File tree

8 files changed

+1462
-9
lines changed

8 files changed

+1462
-9
lines changed
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import org.opencv.core.Core;
2+
import org.opencv.core.Mat;
3+
import org.opencv.core.Rect;
4+
import org.opencv.core.Scalar;
5+
import org.opencv.core.Size;
6+
import org.opencv.dnn.Net;
7+
import org.opencv.dnn.Dnn;
8+
import org.opencv.imgproc.Imgproc;
9+
import org.opencv.imgcodecs.Imgcodecs;
10+
11+
import java.io.IOException;
12+
import java.util.ArrayList;
13+
import java.util.stream.Collectors;
14+
import java.util.stream.Stream;
15+
import java.nio.file.Files;
16+
import java.nio.file.Paths;
17+
18+
import org.opencv.core.CvType;
19+
20+
21+
public class DnnOpenCV {
22+
private static final int TARGET_IMG_WIDTH = 224;
23+
private static final int TARGET_IMG_HEIGHT = 224;
24+
25+
private static final double SCALE_FACTOR = 1 / 255.0;
26+
27+
private static final String IMAGENET_CLASSES = "imagenet_classes.txt";
28+
private static final String MODEL_PATH = "models/pytorch_mobilenet.onnx";
29+
30+
private static final Scalar MEAN = new Scalar(0.485, 0.456, 0.406);
31+
private static final Scalar STD = new Scalar(0.229, 0.224, 0.225);
32+
33+
public static ArrayList<String> getImgLabels(String imgLabelsFilePath) throws IOException {
34+
ArrayList<String> imgLabels;
35+
try (Stream<String> lines = Files.lines(Paths.get(imgLabelsFilePath))) {
36+
imgLabels = lines.collect(Collectors.toCollection(ArrayList::new));
37+
}
38+
return imgLabels;
39+
}
40+
41+
public static Mat centerCrop(Mat inputImage) {
42+
int y1 = Math.round((inputImage.rows() - TARGET_IMG_HEIGHT) / 2);
43+
int y2 = Math.round(y1 + TARGET_IMG_HEIGHT);
44+
int x1 = Math.round((inputImage.cols() - TARGET_IMG_WIDTH) / 2);
45+
int x2 = Math.round(x1 + TARGET_IMG_WIDTH);
46+
47+
Rect centerRect = new Rect(x1, y1, (x2 - x1), (y2 - y1));
48+
Mat croppedImage = new Mat(inputImage, centerRect);
49+
50+
return croppedImage;
51+
}
52+
53+
public static Mat getPreprocessedImage(String imagePath) {
54+
// get the image from the internal resource folder
55+
Mat image = Imgcodecs.imread(imagePath);
56+
57+
// resize input image
58+
Imgproc.resize(image, image, new Size(256, 256));
59+
60+
// create empty Mat images for float conversions
61+
Mat imgFloat = new Mat(image.rows(), image.cols(), CvType.CV_32FC3);
62+
63+
// convert input image to float type
64+
image.convertTo(imgFloat, CvType.CV_32FC3, SCALE_FACTOR);
65+
66+
// crop input image
67+
imgFloat = centerCrop(imgFloat);
68+
69+
// prepare DNN input
70+
Mat blob = Dnn.blobFromImage(
71+
imgFloat,
72+
1.0, /* default scalefactor */
73+
new Size(TARGET_IMG_WIDTH, TARGET_IMG_HEIGHT), /* target size */
74+
MEAN, /* mean */
75+
true, /* swapRB */
76+
false /* crop */
77+
);
78+
79+
// divide on std
80+
Core.divide(blob, STD, blob);
81+
82+
return blob;
83+
}
84+
85+
public static String getPredictedClass(Mat classificationResult) {
86+
ArrayList<String> imgLabels = new ArrayList<String>();
87+
try {
88+
imgLabels = getImgLabels(IMAGENET_CLASSES);
89+
} catch (IOException ex) {
90+
System.out.printf("Could not read %s file:%n", IMAGENET_CLASSES);
91+
ex.printStackTrace();
92+
}
93+
if (imgLabels.isEmpty()) {
94+
return "";
95+
}
96+
// obtain max prediction result
97+
Core.MinMaxLocResult mm = Core.minMaxLoc(classificationResult);
98+
double maxValIndex = mm.maxLoc.x;
99+
return imgLabels.get((int) maxValIndex);
100+
}
101+
102+
public static void main(String[] args) {
103+
String imageLocation = "images/coffee.jpg";
104+
105+
// load the OpenCV native library
106+
System.loadLibrary(Core.NATIVE_LIBRARY_NAME);
107+
108+
// read and process the input image
109+
Mat inputBlob = DnnOpenCV.getPreprocessedImage(imageLocation);
110+
111+
// read generated ONNX model into org.opencv.dnn.Net object
112+
Net dnnNet = Dnn.readNetFromONNX(DnnOpenCV.MODEL_PATH);
113+
System.out.println("DNN from ONNX was successfully loaded!");
114+
115+
// set OpenCV model input
116+
dnnNet.setInput(inputBlob);
117+
118+
// provide inference
119+
Mat classification = dnnNet.forward();
120+
121+
// decode classification results
122+
String label = DnnOpenCV.getPredictedClass(classification);
123+
System.out.println("Predicted Class: " + label);
124+
}
125+
}
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
import argparse
2+
import os
3+
4+
import cv2
5+
import numpy as np
6+
import onnx
7+
import onnxruntime
8+
import torch
9+
from albumentations import (
10+
CenterCrop,
11+
Compose,
12+
Normalize,
13+
Resize,
14+
)
15+
from torchvision import models
16+
17+
18+
def compare_pytorch_onnx(
19+
original_model_preds, onnx_model_path, input_image,
20+
):
21+
# get onnx result
22+
session = onnxruntime.InferenceSession(onnx_model_path)
23+
input_name = session.get_inputs()[0].name
24+
onnx_result = session.run([], {input_name: input_image})
25+
onnx_result = np.squeeze(onnx_result, axis=0)
26+
27+
print("Checking PyTorch model and converted ONNX model outputs ... ")
28+
for test_onnx_result, gold_result in zip(onnx_result, original_model_preds):
29+
np.testing.assert_almost_equal(
30+
gold_result, test_onnx_result, decimal=3,
31+
)
32+
print("PyTorch and ONNX output values are equal! \n")
33+
34+
35+
def get_onnx_model(
36+
original_model, input_image, model_path="models", model_name="pytorch_mobilenet",
37+
):
38+
# create model root dir
39+
os.makedirs(model_path, exist_ok=True)
40+
41+
model_name = os.path.join(model_path, model_name + ".onnx")
42+
43+
torch.onnx.export(
44+
original_model, torch.Tensor(input_image), model_name, verbose=True,
45+
)
46+
print("ONNX model was successfully generated: {} \n".format(model_name))
47+
48+
return model_name
49+
50+
51+
def get_preprocessed_image(image_name):
52+
# read image
53+
original_image = cv2.imread(image_name)
54+
55+
# convert original image to RGB format
56+
image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
57+
58+
# transform input image:
59+
# 1. resize the image
60+
# 2. crop the image
61+
# 3. normalize: subtract mean and divide by standard deviation
62+
transform = Compose(
63+
[
64+
Resize(height=256, width=256),
65+
CenterCrop(224, 224),
66+
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
67+
],
68+
)
69+
image = transform(image=image)["image"]
70+
71+
# change the order of channels
72+
image = image.transpose(2, 0, 1)
73+
return np.expand_dims(image, axis=0)
74+
75+
76+
def get_predicted_class(pytorch_preds):
77+
# read ImageNet class id to name mapping
78+
with open("imagenet_classes.txt") as f:
79+
labels = [line.strip() for line in f.readlines()]
80+
81+
# find the class with the maximum score
82+
pytorch_class_idx = np.argmax(pytorch_preds, axis=1)
83+
predicted_pytorch_label = labels[pytorch_class_idx[0]]
84+
85+
# print top predicted class
86+
print("Predicted class by PyTorch model: ", predicted_pytorch_label)
87+
88+
89+
def get_execution_arguments():
90+
parser = argparse.ArgumentParser()
91+
parser.add_argument(
92+
"--input_image",
93+
type=str,
94+
help="Define the full input image path, including its name",
95+
default="images/coffee.jpg",
96+
)
97+
return parser.parse_args()
98+
99+
100+
if __name__ == "__main__":
101+
# get the test case parameters
102+
args = get_execution_arguments()
103+
104+
# read and process the input image
105+
image = get_preprocessed_image(image_name=args.input_image)
106+
107+
# obtain original model
108+
pytorch_model = models.mobilenet_v2(pretrained=True)
109+
110+
# provide inference of the original PyTorch model
111+
pytorch_model.eval()
112+
pytorch_predictions = pytorch_model(torch.Tensor(image)).detach().numpy()
113+
114+
# obtain OpenCV generated ONNX model
115+
onnx_model_path = get_onnx_model(original_model=pytorch_model, input_image=image)
116+
117+
# check if conversion succeeded
118+
onnx_model = onnx.load(onnx_model_path)
119+
onnx.checker.check_model(onnx_model)
120+
121+
# check onnx model output
122+
compare_pytorch_onnx(
123+
pytorch_predictions, onnx_model_path, image,
124+
)
125+
126+
# decode classification results
127+
get_predicted_class(pytorch_preds=pytorch_predictions)

0 commit comments

Comments
 (0)