Skip to content

Commit 196269e

Browse files
lu-wang-gtflite-support-robot
authored andcommitted
Support Multi-head audio classifier model in Task Java library
PiperOrigin-RevId: 370173607
1 parent 028105d commit 196269e

File tree

2 files changed

+17
-7
lines changed

2 files changed

+17
-7
lines changed

tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/Classifications.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,16 @@
3333
public abstract class Classifications {
3434

3535
@UsedByReflection("audio_classifier_jni.cc")
36-
static Classifications create(List<Category> categories, int headIndex) {
36+
static Classifications create(List<Category> categories, int headIndex, String headName) {
3737
return new AutoValue_Classifications(
38-
Collections.unmodifiableList(new ArrayList<Category>(categories)), headIndex);
38+
Collections.unmodifiableList(new ArrayList<Category>(categories)), headIndex, headName);
3939
}
4040

4141
// Same reason for not using ImmutableList as stated in
4242
// {@link ImageClassifier#ImageClassifierOptions#labelAllowList}.
4343
public abstract List<Category> getCategories();
4444

4545
public abstract int getHeadIndex();
46+
47+
public abstract String getHeadName();
4648
}

tensorflow_lite_support/java/src/native/task/audio/classifier/audio_classifier_jni.cc

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,10 @@ jobject ConvertToClassificationResults(JNIEnv* env,
7777
// jclass and init of Classifications.
7878
jclass classifications_class = env->FindClass(
7979
"org/tensorflow/lite/task/audio/classifier/Classifications");
80-
jmethodID classifications_create =
81-
env->GetStaticMethodID(classifications_class, "create",
82-
"(Ljava/util/List;I)Lorg/tensorflow/lite/"
83-
"task/audio/classifier/Classifications;");
80+
jmethodID classifications_create = env->GetStaticMethodID(
81+
classifications_class, "create",
82+
"(Ljava/util/List;ILjava/lang/String;)Lorg/tensorflow/lite/"
83+
"task/audio/classifier/Classifications;");
8484

8585
// jclass, init, and add of ArrayList.
8686
jclass array_list_class = env->FindClass("java/util/ArrayList");
@@ -102,12 +102,20 @@ jobject ConvertToClassificationResults(JNIEnv* env,
102102

103103
env->DeleteLocalRef(jcategory);
104104
}
105+
106+
std::string head_name_string =
107+
classifications.has_head_name()
108+
? classifications.head_name()
109+
: std::to_string(classifications.head_index());
110+
jstring head_name = env->NewStringUTF(head_name_string.c_str());
111+
105112
jobject jclassifications = env->CallStaticObjectMethod(
106113
classifications_class, classifications_create, jcategory_list,
107-
classifications.head_index());
114+
classifications.head_index(), head_name);
108115
env->CallBooleanMethod(classifications_list, array_list_add_method,
109116
jclassifications);
110117

118+
env->DeleteLocalRef(head_name);
111119
env->DeleteLocalRef(jcategory_list);
112120
env->DeleteLocalRef(jclassifications);
113121
}

0 commit comments

Comments
 (0)