Skip to content

Commit bf3d8cc

Browse files
author
Anastasia
committed
Added code for image classification with OpenCV Java
1 parent 439c6dc commit bf3d8cc

File tree

8 files changed

+1456
-3
lines changed

8 files changed

+1456
-3
lines changed
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import org.opencv.core.Core;
2+
import org.opencv.core.Mat;
3+
import org.opencv.core.Rect;
4+
import org.opencv.core.Scalar;
5+
import org.opencv.core.Size;
6+
import org.opencv.dnn.Net;
7+
import org.opencv.dnn.Dnn;
8+
import org.opencv.imgproc.Imgproc;
9+
import org.opencv.imgcodecs.Imgcodecs;
10+
11+
import java.io.IOException;
12+
import java.util.ArrayList;
13+
import java.util.stream.Collectors;
14+
import java.util.stream.Stream;
15+
import java.nio.file.Files;
16+
import java.nio.file.Paths;
17+
18+
import org.opencv.core.CvType;
19+
20+
21+
public class DnnOpenCV {
22+
private static final int TARGET_IMG_WIDTH = 224;
23+
private static final int TARGET_IMG_HEIGHT = 224;
24+
25+
private static final double SCALE_FACTOR = 1 / 255.0;
26+
27+
private static final String IMAGENET_CLASSES = "imagenet_classes.txt";
28+
private static final String MODEL_PATH = "models/pytorch_mobilenet.onnx";
29+
30+
31+
public static ArrayList<String> getImgLabels(String imgLabelsFilePath) throws IOException {
32+
ArrayList<String> imgLabels;
33+
try (Stream<String> lines = Files.lines(Paths.get(imgLabelsFilePath))) {
34+
imgLabels = lines.collect(Collectors.toCollection(ArrayList::new));
35+
}
36+
return imgLabels;
37+
}
38+
39+
public static Mat centerCrop(Mat inputImage) {
40+
int y1 = Math.round((inputImage.rows() - TARGET_IMG_HEIGHT) / 2);
41+
int y2 = Math.round(y1 + TARGET_IMG_HEIGHT);
42+
int x1 = Math.round((inputImage.cols() - TARGET_IMG_WIDTH) / 2);
43+
int x2 = Math.round(x1 + TARGET_IMG_WIDTH);
44+
45+
Rect centerRect = new Rect(x1, y1, (x2 - x1), (y2 - y1));
46+
Mat croppedImage = new Mat(inputImage, centerRect);
47+
48+
return croppedImage;
49+
}
50+
51+
public static Mat getPreprocessedImage(String imagePath) {
52+
// define mean and standard deviation
53+
Scalar mean = new Scalar(0.485, 0.456, 0.406);
54+
Scalar std = new Scalar(0.229, 0.224, 0.225);
55+
56+
// get the image from the internal resource folder
57+
Mat image = Imgcodecs.imread(imagePath);
58+
59+
// resize input image
60+
Imgproc.resize(image, image, new Size(256, 256));
61+
62+
// create empty Mat images for float conversions
63+
Mat imgFloat = new Mat(image.rows(), image.cols(), CvType.CV_32FC3);
64+
65+
// convert input image to float type
66+
image.convertTo(imgFloat, CvType.CV_32FC3, SCALE_FACTOR);
67+
68+
// crop input image
69+
imgFloat = centerCrop(imgFloat);
70+
71+
// prepare DNN input
72+
Mat blob = Dnn.blobFromImage(
73+
imgFloat,
74+
1.0, /* default scalefactor */
75+
new Size(TARGET_IMG_WIDTH, TARGET_IMG_HEIGHT), /* target size */
76+
mean, /* mean */
77+
true, /* swapRB */
78+
false /* crop */
79+
);
80+
81+
// divide on std
82+
Core.divide(blob, std, blob);
83+
84+
return blob;
85+
}
86+
87+
public static void getPredictedClass(Mat classificationResult) {
88+
ArrayList<String> imgLabels = new ArrayList<String>();
89+
try {
90+
imgLabels = getImgLabels(IMAGENET_CLASSES);
91+
} catch (IOException ex) {
92+
93+
}
94+
// obtain max prediction result
95+
Core.MinMaxLocResult mm = Core.minMaxLoc(classificationResult);
96+
double maxValIndex = mm.maxLoc.x;
97+
System.out.println("Predicted Class: " + imgLabels.get((int) maxValIndex));
98+
}
99+
100+
public static void main(String[] args) {
101+
String imageLocation = "images/coffee.jpg";
102+
103+
// load the OpenCV native library
104+
System.loadLibrary(Core.NATIVE_LIBRARY_NAME);
105+
106+
// read and process the input image
107+
Mat inputBlob = DnnOpenCV.getPreprocessedImage(imageLocation);
108+
109+
// read generated ONNX model into org.opencv.dnn.Net object
110+
Net dnnNet = Dnn.readNetFromONNX(DnnOpenCV.MODEL_PATH);
111+
System.out.println("DNN from ONNX was successfully loaded!");
112+
113+
// set OpenCV model input
114+
dnnNet.setInput(inputBlob);
115+
116+
// provide inference
117+
Mat classification = dnnNet.forward();
118+
119+
// decode classification results
120+
DnnOpenCV.getPredictedClass(classification);
121+
}
122+
}
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
import argparse
2+
import os
3+
4+
import cv2
5+
import numpy as np
6+
import onnx
7+
import onnxruntime
8+
import torch
9+
from albumentations import (
10+
CenterCrop,
11+
Compose,
12+
Normalize,
13+
Resize,
14+
)
15+
from torchvision import models
16+
17+
18+
def compare_pytorch_onnx(
19+
original_model_preds, onnx_model_path, input_image,
20+
):
21+
# get onnx result
22+
session = onnxruntime.InferenceSession(onnx_model_path)
23+
input_name = session.get_inputs()[0].name
24+
onnx_result = session.run([], {input_name: input_image})
25+
onnx_result = np.squeeze(onnx_result, axis=0)
26+
27+
print("Checking PyTorch model and converted ONNX model outputs ... ")
28+
for test_onnx_result, gold_result in zip(onnx_result, original_model_preds):
29+
np.testing.assert_almost_equal(
30+
gold_result, test_onnx_result, decimal=3,
31+
)
32+
print("PyTorch and ONNX output values are equal! \n")
33+
34+
35+
def get_onnx_model(
36+
original_model, input_image, model_path="models", model_name="pytorch_mobilenet",
37+
):
38+
# create model root dir
39+
os.makedirs(model_path, exist_ok=True)
40+
41+
model_name = os.path.join(model_path, model_name + ".onnx")
42+
43+
torch.onnx.export(
44+
original_model, torch.Tensor(input_image), model_name, verbose=True,
45+
)
46+
print("ONNX model was successfully generated: {} \n".format(model_name))
47+
48+
return model_name
49+
50+
51+
def get_preprocessed_image(image_name):
52+
# read image
53+
original_image = cv2.imread(image_name)
54+
55+
# convert original image to RGB format
56+
image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
57+
58+
# transform input image:
59+
# 1. resize the image
60+
# 2. crop the image
61+
# 3. normalize: subtract mean and divide by standard deviation
62+
transform = Compose(
63+
[
64+
Resize(height=256, width=256),
65+
CenterCrop(224, 224),
66+
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
67+
],
68+
)
69+
image = transform(image=image)["image"]
70+
71+
# change the order of channels
72+
image = image.transpose(2, 0, 1)
73+
return np.expand_dims(image, axis=0)
74+
75+
76+
def get_predicted_class(pytorch_preds):
77+
# read ImageNet class id to name mapping
78+
with open("imagenet_classes.txt") as f:
79+
labels = [line.strip() for line in f.readlines()]
80+
81+
# find the class with the maximum score
82+
pytorch_class_idx = np.argmax(pytorch_preds, axis=1)
83+
predicted_pytorch_label = labels[pytorch_class_idx[0]]
84+
85+
# print top predicted class
86+
print("Predicted class by PyTorch model: ", predicted_pytorch_label)
87+
88+
89+
def get_execution_arguments():
90+
parser = argparse.ArgumentParser()
91+
parser.add_argument(
92+
"--input_image",
93+
type=str,
94+
help="Define the full input image path, including its name",
95+
default="images/coffee.jpg",
96+
)
97+
return parser.parse_args()
98+
99+
100+
if __name__ == "__main__":
101+
# get the test case parameters
102+
args = get_execution_arguments()
103+
104+
# read and process the input image
105+
image = get_preprocessed_image(image_name=args.input_image)
106+
107+
# obtain original model
108+
pytorch_model = models.mobilenet_v2(pretrained=True)
109+
110+
# provide inference of the original PyTorch model
111+
pytorch_model.eval()
112+
pytorch_predictions = pytorch_model(torch.Tensor(image)).detach().numpy()
113+
114+
# obtain OpenCV generated ONNX model
115+
onnx_model_path = get_onnx_model(original_model=pytorch_model, input_image=image)
116+
117+
# check if conversion succeeded
118+
onnx_model = onnx.load(onnx_model_path)
119+
onnx.checker.check_model(onnx_model)
120+
121+
# check onnx model output
122+
compare_pytorch_onnx(
123+
pytorch_predictions, onnx_model_path, image,
124+
)
125+
126+
# decode classification results
127+
get_predicted_class(pytorch_preds=pytorch_predictions)

0 commit comments

Comments
 (0)