7
7
import android .content .res .AssetManager ;
8
8
import android .util .Log ;
9
9
10
- import org .tensorflow .lite .Interpreter ;
10
+ import com .google .android .gms .tasks .Task ;
11
+ import com .google .android .gms .tasks .Tasks ;
12
+ import com .google .android .gms .tflite .TfLite ;
13
+ import com .google .android .gms .tflite .TfLiteInitializationOptions ;
14
+ import com .google .android .gms .tflite .TfLiteGpu ;
15
+
16
+ import org .tensorflow .lite .InterpreterApi ;
17
+ import org .tensorflow .lite .InterpreterApi .Options .TfLiteRuntime ;
11
18
import org .tensorflow .lite .Tensor ;
12
- import org .tensorflow .lite .gpu .CompatibilityList ;
13
- import org .tensorflow .lite .gpu .GpuDelegate ;
14
19
import org .tensorflow .lite .gpu .GpuDelegateFactory ;
15
20
16
21
import java .io .BufferedReader ;
30
35
31
36
public class Yolo {
32
37
protected float [][][] output ;
33
- protected Interpreter interpreter ;
34
- protected Vector < String > labels ;
38
+ protected InterpreterApi interpreter ;
39
+ protected Vector < String > labels ;
35
40
protected final Context context ;
36
41
protected final String model_path ;
37
42
protected final boolean is_assets ;
@@ -64,7 +69,6 @@ public Tensor getInputTensor() {
64
69
return this .interpreter .getInputTensor (0 );
65
70
}
66
71
67
-
68
72
public void initialize_model () throws Exception {
69
73
AssetManager asset_manager = null ;
70
74
MappedByteBuffer buffer ;
@@ -89,24 +93,38 @@ public void initialize_model() throws Exception {
89
93
buffer = file_channel .map (FileChannel .MapMode .READ_ONLY , 0 , file_channel .size ());
90
94
}
91
95
92
- Interpreter .Options interpreterOptions = new Interpreter .Options ();
93
- CompatibilityList compatibilityList = new CompatibilityList ();
96
+ // Initialize LiteRT via Google Play services.
97
+ Task < Void > initializeTask ;
98
+ if (use_gpu ) {
99
+ TfLiteInitializationOptions initOptions = TfLiteInitializationOptions .builder ()
100
+ .setEnableGpuDelegateSupport (true )
101
+ .build ();
102
+ initializeTask = TfLite .initialize (context , initOptions );
103
+ } else {
104
+ initializeTask = TfLite .initialize (context );
105
+ }
106
+ // Ensure initialization completes (do not call await on the UI thread).
107
+ Tasks .await (initializeTask );
94
108
95
- if (use_gpu && compatibilityList .isDelegateSupportedOnThisDevice ()) {
96
- try {
97
- GpuDelegateFactory .Options delegateOptions = compatibilityList .getBestOptionsForThisDevice ();
98
- GpuDelegate gpuDelegate = new GpuDelegate (delegateOptions .setQuantizedModelsAllowed (this .quantization ));
99
- interpreterOptions .addDelegate (gpuDelegate );
100
- } catch (Exception e ) {
101
- Log .e ("Yolo" , "GPU delegate failed, falling back to CPU" , e );
102
- interpreterOptions = new Interpreter .Options ();
103
- interpreterOptions .setNumThreads (num_threads );
109
+ // Set up interpreter options for LiteRT.
110
+ InterpreterApi .Options interpreterOptions = new InterpreterApi .Options ()
111
+ .setRuntime (TfLiteRuntime .FROM_SYSTEM_ONLY );
112
+ if (use_gpu ) {
113
+ // Check for GPU delegate availability.
114
+ boolean gpuAvailable = TfLiteGpu .isGpuDelegateAvailable (context );
115
+ if (gpuAvailable ) {
116
+ try {
117
+ interpreterOptions .addDelegateFactory (new GpuDelegateFactory ());
118
+ } catch (Exception e ) {
119
+ Log .e ("Yolo" , "GPU delegate failed, falling back to CPU" , e );
120
+ }
104
121
}
105
122
} else {
106
123
interpreterOptions .setNumThreads (num_threads );
107
124
}
108
125
109
- this .interpreter = new Interpreter (buffer , interpreterOptions );
126
+ // Create the interpreter using the LiteRT Interpreter API.
127
+ this .interpreter = InterpreterApi .create (buffer , interpreterOptions );
110
128
this .interpreter .allocateTensors ();
111
129
this .labels = load_labels (asset_manager , label_path );
112
130
int [] shape = interpreter .getOutputTensor (0 ).shape ();
@@ -127,7 +145,7 @@ public void initialize_model() throws Exception {
127
145
}
128
146
}
129
147
130
- protected Vector < String > load_labels (AssetManager asset_manager , String label_path ) throws Exception {
148
+ protected Vector < String > load_labels (AssetManager asset_manager , String label_path ) throws Exception {
131
149
if (label_path == null || label_path .isEmpty ()) {
132
150
throw new Exception ("Invalid label path" );
133
151
}
@@ -139,8 +157,7 @@ protected Vector<String> load_labels(AssetManager asset_manager, String label_pa
139
157
} else {
140
158
br = new BufferedReader (new InputStreamReader (new FileInputStream (label_path )));
141
159
}
142
-
143
- Vector <String > labels = new Vector <>();
160
+ Vector < String > labels = new Vector < > ();
144
161
String line ;
145
162
while ((line = br .readLine ()) != null ) {
146
163
labels .add (line );
@@ -157,19 +174,18 @@ protected Vector<String> load_labels(AssetManager asset_manager, String label_pa
157
174
}
158
175
}
159
176
160
- public List < Map < String , Object >> detect_task (ByteBuffer byteBuffer ,
161
- int source_height ,
162
- int source_width ,
163
- float iou_threshold ,
164
- float conf_threshold , float class_threshold ) throws Exception {
177
+ public List < Map < String , Object >> detect_task (ByteBuffer byteBuffer ,
178
+ int source_height ,
179
+ int source_width ,
180
+ float iou_threshold ,
181
+ float conf_threshold , float class_threshold ) throws Exception {
165
182
if (interpreter == null ) {
166
183
throw new Exception ("Interpreter not initialized" );
167
184
}
168
-
169
185
try {
170
186
int [] input_shape = this .interpreter .getInputTensor (0 ).shape ();
171
187
this .interpreter .run (byteBuffer , this .output );
172
- List < float []> boxes = filter_box (this .output , iou_threshold , conf_threshold ,
188
+ List < float [] > boxes = filter_box (this .output , iou_threshold , conf_threshold ,
173
189
class_threshold , input_shape [1 ], input_shape [2 ]);
174
190
boxes = restore_size (boxes , input_shape [1 ], input_shape [2 ], source_width , source_height );
175
191
return out (boxes , this .labels );
@@ -180,11 +196,10 @@ public List<Map<String, Object>> detect_task(ByteBuffer byteBuffer,
180
196
}
181
197
}
182
198
183
- protected List < float []> filter_box (float [][][] model_outputs , float iou_threshold ,
184
- float conf_threshold , float class_threshold , float input_width , float input_height ) {
199
+ protected List < float [] > filter_box (float [][][] model_outputs , float iou_threshold ,
200
+ float conf_threshold , float class_threshold , float input_width , float input_height ) {
185
201
try {
186
- //model_outputs = [1,box+model_conf+class,detected_box]
187
- List <float []> pre_box = new ArrayList <>();
202
+ List < float [] > pre_box = new ArrayList < > ();
188
203
int conf_index = 4 ;
189
204
int class_index = 5 ;
190
205
int dimension = model_outputs [0 ][0 ].length ;
@@ -193,50 +208,44 @@ protected List<float[]> filter_box(float[][][] model_outputs, float iou_threshol
193
208
int max_index ;
194
209
float max ;
195
210
for (int i = 0 ; i < rows ; i ++) {
196
- //convert xywh to xyxy
197
- x1 = (model_outputs [0 ][i ][0 ] - model_outputs [0 ][i ][2 ] / 2f ) * input_width ;
198
- y1 = (model_outputs [0 ][i ][1 ] - model_outputs [0 ][i ][3 ] / 2f ) * input_height ;
199
- x2 = (model_outputs [0 ][i ][0 ] + model_outputs [0 ][i ][2 ] / 2f ) * input_width ;
200
- y2 = (model_outputs [0 ][i ][1 ] + model_outputs [0 ][i ][3 ] / 2f ) * input_height ;
211
+ x1 = (model_outputs [0 ][i ][0 ] - model_outputs [0 ][i ][2 ] / 2 f ) * input_width ;
212
+ y1 = (model_outputs [0 ][i ][1 ] - model_outputs [0 ][i ][3 ] / 2 f ) * input_height ;
213
+ x2 = (model_outputs [0 ][i ][0 ] + model_outputs [0 ][i ][2 ] / 2 f ) * input_width ;
214
+ y2 = (model_outputs [0 ][i ][1 ] + model_outputs [0 ][i ][3 ] / 2 f ) * input_height ;
201
215
conf = model_outputs [0 ][i ][conf_index ];
202
216
if (conf < conf_threshold ) continue ;
203
-
204
217
max_index = class_index ;
205
218
max = model_outputs [0 ][i ][max_index ];
206
-
207
219
for (int j = class_index + 1 ; j < dimension ; j ++) {
208
220
float current = model_outputs [0 ][i ][j ];
209
221
if (current > max ) {
210
222
max = current ;
211
223
max_index = j ;
212
224
}
213
225
}
214
- if (max > class_threshold ){
226
+ if (max > class_threshold ) {
215
227
float [] tmp = new float [6 ];
216
228
tmp [0 ] = x1 ;
217
229
tmp [1 ] = y1 ;
218
230
tmp [2 ] = x2 ;
219
231
tmp [3 ] = y2 ;
220
232
tmp [4 ] = model_outputs [0 ][i ][max_index ];
221
- tmp [5 ] = (max_index - class_index ) * 1f ;
233
+ tmp [5 ] = (max_index - class_index ) * 1 f ;
222
234
pre_box .add (tmp );
223
235
}
224
236
}
225
- if (pre_box .isEmpty ()) return new ArrayList <>();
226
- //for reverse orden, insteand of using .reversed method
227
- Comparator <float []> compareValues = (v1 , v2 ) -> Float .compare (v2 [4 ], v1 [4 ]);
228
- //Collections.sort(pre_box,compareValues.reversed());
237
+ if (pre_box .isEmpty ()) return new ArrayList < > ();
238
+ Comparator < float [] > compareValues = (v1 , v2 ) - > Float .compare (v2 [4 ], v1 [4 ]);
229
239
Collections .sort (pre_box , compareValues );
230
240
return nms (pre_box , iou_threshold );
231
241
} catch (Exception e ) {
232
242
throw e ;
233
243
}
234
244
}
235
245
236
- protected static List < float []> nms (List < float []> boxes , float iou_threshold ) {
246
+ protected static List < float [] > nms (List < float [] > boxes , float iou_threshold ) {
237
247
try {
238
- List <float []> filteredBoxes = new ArrayList <>(boxes ); // Create a copy of the input list
239
-
248
+ List < float [] > filteredBoxes = new ArrayList < > (boxes );
240
249
for (int i = 0 ; i < filteredBoxes .size (); i ++) {
241
250
float [] box = filteredBoxes .get (i );
242
251
for (int j = i + 1 ; j < filteredBoxes .size (); j ++) {
@@ -245,13 +254,11 @@ protected static List<float[]> nms(List<float[]> boxes, float iou_threshold) {
245
254
float y1 = Math .max (next_box [1 ], box [1 ]);
246
255
float x2 = Math .min (next_box [2 ], box [2 ]);
247
256
float y2 = Math .min (next_box [3 ], box [3 ]);
248
-
249
257
float width = Math .max (0 , x2 - x1 );
250
258
float height = Math .max (0 , y2 - y1 );
251
-
252
259
float intersection = width * height ;
253
- float union = (next_box [2 ] - next_box [0 ]) * (next_box [3 ] - next_box [1 ])
254
- + (box [2 ] - box [0 ]) * (box [3 ] - box [1 ]) - intersection ;
260
+ float union = (next_box [2 ] - next_box [0 ]) * (next_box [3 ] - next_box [1 ]) +
261
+ (box [2 ] - box [0 ]) * (box [3 ] - box [1 ]) - intersection ;
255
262
float iou = intersection / union ;
256
263
if (iou > iou_threshold ) {
257
264
filteredBoxes .remove (j );
@@ -270,14 +277,12 @@ public boolean isInitialized() {
270
277
return interpreter != null ;
271
278
}
272
279
273
-
274
- protected List <float []> restore_size (List <float []> nms ,
275
- int input_width ,
276
- int input_height ,
277
- int src_width ,
278
- int src_height ) {
280
+ protected List < float [] > restore_size (List < float [] > nms ,
281
+ int input_width ,
282
+ int input_height ,
283
+ int src_width ,
284
+ int src_height ) {
279
285
try {
280
- //restore size after scaling, larger images
281
286
if (src_width > input_width || src_height > input_height ) {
282
287
float gainx = src_width / (float ) input_width ;
283
288
float gainy = src_height / (float ) input_height ;
@@ -287,10 +292,9 @@ protected List<float[]> restore_size(List<float[]> nms,
287
292
nms .get (i )[2 ] = min (src_width , Math .max (nms .get (i )[2 ] * gainx , 0 ));
288
293
nms .get (i )[3 ] = min (src_height , Math .max (nms .get (i )[3 ] * gainy , 0 ));
289
294
}
290
- //restore size after padding, smaller images
291
295
} else {
292
- float padx = (src_width - input_width ) / 2f ;
293
- float pady = (src_height - input_height ) / 2f ;
296
+ float padx = (src_width - input_width ) / 2 f ;
297
+ float pady = (src_height - input_height ) / 2 f ;
294
298
for (int i = 0 ; i < nms .size (); i ++) {
295
299
nms .get (i )[0 ] = min (src_width , Math .max (nms .get (i )[0 ] + padx , 0 ));
296
300
nms .get (i )[1 ] = min (src_height , Math .max (nms .get (i )[1 ] + pady , 0 ));
@@ -303,13 +307,15 @@ protected List<float[]> restore_size(List<float[]> nms,
303
307
throw new RuntimeException (e .getMessage ());
304
308
}
305
309
}
306
- protected List <Map <String , Object >> out (List <float []> yolo_result , Vector <String > labels ) {
310
+
311
+ protected List < Map < String , Object >> out (List < float [] > yolo_result , Vector < String > labels ) {
307
312
try {
308
- List <Map <String , Object >> result = new ArrayList <>();
309
- //utils.getScreenshotBmp(bitmap, "current");
310
- for (float [] box : yolo_result ) {
311
- Map <String , Object > output = new HashMap <>();
312
- output .put ("box" , new float []{box [0 ], box [1 ], box [2 ], box [3 ], box [4 ]}); //x1,y1,x2,y2,conf_class
313
+ List < Map < String , Object >> result = new ArrayList < > ();
314
+ for (float [] box : yolo_result ) {
315
+ Map < String , Object > output = new HashMap < > ();
316
+ output .put ("box" , new float [] {
317
+ box [0 ], box [1 ], box [2 ], box [3 ], box [4 ]
318
+ });
313
319
output .put ("tag" , labels .get ((int ) box [5 ]));
314
320
result .add (output );
315
321
}
@@ -329,4 +335,4 @@ public void close() {
329
335
Log .e ("Yolo" , "Interpreter close error" , e );
330
336
}
331
337
}
332
- }
338
+ }
0 commit comments