1
+ import org .opencv .core .Core ;
2
+ import org .opencv .core .Mat ;
3
+ import org .opencv .core .Rect ;
4
+ import org .opencv .core .Scalar ;
5
+ import org .opencv .core .Size ;
6
+ import org .opencv .dnn .Net ;
7
+ import org .opencv .dnn .Dnn ;
8
+ import org .opencv .imgproc .Imgproc ;
9
+ import org .opencv .imgcodecs .Imgcodecs ;
10
+
11
+ import java .io .IOException ;
12
+ import java .util .ArrayList ;
13
+ import java .util .stream .Collectors ;
14
+ import java .util .stream .Stream ;
15
+ import java .nio .file .Files ;
16
+ import java .nio .file .Paths ;
17
+
18
+ import org .opencv .core .CvType ;
19
+
20
+
21
+ public class DnnOpenCV {
22
+ private static final int TARGET_IMG_WIDTH = 224 ;
23
+ private static final int TARGET_IMG_HEIGHT = 224 ;
24
+
25
+ private static final double SCALE_FACTOR = 1 / 255.0 ;
26
+
27
+ private static final String IMAGENET_CLASSES = "imagenet_classes.txt" ;
28
+ private static final String MODEL_PATH = "models/pytorch_mobilenet.onnx" ;
29
+
30
+ private static final Scalar MEAN = new Scalar (0.485 , 0.456 , 0.406 );
31
+ private static final Scalar STD = new Scalar (0.229 , 0.224 , 0.225 );
32
+
33
+ public static ArrayList <String > getImgLabels (String imgLabelsFilePath ) throws IOException {
34
+ ArrayList <String > imgLabels ;
35
+ try (Stream <String > lines = Files .lines (Paths .get (imgLabelsFilePath ))) {
36
+ imgLabels = lines .collect (Collectors .toCollection (ArrayList ::new ));
37
+ }
38
+ return imgLabels ;
39
+ }
40
+
41
+ public static Mat centerCrop (Mat inputImage ) {
42
+ int y1 = Math .round ((inputImage .rows () - TARGET_IMG_HEIGHT ) / 2 );
43
+ int y2 = Math .round (y1 + TARGET_IMG_HEIGHT );
44
+ int x1 = Math .round ((inputImage .cols () - TARGET_IMG_WIDTH ) / 2 );
45
+ int x2 = Math .round (x1 + TARGET_IMG_WIDTH );
46
+
47
+ Rect centerRect = new Rect (x1 , y1 , (x2 - x1 ), (y2 - y1 ));
48
+ Mat croppedImage = new Mat (inputImage , centerRect );
49
+
50
+ return croppedImage ;
51
+ }
52
+
53
+ public static Mat getPreprocessedImage (String imagePath ) {
54
+ // get the image from the internal resource folder
55
+ Mat image = Imgcodecs .imread (imagePath );
56
+
57
+ // resize input image
58
+ Imgproc .resize (image , image , new Size (256 , 256 ));
59
+
60
+ // create empty Mat images for float conversions
61
+ Mat imgFloat = new Mat (image .rows (), image .cols (), CvType .CV_32FC3 );
62
+
63
+ // convert input image to float type
64
+ image .convertTo (imgFloat , CvType .CV_32FC3 , SCALE_FACTOR );
65
+
66
+ // crop input image
67
+ imgFloat = centerCrop (imgFloat );
68
+
69
+ // prepare DNN input
70
+ Mat blob = Dnn .blobFromImage (
71
+ imgFloat ,
72
+ 1.0 , /* default scalefactor */
73
+ new Size (TARGET_IMG_WIDTH , TARGET_IMG_HEIGHT ), /* target size */
74
+ MEAN , /* mean */
75
+ true , /* swapRB */
76
+ false /* crop */
77
+ );
78
+
79
+ // divide on std
80
+ Core .divide (blob , STD , blob );
81
+
82
+ return blob ;
83
+ }
84
+
85
+ public static String getPredictedClass (Mat classificationResult ) {
86
+ ArrayList <String > imgLabels = new ArrayList <String >();
87
+ try {
88
+ imgLabels = getImgLabels (IMAGENET_CLASSES );
89
+ } catch (IOException ex ) {
90
+ System .out .printf ("Could not read %s file:%n" , IMAGENET_CLASSES );
91
+ ex .printStackTrace ();
92
+ }
93
+ if (imgLabels .isEmpty ()) {
94
+ return "" ;
95
+ }
96
+ // obtain max prediction result
97
+ Core .MinMaxLocResult mm = Core .minMaxLoc (classificationResult );
98
+ double maxValIndex = mm .maxLoc .x ;
99
+ return imgLabels .get ((int ) maxValIndex );
100
+ }
101
+
102
+ public static void main (String [] args ) {
103
+ String imageLocation = "images/coffee.jpg" ;
104
+
105
+ // load the OpenCV native library
106
+ System .loadLibrary (Core .NATIVE_LIBRARY_NAME );
107
+
108
+ // read and process the input image
109
+ Mat inputBlob = DnnOpenCV .getPreprocessedImage (imageLocation );
110
+
111
+ // read generated ONNX model into org.opencv.dnn.Net object
112
+ Net dnnNet = Dnn .readNetFromONNX (DnnOpenCV .MODEL_PATH );
113
+ System .out .println ("DNN from ONNX was successfully loaded!" );
114
+
115
+ // set OpenCV model input
116
+ dnnNet .setInput (inputBlob );
117
+
118
+ // provide inference
119
+ Mat classification = dnnNet .forward ();
120
+
121
+ // decode classification results
122
+ String label = DnnOpenCV .getPredictedClass (classification );
123
+ System .out .println ("Predicted Class: " + label );
124
+ }
125
+ }
0 commit comments