3434import org .tensorflow .lite .support .audio .TensorAudio ;
3535import org .tensorflow .lite .support .audio .TensorAudio .TensorAudioFormat ;
3636import org .tensorflow .lite .support .tensorbuffer .TensorBuffer ;
37+ import org .tensorflow .lite .task .core .BaseOptions ;
3738import org .tensorflow .lite .task .core .BaseTaskApi ;
3839import org .tensorflow .lite .task .core .TaskJniUtils ;
3940import org .tensorflow .lite .task .core .TaskJniUtils .EmptyHandleProvider ;
6465 * use index as label in the result.
6566 * </ul>
6667 * </ul>
68+ *
69+ * See <a href="https://tfhub.dev/google/lite-model/yamnet/classification/tflite/1">an example</a>
70+ * of such model, and <a
71+ * href="https://github.com/tensorflow/tflite-support/tree/master/tensorflow_lite_support/examples/task/audio/desktop">a
72+ * CLI demo tool</a> for easily trying out this API.
6773 */
68- // TODO(b/182535933): Add a model example and demo comments here.
6974public final class AudioClassifier extends BaseTaskApi {
7075
7176 private static final String AUDIO_CLASSIFIER_NATIVE_LIB = "task_audio_jni" ;
@@ -133,7 +138,11 @@ public long createHandle(
133138 long fileDescriptorOffset ,
134139 AudioClassifierOptions options ) {
135140 return initJniWithModelFdAndOptions (
136- fileDescriptor , fileDescriptorLength , fileDescriptorOffset , options );
141+ fileDescriptor ,
142+ fileDescriptorLength ,
143+ fileDescriptorOffset ,
144+ options ,
145+ TaskJniUtils .createProtoBaseOptionsHandle (options .getBaseOptions ()));
137146 }
138147 },
139148 AUDIO_CLASSIFIER_NATIVE_LIB ,
@@ -162,7 +171,8 @@ public long createHandle() {
162171 descriptor .getFd (),
163172 /*fileDescriptorLength=*/ OPTIONAL_FD_LENGTH ,
164173 /*fileDescriptorOffset=*/ OPTIONAL_FD_OFFSET ,
165- options );
174+ options ,
175+ TaskJniUtils .createProtoBaseOptionsHandle (options .getBaseOptions ()));
166176 }
167177 },
168178 AUDIO_CLASSIFIER_NATIVE_LIB ));
@@ -191,7 +201,10 @@ public static AudioClassifier createFromBufferAndOptions(
191201 new EmptyHandleProvider () {
192202 @ Override
193203 public long createHandle () {
194- return initJniWithByteBuffer (modelBuffer , options );
204+ return initJniWithByteBuffer (
205+ modelBuffer ,
206+ options ,
207+ TaskJniUtils .createProtoBaseOptionsHandle (options .getBaseOptions ()));
195208 }
196209 },
197210 AUDIO_CLASSIFIER_NATIVE_LIB ));
@@ -215,6 +228,7 @@ public static class AudioClassifierOptions {
215228 // 1. java.util.Optional require Java 8 while we need to support Java 7.
216229 // 2. The Guava library (com.google.common.base.Optional) is avoided in this project. See the
217230 // comments for labelAllowList.
231+ private final BaseOptions baseOptions ;
218232 private final String displayNamesLocale ;
219233 private final int maxResults ;
220234 private final float scoreThreshold ;
@@ -233,6 +247,7 @@ public static Builder builder() {
233247
234248 /** A builder that helps to configure an instance of AudioClassifierOptions. */
235249 public static class Builder {
250+ private BaseOptions baseOptions = BaseOptions .builder ().build ();
236251 private String displayNamesLocale = "en" ;
237252 private int maxResults = -1 ;
238253 private float scoreThreshold ;
@@ -242,6 +257,12 @@ public static class Builder {
242257
243258 private Builder () {}
244259
260+ /** Sets the general options to configure Task APIs, such as accelerators. */
261+ public Builder setBaseOptions (BaseOptions baseOptions ) {
262+ this .baseOptions = baseOptions ;
263+ return this ;
264+ }
265+
245266 /**
246267 * Sets the locale to use for display names specified through the TFLite Model Metadata, if
247268 * any.
@@ -339,13 +360,18 @@ public List<String> getLabelDenyList() {
339360 return new ArrayList <>(labelDenyList );
340361 }
341362
363+ public BaseOptions getBaseOptions () {
364+ return baseOptions ;
365+ }
366+
342367 private AudioClassifierOptions (Builder builder ) {
343368 displayNamesLocale = builder .displayNamesLocale ;
344369 maxResults = builder .maxResults ;
345370 scoreThreshold = builder .scoreThreshold ;
346371 isScoreThresholdSet = builder .isScoreThresholdSet ;
347372 labelAllowList = builder .labelAllowList ;
348373 labelDenyList = builder .labelDenyList ;
374+ baseOptions = builder .baseOptions ;
349375 }
350376 }
351377
@@ -485,10 +511,11 @@ private static native long initJniWithModelFdAndOptions(
485511 int fileDescriptor ,
486512 long fileDescriptorLength ,
487513 long fileDescriptorOffset ,
488- AudioClassifierOptions options );
514+ AudioClassifierOptions options ,
515+ long baseOptionsHandle );
489516
490517 private static native long initJniWithByteBuffer (
491- ByteBuffer modelBuffer , AudioClassifierOptions options );
518+ ByteBuffer modelBuffer , AudioClassifierOptions options , long baseOptionsHandle );
492519
493520 /**
494521 * Releases memory pointed by {@code nativeHandle}, namely a C++ `AudioClassifier` instance.
0 commit comments