|
11 | 11 | import org.tensorflow.lite.Tensor;
|
12 | 12 | import org.tensorflow.lite.gpu.CompatibilityList;
|
13 | 13 | import org.tensorflow.lite.gpu.GpuDelegate;
|
| 14 | +import org.tensorflow.lite.gpu.GpuDelegateFactory; |
14 | 15 |
|
15 | 16 | import java.io.BufferedReader;
|
16 | 17 | import java.io.FileInputStream;
|
@@ -88,24 +89,25 @@ public void initialize_model() throws Exception {
|
88 | 89 | buffer = file_channel.map(FileChannel.MapMode.READ_ONLY, 0, file_channel.size());
|
89 | 90 | }
|
90 | 91 |
|
91 |
| - // Initialize interpreter with GPU delegate |
92 | 92 | Interpreter.Options interpreterOptions = new Interpreter.Options();
|
93 |
| - CompatibilityList compatList = new CompatibilityList(); |
94 |
| - |
95 |
| - if (use_gpu && compatList.isDelegateSupportedOnThisDevice()) { |
96 |
| - try { |
97 |
| - GpuDelegate.Options delegateOptions = compatList.getBestOptionsForThisDevice(); |
98 |
| - GpuDelegate gpuDelegate = new GpuDelegate(delegateOptions); |
| 93 | + try { |
| 94 | + // Check if GPU support is available |
| 95 | + CompatibilityList compatibilityList = new CompatibilityList(); |
| 96 | + if (use_gpu && compatibilityList.isDelegateSupportedOnThisDevice()) { |
| 97 | + GpuDelegateFactory.Options delegateOptions = compatibilityList.getBestOptionsForThisDevice(); |
| 98 | + GpuDelegate gpuDelegate = new GpuDelegate(delegateOptions.setQuantizedModelsAllowed(this.quantization)); |
99 | 99 | interpreterOptions.addDelegate(gpuDelegate);
|
100 |
| - } catch (Exception e) { |
101 |
| - Log.e("Yolo", "GPU delegate failed, falling back to CPU", e); |
| 100 | + } else { |
102 | 101 | interpreterOptions.setNumThreads(num_threads);
|
103 | 102 | }
|
104 |
| - } else { |
| 103 | + // Create the interpreter |
| 104 | + this.interpreter = new Interpreter(buffer, interpreterOptions); |
| 105 | + } catch (Exception e) { |
| 106 | + interpreterOptions = new Interpreter.Options(); |
105 | 107 | interpreterOptions.setNumThreads(num_threads);
|
| 108 | + // Create the interpreter |
| 109 | + this.interpreter = new Interpreter(buffer, interpreterOptions); |
106 | 110 | }
|
107 |
| - |
108 |
| - this.interpreter = new Interpreter(buffer, interpreterOptions); |
109 | 111 | this.interpreter.allocateTensors();
|
110 | 112 | this.labels = load_labels(asset_manager, label_path);
|
111 | 113 | int[] shape = interpreter.getOutputTensor(0).shape();
|
|
0 commit comments