Skip to content

Commit 293f2b9

Browse files
lu-wang-gtflite-support-robot
authored andcommitted
Task delegate 9: Expose BaseOptions in Java AudioClassifier
PiperOrigin-RevId: 400353942
1 parent c1d709a commit 293f2b9

File tree

5 files changed

+60
-12
lines changed

5 files changed

+60
-12
lines changed

tensorflow_lite_support/cc/task/audio/audio_classifier.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,13 @@ namespace audio {
5252
// `ImageClassifierOptions` used at creation time ("en" by default, i.e.
5353
// English). If none of these are available, only the `index` field of the
5454
// results will be filled.
55+
//
56+
// An example of such model can be found at:
57+
// https://tfhub.dev/google/lite-model/yamnet/classification/tflite/1
58+
59+
// A CLI demo tool is available for easily trying out this API, and provides
60+
// example usage. See:
61+
// https://github.com/tensorflow/tflite-support/tree/master/tensorflow_lite_support/examples/task/audio/desktop
5562
class AudioClassifier
5663
: public tflite::task::core::BaseTaskApi<ClassificationResult,
5764
const AudioBuffer&> {

tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/AudioClassifier.java

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.tensorflow.lite.support.audio.TensorAudio;
3535
import org.tensorflow.lite.support.audio.TensorAudio.TensorAudioFormat;
3636
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
37+
import org.tensorflow.lite.task.core.BaseOptions;
3738
import org.tensorflow.lite.task.core.BaseTaskApi;
3839
import org.tensorflow.lite.task.core.TaskJniUtils;
3940
import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider;
@@ -64,8 +65,12 @@
6465
* use index as label in the result.
6566
* </ul>
6667
* </ul>
68+
*
69+
* See <a href="https://tfhub.dev/google/lite-model/yamnet/classification/tflite/1">an example</a>
70+
* of such model, and <a
71+
* href="https://github.com/tensorflow/tflite-support/tree/master/tensorflow_lite_support/examples/task/audio/desktop">a
72+
* CLI demo tool</a> for easily trying out this API.
6773
*/
68-
// TODO(b/182535933): Add a model example and demo comments here.
6974
public final class AudioClassifier extends BaseTaskApi {
7075

7176
private static final String AUDIO_CLASSIFIER_NATIVE_LIB = "task_audio_jni";
@@ -133,7 +138,11 @@ public long createHandle(
133138
long fileDescriptorOffset,
134139
AudioClassifierOptions options) {
135140
return initJniWithModelFdAndOptions(
136-
fileDescriptor, fileDescriptorLength, fileDescriptorOffset, options);
141+
fileDescriptor,
142+
fileDescriptorLength,
143+
fileDescriptorOffset,
144+
options,
145+
TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()));
137146
}
138147
},
139148
AUDIO_CLASSIFIER_NATIVE_LIB,
@@ -162,7 +171,8 @@ public long createHandle() {
162171
descriptor.getFd(),
163172
/*fileDescriptorLength=*/ OPTIONAL_FD_LENGTH,
164173
/*fileDescriptorOffset=*/ OPTIONAL_FD_OFFSET,
165-
options);
174+
options,
175+
TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()));
166176
}
167177
},
168178
AUDIO_CLASSIFIER_NATIVE_LIB));
@@ -191,7 +201,10 @@ public static AudioClassifier createFromBufferAndOptions(
191201
new EmptyHandleProvider() {
192202
@Override
193203
public long createHandle() {
194-
return initJniWithByteBuffer(modelBuffer, options);
204+
return initJniWithByteBuffer(
205+
modelBuffer,
206+
options,
207+
TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()));
195208
}
196209
},
197210
AUDIO_CLASSIFIER_NATIVE_LIB));
@@ -215,6 +228,7 @@ public static class AudioClassifierOptions {
215228
// 1. java.util.Optional require Java 8 while we need to support Java 7.
216229
// 2. The Guava library (com.google.common.base.Optional) is avoided in this project. See the
217230
// comments for labelAllowList.
231+
private final BaseOptions baseOptions;
218232
private final String displayNamesLocale;
219233
private final int maxResults;
220234
private final float scoreThreshold;
@@ -233,6 +247,7 @@ public static Builder builder() {
233247

234248
/** A builder that helps to configure an instance of AudioClassifierOptions. */
235249
public static class Builder {
250+
private BaseOptions baseOptions = BaseOptions.builder().build();
236251
private String displayNamesLocale = "en";
237252
private int maxResults = -1;
238253
private float scoreThreshold;
@@ -242,6 +257,12 @@ public static class Builder {
242257

243258
private Builder() {}
244259

260+
/** Sets the general options to configure Task APIs, such as accelerators. */
261+
public Builder setBaseOptions(BaseOptions baseOptions) {
262+
this.baseOptions = baseOptions;
263+
return this;
264+
}
265+
245266
/**
246267
* Sets the locale to use for display names specified through the TFLite Model Metadata, if
247268
* any.
@@ -339,13 +360,18 @@ public List<String> getLabelDenyList() {
339360
return new ArrayList<>(labelDenyList);
340361
}
341362

363+
public BaseOptions getBaseOptions() {
364+
return baseOptions;
365+
}
366+
342367
private AudioClassifierOptions(Builder builder) {
343368
displayNamesLocale = builder.displayNamesLocale;
344369
maxResults = builder.maxResults;
345370
scoreThreshold = builder.scoreThreshold;
346371
isScoreThresholdSet = builder.isScoreThresholdSet;
347372
labelAllowList = builder.labelAllowList;
348373
labelDenyList = builder.labelDenyList;
374+
baseOptions = builder.baseOptions;
349375
}
350376
}
351377

@@ -485,10 +511,11 @@ private static native long initJniWithModelFdAndOptions(
485511
int fileDescriptor,
486512
long fileDescriptorLength,
487513
long fileDescriptorOffset,
488-
AudioClassifierOptions options);
514+
AudioClassifierOptions options,
515+
long baseOptionsHandle);
489516

490517
private static native long initJniWithByteBuffer(
491-
ByteBuffer modelBuffer, AudioClassifierOptions options);
518+
ByteBuffer modelBuffer, AudioClassifierOptions options, long baseOptionsHandle);
492519

493520
/**
494521
* Releases memory pointed by {@code nativeHandle}, namely a C++ `AudioClassifier` instance.

tensorflow_lite_support/java/src/native/task/audio/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ tflite_jni_binary(
1616
name = "libtask_audio_jni.so",
1717
srcs = [
1818
"//tensorflow_lite_support/java/src/native/task/audio/classifier:audio_classifier_jni.cc",
19+
"//tensorflow_lite_support/java/src/native/task/core:task_jni_utils.cc",
1920
],
2021
linkscript = "//tensorflow_lite_support/java:default_version_script.lds",
2122
deps = [
@@ -25,6 +26,7 @@ tflite_jni_binary(
2526
"//tensorflow_lite_support/cc/task/audio/proto:audio_classifier_options_cc_proto",
2627
"//tensorflow_lite_support/cc/task/audio/proto:class_proto_inc",
2728
"//tensorflow_lite_support/cc/task/audio/proto:classifications_proto_inc",
29+
"//tensorflow_lite_support/cc/task/core/proto:base_options_proto_inc",
2830
"//tensorflow_lite_support/cc/utils:jni_utils",
2931
"//tensorflow_lite_support/java/jni",
3032
],

tensorflow_lite_support/java/src/native/task/audio/classifier/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ tflite_jni_binary(
1818
name = "libtask_audio_jni.so",
1919
srcs = [
2020
"audio_classifier_jni.cc",
21+
"//tensorflow_lite_support/java/src/native/task/core:task_jni_utils.cc",
2122
],
2223
linkscript = "//tensorflow_lite_support/java:default_version_script.lds",
2324
deps = [
@@ -27,6 +28,7 @@ tflite_jni_binary(
2728
"//tensorflow_lite_support/cc/task/audio/proto:audio_classifier_options_cc_proto",
2829
"//tensorflow_lite_support/cc/task/audio/proto:class_proto_inc",
2930
"//tensorflow_lite_support/cc/task/audio/proto:classifications_proto_inc",
31+
"//tensorflow_lite_support/cc/task/core/proto:base_options_proto_inc",
3032
"//tensorflow_lite_support/cc/utils:jni_utils",
3133
"//tensorflow_lite_support/java/jni",
3234
],

tensorflow_lite_support/java/src/native/task/audio/classifier/audio_classifier_jni.cc

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ limitations under the License.
2424
#include "tensorflow_lite_support/cc/task/audio/proto/audio_classifier_options.pb.h"
2525
#include "tensorflow_lite_support/cc/task/audio/proto/class_proto_inc.h"
2626
#include "tensorflow_lite_support/cc/task/audio/proto/classifications_proto_inc.h"
27+
#include "tensorflow_lite_support/cc/task/core/proto/base_options_proto_inc.h"
2728
#include "tensorflow_lite_support/cc/utils/jni_utils.h"
2829

2930
namespace {
@@ -38,6 +39,7 @@ using ::tflite::task::audio::AudioClassifier;
3839
using ::tflite::task::audio::AudioClassifierOptions;
3940
using ::tflite::task::audio::Class;
4041
using ::tflite::task::audio::ClassificationResult;
42+
using ::tflite::task::core::BaseOptions;
4143

4244
// TODO(b/183343074): Share the code below with ImageClassifier.
4345

@@ -123,9 +125,16 @@ jobject ConvertToClassificationResults(JNIEnv* env,
123125
}
124126

125127
// Creates an AudioClassifierOptions proto based on the Java class.
126-
AudioClassifierOptions ConvertToProtoOptions(JNIEnv* env,
127-
jobject java_options) {
128+
AudioClassifierOptions ConvertToProtoOptions(JNIEnv* env, jobject java_options,
129+
jlong base_options_handle) {
128130
AudioClassifierOptions proto_options;
131+
132+
if (base_options_handle != kInvalidPointer) {
133+
// proto_options will free the previous base_options and set the new one.
134+
proto_options.set_allocated_base_options(
135+
reinterpret_cast<BaseOptions*>(base_options_handle));
136+
}
137+
129138
jclass java_options_class = env->FindClass(
130139
"org/tensorflow/lite/task/audio/classifier/"
131140
"AudioClassifier$AudioClassifierOptions");
@@ -202,9 +211,9 @@ extern "C" JNIEXPORT jlong JNICALL
202211
Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_initJniWithModelFdAndOptions(
203212
JNIEnv* env, jclass thiz, jint file_descriptor,
204213
jlong file_descriptor_length, jlong file_descriptor_offset,
205-
jobject java_options) {
214+
jobject java_options, jlong base_options_handle) {
206215
AudioClassifierOptions proto_options =
207-
ConvertToProtoOptions(env, java_options);
216+
ConvertToProtoOptions(env, java_options, base_options_handle);
208217
auto file_descriptor_meta = proto_options.mutable_base_options()
209218
->mutable_model_file()
210219
->mutable_file_descriptor_meta();
@@ -220,9 +229,10 @@ Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_initJniWithModelF
220229

221230
extern "C" JNIEXPORT jlong JNICALL
222231
Java_org_tensorflow_lite_task_audio_classifier_AudioClassifier_initJniWithByteBuffer(
223-
JNIEnv* env, jclass thiz, jobject model_buffer, jobject java_options) {
232+
JNIEnv* env, jclass thiz, jobject model_buffer, jobject java_options,
233+
jlong base_options_handle) {
224234
AudioClassifierOptions proto_options =
225-
ConvertToProtoOptions(env, java_options);
235+
ConvertToProtoOptions(env, java_options, base_options_handle);
226236
// External proto generated header does not overload `set_file_content` with
227237
// string_view, therefore GetMappedFileBuffer does not apply here.
228238
// Creating a std::string will cause one extra copying of data. Thus, the

0 commit comments

Comments
 (0)