Skip to content

Commit 970fbf6

Browse files
author
Persie
committed
fix
1 parent dcc677a commit 970fbf6

File tree

1 file changed

+14
-12
lines changed
  • android/src/main/java/com/vladih/computer_vision/flutter_vision/models

1 file changed

+14
-12
lines changed

android/src/main/java/com/vladih/computer_vision/flutter_vision/models/Yolo.java

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.tensorflow.lite.Tensor;
1212
import org.tensorflow.lite.gpu.CompatibilityList;
1313
import org.tensorflow.lite.gpu.GpuDelegate;
14+
import org.tensorflow.lite.gpu.GpuDelegateFactory;
1415

1516
import java.io.BufferedReader;
1617
import java.io.FileInputStream;
@@ -88,24 +89,25 @@ public void initialize_model() throws Exception {
8889
buffer = file_channel.map(FileChannel.MapMode.READ_ONLY, 0, file_channel.size());
8990
}
9091

91-
// Initialize interpreter with GPU delegate
9292
Interpreter.Options interpreterOptions = new Interpreter.Options();
93-
CompatibilityList compatList = new CompatibilityList();
94-
95-
if (use_gpu && compatList.isDelegateSupportedOnThisDevice()) {
96-
try {
97-
GpuDelegate.Options delegateOptions = compatList.getBestOptionsForThisDevice();
98-
GpuDelegate gpuDelegate = new GpuDelegate(delegateOptions);
93+
try {
94+
// Check if GPU support is available
95+
CompatibilityList compatibilityList = new CompatibilityList();
96+
if (use_gpu && compatibilityList.isDelegateSupportedOnThisDevice()) {
97+
GpuDelegateFactory.Options delegateOptions = compatibilityList.getBestOptionsForThisDevice();
98+
GpuDelegate gpuDelegate = new GpuDelegate(delegateOptions.setQuantizedModelsAllowed(this.quantization));
9999
interpreterOptions.addDelegate(gpuDelegate);
100-
} catch (Exception e) {
101-
Log.e("Yolo", "GPU delegate failed, falling back to CPU", e);
100+
} else {
102101
interpreterOptions.setNumThreads(num_threads);
103102
}
104-
} else {
103+
// Create the interpreter
104+
this.interpreter = new Interpreter(buffer, interpreterOptions);
105+
} catch (Exception e) {
106+
interpreterOptions = new Interpreter.Options();
105107
interpreterOptions.setNumThreads(num_threads);
108+
// Create the interpreter
109+
this.interpreter = new Interpreter(buffer, interpreterOptions);
106110
}
107-
108-
this.interpreter = new Interpreter(buffer, interpreterOptions);
109111
this.interpreter.allocateTensors();
110112
this.labels = load_labels(asset_manager, label_path);
111113
int[] shape = interpreter.getOutputTensor(0).shape();

0 commit comments

Comments
 (0)