1717
1818import android .content .Context ;
1919import android .os .ParcelFileDescriptor ;
20+ import com .google .auto .value .AutoValue ;
2021import java .io .File ;
2122import java .io .IOException ;
2223import java .nio .ByteBuffer ;
2324import java .util .List ;
25+ import org .tensorflow .lite .task .core .BaseOptions ;
2426import org .tensorflow .lite .task .core .BaseTaskApi ;
2527import org .tensorflow .lite .task .core .TaskJniUtils ;
2628import org .tensorflow .lite .task .core .TaskJniUtils .EmptyHandleProvider ;
29+ import org .tensorflow .lite .task .core .TaskJniUtils .FdAndOptionsHandleProvider ;
2730import org .tensorflow .lite .task .core .TaskJniUtils .MultipleBuffersHandleProvider ;
2831
2932/**
@@ -60,22 +63,62 @@ public class BertQuestionAnswerer extends BaseTaskApi implements QuestionAnswere
6063 */
6164 public static BertQuestionAnswerer createFromFile (Context context , String modelPath )
6265 throws IOException {
66+ return createFromFileAndOptions (
67+ context , modelPath , BertQuestionAnswererOptions .builder ().build ());
68+ }
69+
70+ /**
71+ * Creates a {@link BertQuestionAnswerer} instance from the default {@link
72+ * BertQuestionAnswererOptions}.
73+ *
74+ * @param modelFile a {@link File} object of the model
75+ * @return a {@link BertQuestionAnswerer} instance
76+ * @throws IOException if model file fails to load
77+ * @throws IllegalArgumentException if an argument is invalid
78+ * @throws IllegalStateException if there is an internal error
79+ * @throws RuntimeException if there is an otherwise unspecified error
80+ */
81+ public static BertQuestionAnswerer createFromFile (File modelFile ) throws IOException {
82+ return createFromFileAndOptions (modelFile , BertQuestionAnswererOptions .builder ().build ());
83+ }
84+
85+ /**
86+ * Creates a {@link BertQuestionAnswerer} instance from {@link BertQuestionAnswererOptions}.
87+ *
88+ * @param context android context
89+ * @param modelPath file path to the model with metadata. Note: The model should not be compressed
90+ * @return a {@link BertQuestionAnswerer} instance
91+ * @throws IOException if model file fails to load
92+ * @throws IllegalArgumentException if an argument is invalid
93+ * @throws IllegalStateException if there is an internal error
94+ * @throws RuntimeException if there is an otherwise unspecified error
95+ */
96+ public static BertQuestionAnswerer createFromFileAndOptions (
97+ Context context , String modelPath , BertQuestionAnswererOptions options ) throws IOException {
6398 return new BertQuestionAnswerer (
64- TaskJniUtils .createHandleWithMultipleAssetFilesFromLibrary (
99+ TaskJniUtils .createHandleFromFdAndOptions (
65100 context ,
66- new MultipleBuffersHandleProvider () {
101+ new FdAndOptionsHandleProvider < BertQuestionAnswererOptions > () {
67102 @ Override
68- public long createHandle (ByteBuffer ... buffers ) {
69- return BertQuestionAnswerer .initJniWithModelWithMetadataByteBuffers (buffers );
103+ public long createHandle (
104+ int fileDescriptor ,
105+ long fileDescriptorLength ,
106+ long fileDescriptorOffset ,
107+ BertQuestionAnswererOptions options ) {
108+ return initJniWithFileDescriptor (
109+ fileDescriptor ,
110+ fileDescriptorLength ,
111+ fileDescriptorOffset ,
112+ TaskJniUtils .createProtoBaseOptionsHandle (options .getBaseOptions ()));
70113 }
71114 },
72115 BERT_QUESTION_ANSWERER_NATIVE_LIBNAME ,
73- modelPath ));
116+ modelPath ,
117+ options ));
74118 }
75119
76120 /**
77- * Creates a {@link BertQuestionAnswerer} instance from the default {@link
78- * BertQuestionAnswererOptions}.
121+ * Creates a {@link BertQuestionAnswerer} instance from {@link BertQuestionAnswererOptions}.
79122 *
80123 * @param modelFile a {@link File} object of the model
81124 * @return a {@link BertQuestionAnswerer} instance
@@ -84,7 +127,8 @@ public long createHandle(ByteBuffer... buffers) {
84127 * @throws IllegalStateException if there is an internal error
85128 * @throws RuntimeException if there is an otherwise unspecified error
86129 */
87- public static BertQuestionAnswerer createFromFile (File modelFile ) throws IOException {
130+ public static BertQuestionAnswerer createFromFileAndOptions (
131+ File modelFile , final BertQuestionAnswererOptions options ) throws IOException {
88132 try (ParcelFileDescriptor descriptor =
89133 ParcelFileDescriptor .open (modelFile , ParcelFileDescriptor .MODE_READ_ONLY )) {
90134 return new BertQuestionAnswerer (
@@ -95,7 +139,8 @@ public long createHandle() {
95139 return initJniWithFileDescriptor (
96140 /*fileDescriptor=*/ descriptor .getFd (),
97141 /*fileDescriptorLength=*/ OPTIONAL_FD_LENGTH ,
98- /*fileDescriptorOffset=*/ OPTIONAL_FD_OFFSET );
142+ /*fileDescriptorOffset=*/ OPTIONAL_FD_OFFSET ,
143+ TaskJniUtils .createProtoBaseOptionsHandle (options .getBaseOptions ()));
99144 }
100145 },
101146 BERT_QUESTION_ANSWERER_NATIVE_LIBNAME ));
@@ -164,6 +209,26 @@ public long createHandle(ByteBuffer... buffers) {
164209 sentencePieceModelPath ));
165210 }
166211
212+ /** Options for setting up a {@link BertQuestionAnswerer}. */
213+ @ AutoValue
214+ public abstract static class BertQuestionAnswererOptions {
215+ abstract BaseOptions getBaseOptions ();
216+
217+ public static Builder builder () {
218+ return new AutoValue_BertQuestionAnswerer_BertQuestionAnswererOptions .Builder ()
219+ .setBaseOptions (BaseOptions .builder ().build ());
220+ }
221+
222+ /** Builder for {@link BertQuestionAnswererOptions}. */
223+ @ AutoValue .Builder
224+ public abstract static class Builder {
225+ /** Sets the general options to configure Task APIs, such as accelerators. */
226+ public abstract Builder setBaseOptions (BaseOptions baseOptions );
227+
228+ public abstract BertQuestionAnswererOptions build ();
229+ }
230+ }
231+
167232 @ Override
168233 public List <QaAnswer > answer (String context , String question ) {
169234 checkNotClosed ();
@@ -181,11 +246,11 @@ private BertQuestionAnswerer(long nativeHandle) {
181246 // buffer.
182247 private static native long initJniWithAlbertByteBuffers (ByteBuffer ... modelBuffers );
183248
184- // modelBuffers[0] is tflite model file buffer with metadata to specify which tokenizer to use.
185- private static native long initJniWithModelWithMetadataByteBuffers (ByteBuffer ... modelBuffers );
186-
187249 private static native long initJniWithFileDescriptor (
188- int fileDescriptor , long fileDescriptorLength , long fileDescriptorOffset );
250+ int fileDescriptor ,
251+ long fileDescriptorLength ,
252+ long fileDescriptorOffset ,
253+ long baseOptionsHandle );
189254
190255 private static native List <QaAnswer > answerNative (
191256 long nativeHandle , String context , String question );
0 commit comments