Skip to content

Commit e30d0f4

Browse files
author
Persie
committed
changed dependencies
1 parent 2c0b843 commit e30d0f4

File tree

2 files changed

+78
-77
lines changed

2 files changed

+78
-77
lines changed

android/build.gradle

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,7 @@ android {
5454
}
5555
dependencies{
5656
implementation 'com.github.vladiH:opencv-android:v1.0.0'
57-
implementation 'org.tensorflow:tensorflow-lite:2.10.0'
58-
implementation 'org.tensorflow:tensorflow-lite-api:2.10.0'
59-
implementation 'org.tensorflow:tensorflow-lite-gpu:2.10.0'
60-
implementation 'org.tensorflow:tensorflow-lite-gpu-api:2.10.0'
61-
implementation 'org.tensorflow:tensorflow-lite-gpu-delegate-plugin:0.4.3'
62-
implementation 'org.tensorflow:tensorflow-lite-support:0.4.3'
63-
implementation 'org.tensorflow:tensorflow-lite-metadata:0.4.3'
64-
implementation 'org.tensorflow:tensorflow-lite-select-tf-ops:2.11.0'
57+
implementation 'com.google.android.gms:play-services-tflite-java:16.1.0'
58+
implementation 'com.google.android.gms:play-services-tflite-support:16.1.0'
59+
implementation 'com.google.android.gms:play-services-tflite-gpu:16.2.0'
6560
}

android/src/main/java/com/vladih/computer_vision/flutter_vision/models/Yolo.java

Lines changed: 75 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,15 @@
77
import android.content.res.AssetManager;
88
import android.util.Log;
99

10-
import org.tensorflow.lite.Interpreter;
10+
import com.google.android.gms.tasks.Task;
11+
import com.google.android.gms.tasks.Tasks;
12+
import com.google.android.gms.tflite.TfLite;
13+
import com.google.android.gms.tflite.TfLiteInitializationOptions;
14+
import com.google.android.gms.tflite.TfLiteGpu;
15+
16+
import org.tensorflow.lite.InterpreterApi;
17+
import org.tensorflow.lite.InterpreterApi.Options.TfLiteRuntime;
1118
import org.tensorflow.lite.Tensor;
12-
import org.tensorflow.lite.gpu.CompatibilityList;
13-
import org.tensorflow.lite.gpu.GpuDelegate;
1419
import org.tensorflow.lite.gpu.GpuDelegateFactory;
1520

1621
import java.io.BufferedReader;
@@ -30,8 +35,8 @@
3035

3136
public class Yolo {
3237
protected float[][][] output;
33-
protected Interpreter interpreter;
34-
protected Vector<String> labels;
38+
protected InterpreterApi interpreter;
39+
protected Vector < String > labels;
3540
protected final Context context;
3641
protected final String model_path;
3742
protected final boolean is_assets;
@@ -64,7 +69,6 @@ public Tensor getInputTensor() {
6469
return this.interpreter.getInputTensor(0);
6570
}
6671

67-
6872
public void initialize_model() throws Exception {
6973
AssetManager asset_manager = null;
7074
MappedByteBuffer buffer;
@@ -89,24 +93,38 @@ public void initialize_model() throws Exception {
8993
buffer = file_channel.map(FileChannel.MapMode.READ_ONLY, 0, file_channel.size());
9094
}
9195

92-
Interpreter.Options interpreterOptions = new Interpreter.Options();
93-
CompatibilityList compatibilityList = new CompatibilityList();
96+
// Initialize LiteRT via Google Play services.
97+
Task < Void > initializeTask;
98+
if (use_gpu) {
99+
TfLiteInitializationOptions initOptions = TfLiteInitializationOptions.builder()
100+
.setEnableGpuDelegateSupport(true)
101+
.build();
102+
initializeTask = TfLite.initialize(context, initOptions);
103+
} else {
104+
initializeTask = TfLite.initialize(context);
105+
}
106+
// Ensure initialization completes (do not call await on the UI thread).
107+
Tasks.await(initializeTask);
94108

95-
if (use_gpu && compatibilityList.isDelegateSupportedOnThisDevice()) {
96-
try {
97-
GpuDelegateFactory.Options delegateOptions = compatibilityList.getBestOptionsForThisDevice();
98-
GpuDelegate gpuDelegate = new GpuDelegate(delegateOptions.setQuantizedModelsAllowed(this.quantization));
99-
interpreterOptions.addDelegate(gpuDelegate);
100-
} catch (Exception e) {
101-
Log.e("Yolo", "GPU delegate failed, falling back to CPU", e);
102-
interpreterOptions = new Interpreter.Options();
103-
interpreterOptions.setNumThreads(num_threads);
109+
// Set up interpreter options for LiteRT.
110+
InterpreterApi.Options interpreterOptions = new InterpreterApi.Options()
111+
.setRuntime(TfLiteRuntime.FROM_SYSTEM_ONLY);
112+
if (use_gpu) {
113+
// Check for GPU delegate availability.
114+
boolean gpuAvailable = TfLiteGpu.isGpuDelegateAvailable(context);
115+
if (gpuAvailable) {
116+
try {
117+
interpreterOptions.addDelegateFactory(new GpuDelegateFactory());
118+
} catch (Exception e) {
119+
Log.e("Yolo", "GPU delegate failed, falling back to CPU", e);
120+
}
104121
}
105122
} else {
106123
interpreterOptions.setNumThreads(num_threads);
107124
}
108125

109-
this.interpreter = new Interpreter(buffer, interpreterOptions);
126+
// Create the interpreter using the LiteRT Interpreter API.
127+
this.interpreter = InterpreterApi.create(buffer, interpreterOptions);
110128
this.interpreter.allocateTensors();
111129
this.labels = load_labels(asset_manager, label_path);
112130
int[] shape = interpreter.getOutputTensor(0).shape();
@@ -127,7 +145,7 @@ public void initialize_model() throws Exception {
127145
}
128146
}
129147

130-
protected Vector<String> load_labels(AssetManager asset_manager, String label_path) throws Exception {
148+
protected Vector < String > load_labels(AssetManager asset_manager, String label_path) throws Exception {
131149
if (label_path == null || label_path.isEmpty()) {
132150
throw new Exception("Invalid label path");
133151
}
@@ -139,8 +157,7 @@ protected Vector<String> load_labels(AssetManager asset_manager, String label_pa
139157
} else {
140158
br = new BufferedReader(new InputStreamReader(new FileInputStream(label_path)));
141159
}
142-
143-
Vector<String> labels = new Vector<>();
160+
Vector < String > labels = new Vector < > ();
144161
String line;
145162
while ((line = br.readLine()) != null) {
146163
labels.add(line);
@@ -157,19 +174,18 @@ protected Vector<String> load_labels(AssetManager asset_manager, String label_pa
157174
}
158175
}
159176

160-
public List<Map<String, Object>> detect_task(ByteBuffer byteBuffer,
161-
int source_height,
162-
int source_width,
163-
float iou_threshold,
164-
float conf_threshold, float class_threshold) throws Exception {
177+
public List < Map < String, Object >> detect_task(ByteBuffer byteBuffer,
178+
int source_height,
179+
int source_width,
180+
float iou_threshold,
181+
float conf_threshold, float class_threshold) throws Exception {
165182
if (interpreter == null) {
166183
throw new Exception("Interpreter not initialized");
167184
}
168-
169185
try {
170186
int[] input_shape = this.interpreter.getInputTensor(0).shape();
171187
this.interpreter.run(byteBuffer, this.output);
172-
List<float[]> boxes = filter_box(this.output, iou_threshold, conf_threshold,
188+
List < float[] > boxes = filter_box(this.output, iou_threshold, conf_threshold,
173189
class_threshold, input_shape[1], input_shape[2]);
174190
boxes = restore_size(boxes, input_shape[1], input_shape[2], source_width, source_height);
175191
return out(boxes, this.labels);
@@ -180,11 +196,10 @@ public List<Map<String, Object>> detect_task(ByteBuffer byteBuffer,
180196
}
181197
}
182198

183-
protected List<float[]> filter_box(float[][][] model_outputs, float iou_threshold,
184-
float conf_threshold, float class_threshold, float input_width, float input_height) {
199+
protected List < float[] > filter_box(float[][][] model_outputs, float iou_threshold,
200+
float conf_threshold, float class_threshold, float input_width, float input_height) {
185201
try {
186-
//model_outputs = [1,box+model_conf+class,detected_box]
187-
List<float[]> pre_box = new ArrayList<>();
202+
List < float[] > pre_box = new ArrayList < > ();
188203
int conf_index = 4;
189204
int class_index = 5;
190205
int dimension = model_outputs[0][0].length;
@@ -193,50 +208,44 @@ protected List<float[]> filter_box(float[][][] model_outputs, float iou_threshol
193208
int max_index;
194209
float max;
195210
for (int i = 0; i < rows; i++) {
196-
//convert xywh to xyxy
197-
x1 = (model_outputs[0][i][0] - model_outputs[0][i][2] / 2f) * input_width;
198-
y1 = (model_outputs[0][i][1] - model_outputs[0][i][3] / 2f) * input_height;
199-
x2 = (model_outputs[0][i][0] + model_outputs[0][i][2] / 2f) * input_width;
200-
y2 = (model_outputs[0][i][1] + model_outputs[0][i][3] / 2f) * input_height;
211+
x1 = (model_outputs[0][i][0] - model_outputs[0][i][2] / 2 f) * input_width;
212+
y1 = (model_outputs[0][i][1] - model_outputs[0][i][3] / 2 f) * input_height;
213+
x2 = (model_outputs[0][i][0] + model_outputs[0][i][2] / 2 f) * input_width;
214+
y2 = (model_outputs[0][i][1] + model_outputs[0][i][3] / 2 f) * input_height;
201215
conf = model_outputs[0][i][conf_index];
202216
if (conf < conf_threshold) continue;
203-
204217
max_index = class_index;
205218
max = model_outputs[0][i][max_index];
206-
207219
for (int j = class_index + 1; j < dimension; j++) {
208220
float current = model_outputs[0][i][j];
209221
if (current > max) {
210222
max = current;
211223
max_index = j;
212224
}
213225
}
214-
if (max > class_threshold){
226+
if (max > class_threshold) {
215227
float[] tmp = new float[6];
216228
tmp[0] = x1;
217229
tmp[1] = y1;
218230
tmp[2] = x2;
219231
tmp[3] = y2;
220232
tmp[4] = model_outputs[0][i][max_index];
221-
tmp[5] = (max_index - class_index) * 1f;
233+
tmp[5] = (max_index - class_index) * 1 f;
222234
pre_box.add(tmp);
223235
}
224236
}
225-
if (pre_box.isEmpty()) return new ArrayList<>();
226-
//for reverse orden, insteand of using .reversed method
227-
Comparator<float[]> compareValues = (v1, v2) -> Float.compare(v2[4], v1[4]);
228-
//Collections.sort(pre_box,compareValues.reversed());
237+
if (pre_box.isEmpty()) return new ArrayList < > ();
238+
Comparator < float[] > compareValues = (v1, v2) - > Float.compare(v2[4], v1[4]);
229239
Collections.sort(pre_box, compareValues);
230240
return nms(pre_box, iou_threshold);
231241
} catch (Exception e) {
232242
throw e;
233243
}
234244
}
235245

236-
protected static List<float[]> nms(List<float[]> boxes, float iou_threshold) {
246+
protected static List < float[] > nms(List < float[] > boxes, float iou_threshold) {
237247
try {
238-
List<float[]> filteredBoxes = new ArrayList<>(boxes); // Create a copy of the input list
239-
248+
List < float[] > filteredBoxes = new ArrayList < > (boxes);
240249
for (int i = 0; i < filteredBoxes.size(); i++) {
241250
float[] box = filteredBoxes.get(i);
242251
for (int j = i + 1; j < filteredBoxes.size(); j++) {
@@ -245,13 +254,11 @@ protected static List<float[]> nms(List<float[]> boxes, float iou_threshold) {
245254
float y1 = Math.max(next_box[1], box[1]);
246255
float x2 = Math.min(next_box[2], box[2]);
247256
float y2 = Math.min(next_box[3], box[3]);
248-
249257
float width = Math.max(0, x2 - x1);
250258
float height = Math.max(0, y2 - y1);
251-
252259
float intersection = width * height;
253-
float union = (next_box[2] - next_box[0]) * (next_box[3] - next_box[1])
254-
+ (box[2] - box[0]) * (box[3] - box[1]) - intersection;
260+
float union = (next_box[2] - next_box[0]) * (next_box[3] - next_box[1]) +
261+
(box[2] - box[0]) * (box[3] - box[1]) - intersection;
255262
float iou = intersection / union;
256263
if (iou > iou_threshold) {
257264
filteredBoxes.remove(j);
@@ -270,14 +277,12 @@ public boolean isInitialized() {
270277
return interpreter != null;
271278
}
272279

273-
274-
protected List<float[]> restore_size(List<float[]> nms,
275-
int input_width,
276-
int input_height,
277-
int src_width,
278-
int src_height) {
280+
protected List < float[] > restore_size(List < float[] > nms,
281+
int input_width,
282+
int input_height,
283+
int src_width,
284+
int src_height) {
279285
try {
280-
//restore size after scaling, larger images
281286
if (src_width > input_width || src_height > input_height) {
282287
float gainx = src_width / (float) input_width;
283288
float gainy = src_height / (float) input_height;
@@ -287,10 +292,9 @@ protected List<float[]> restore_size(List<float[]> nms,
287292
nms.get(i)[2] = min(src_width, Math.max(nms.get(i)[2] * gainx, 0));
288293
nms.get(i)[3] = min(src_height, Math.max(nms.get(i)[3] * gainy, 0));
289294
}
290-
//restore size after padding, smaller images
291295
} else {
292-
float padx = (src_width - input_width) / 2f;
293-
float pady = (src_height - input_height) / 2f;
296+
float padx = (src_width - input_width) / 2 f;
297+
float pady = (src_height - input_height) / 2 f;
294298
for (int i = 0; i < nms.size(); i++) {
295299
nms.get(i)[0] = min(src_width, Math.max(nms.get(i)[0] + padx, 0));
296300
nms.get(i)[1] = min(src_height, Math.max(nms.get(i)[1] + pady, 0));
@@ -303,13 +307,15 @@ protected List<float[]> restore_size(List<float[]> nms,
303307
throw new RuntimeException(e.getMessage());
304308
}
305309
}
306-
protected List<Map<String, Object>> out(List<float[]> yolo_result, Vector<String> labels) {
310+
311+
protected List < Map < String, Object >> out(List < float[] > yolo_result, Vector < String > labels) {
307312
try {
308-
List<Map<String, Object>> result = new ArrayList<>();
309-
//utils.getScreenshotBmp(bitmap, "current");
310-
for (float[] box : yolo_result) {
311-
Map<String, Object> output = new HashMap<>();
312-
output.put("box", new float[]{box[0], box[1], box[2], box[3], box[4]}); //x1,y1,x2,y2,conf_class
313+
List < Map < String, Object >> result = new ArrayList < > ();
314+
for (float[] box: yolo_result) {
315+
Map < String, Object > output = new HashMap < > ();
316+
output.put("box", new float[] {
317+
box[0], box[1], box[2], box[3], box[4]
318+
});
313319
output.put("tag", labels.get((int) box[5]));
314320
result.add(output);
315321
}
@@ -329,4 +335,4 @@ public void close() {
329335
Log.e("Yolo", "Interpreter close error", e);
330336
}
331337
}
332-
}
338+
}

0 commit comments

Comments
 (0)