1919package org .apache .flink .model .openai ;
2020
2121import org .apache .flink .configuration .ReadableConfig ;
22+ import org .apache .flink .table .api .DataTypes ;
2223import org .apache .flink .table .api .config .ExecutionConfigOptions ;
2324import org .apache .flink .table .catalog .Column ;
2425import org .apache .flink .table .catalog .ResolvedSchema ;
26+ import org .apache .flink .table .data .GenericRowData ;
2527import org .apache .flink .table .data .RowData ;
28+ import org .apache .flink .table .data .binary .BinaryStringData ;
2629import org .apache .flink .table .factories .ModelProviderFactory ;
2730import org .apache .flink .table .functions .AsyncPredictFunction ;
2831import org .apache .flink .table .functions .FunctionContext ;
32+ import org .apache .flink .table .types .DataType ;
2933import org .apache .flink .table .types .logical .LogicalType ;
3034import org .apache .flink .table .types .logical .VarCharType ;
3135import org .apache .flink .util .ExceptionUtils ;
4145
4246import java .io .IOException ;
4347import java .time .Duration ;
48+ import java .util .Arrays ;
4449import java .util .Collection ;
4550import java .util .Collections ;
4651import java .util .HashSet ;
@@ -78,6 +83,7 @@ public abstract class AbstractOpenAIModelFunction extends AsyncPredictFunction {
7883 private final String model ;
7984 @ Nullable private final Integer maxContextSize ;
8085 private final ContextOverflowAction contextOverflowAction ;
86+ protected final List <String > outputColumnNames ;
8187
8288 public AbstractOpenAIModelFunction (
8389 ModelProviderFactory .Context factoryContext , ReadableConfig config ) {
@@ -140,6 +146,9 @@ public AbstractOpenAIModelFunction(
140146 factoryContext .getCatalogModel ().getResolvedInputSchema (),
141147 new VarCharType (VarCharType .MAX_LENGTH ),
142148 "input" );
149+
150+ this .outputColumnNames =
151+ factoryContext .getCatalogModel ().getResolvedOutputSchema ().getColumnNames ();
143152 }
144153
145154 @ Override
@@ -184,23 +193,19 @@ public void close() throws Exception {
184193 protected void validateSingleColumnSchema (
185194 ResolvedSchema schema , LogicalType expectedType , String inputOrOutput ) {
186195 List <Column > columns = schema .getColumns ();
187- if (columns .size () != 1 ) {
196+ List <String > physicalColumnNames =
197+ columns .stream ()
198+ .filter (Column ::isPhysical )
199+ .map (Column ::getName )
200+ .collect (Collectors .toList ());
201+ if (physicalColumnNames .size () != 1 ) {
188202 throw new IllegalArgumentException (
189203 String .format (
190- "Model should have exactly one %s column, but actually has %s columns: %s" ,
191- inputOrOutput ,
192- columns .size (),
193- columns .stream ().map (Column ::getName ).collect (Collectors .toList ())));
194- }
195-
196- Column column = columns .get (0 );
197- if (!column .isPhysical ()) {
198- throw new IllegalArgumentException (
199- String .format (
200- "%s column %s should be a physical column, but is a %s." ,
201- inputOrOutput , column .getName (), column .getClass ()));
204+ "Model should have exactly one %s physical column, but actually has %s physical columns: %s" ,
205+ inputOrOutput , physicalColumnNames .size (), physicalColumnNames ));
202206 }
203207
208+ Column column = schema .getColumn (physicalColumnNames .get (0 )).get ();
204209 if (!expectedType .equals (column .getDataType ().getLogicalType ())) {
205210 throw new IllegalArgumentException (
206211 String .format (
@@ -210,6 +215,33 @@ protected void validateSingleColumnSchema(
210215 expectedType ,
211216 column .getDataType ().getLogicalType ()));
212217 }
218+
219+ List <Column > metadataColumns =
220+ columns .stream ()
221+ .filter (x -> x instanceof Column .MetadataColumn )
222+ .collect (Collectors .toList ());
223+ if (!metadataColumns .isEmpty ()) {
224+ Preconditions .checkArgument (
225+ "output" .equals (inputOrOutput ), "Only output schema supports metadata column" );
226+
227+ for (Column metadataColumn : metadataColumns ) {
228+ ErrorMessageMetadata errorMessageMetadata =
229+ ErrorMessageMetadata .get (metadataColumn .getName ());
230+ Preconditions .checkNotNull (
231+ errorMessageMetadata ,
232+ String .format (
233+ "Unexpected metadata column %s. Supported metadata columns:\n %s" ,
234+ metadataColumn .getName (),
235+ ErrorMessageMetadata .getAllKeysAndDescriptions ()));
236+ Preconditions .checkArgument (
237+ errorMessageMetadata .dataType .equals (metadataColumn .getDataType ()),
238+ String .format (
239+ "Expected metadata column %s to be of type %s, but is of type %s" ,
240+ metadataColumn .getName (),
241+ errorMessageMetadata .dataType ,
242+ metadataColumn .getDataType ()));
243+ }
244+ }
213245 }
214246
215247 /**
@@ -223,30 +255,52 @@ protected void validateSingleColumnSchema(
223255 * appropriate retry and error handling applied, or a null value if the request failed in
224256 * the middle and the failure should be ignored.
225257 */
226- protected <T > CompletableFuture <T > sendAsyncOpenAIRequest (
227- Supplier <CompletableFuture <T >> requestSender ) {
258+ protected <T > CompletableFuture <Collection <RowData >> sendAsyncOpenAIRequest (
259+ Supplier <CompletableFuture <T >> requestSender ,
260+ Function <T , Collection <RowData >> converter ) {
228261 CompletableFuture <T > result =
229262 retryAsync (
230263 requestSender ,
231264 numRetry ,
232265 retryBackoffBaseIntervalMs ,
233266 retryBackoffStrategy ,
234267 null );
235- ErrorHandlingStrategy finalErrorHandlingStrategy =
236- this .errorHandlingStrategy == ErrorHandlingStrategy .RETRY
237- ? this .retryFallbackStrategy
238- : this .errorHandlingStrategy ;
239- if (finalErrorHandlingStrategy == ErrorHandlingStrategy .IGNORE ) {
240- result =
241- result .exceptionally (
242- (e ) -> {
243- LOG .warn (
244- "The input row data failed to acquire a valid response. Ignoring the input." ,
245- e );
246- return null ;
247- });
268+ return result .handle ((x , throwable ) -> this .convertToRowData (x , throwable , converter ));
269+ }
270+
271+ private <T > Collection <RowData > convertToRowData (
272+ @ Nullable T t ,
273+ @ Nullable Throwable throwable ,
274+ Function <T , Collection <RowData >> converter ) {
275+ if (throwable != null ) {
276+ ErrorHandlingStrategy finalErrorHandlingStrategy =
277+ this .errorHandlingStrategy == ErrorHandlingStrategy .RETRY
278+ ? this .retryFallbackStrategy
279+ : this .errorHandlingStrategy ;
280+ if (finalErrorHandlingStrategy == ErrorHandlingStrategy .FAILOVER ) {
281+ throw new RuntimeException (throwable );
282+ } else {
283+ LOG .warn (
284+ "The input row data failed to acquire a valid response. Ignoring the input." ,
285+ throwable );
286+ GenericRowData rowData = new GenericRowData (this .outputColumnNames .size ());
287+ boolean isMetadataSet = false ;
288+ for (int i = 0 ; i < this .outputColumnNames .size (); i ++) {
289+ String columnName = this .outputColumnNames .get (i );
290+ ErrorMessageMetadata errorMessageMetadata =
291+ ErrorMessageMetadata .get (columnName );
292+ if (errorMessageMetadata != null ) {
293+ rowData .setField (i , errorMessageMetadata .converter .apply (throwable ));
294+ isMetadataSet = true ;
295+ }
296+ }
297+ return isMetadataSet ? Collections .singletonList (rowData ) : Collections .emptyList ();
298+ }
299+ } else if (t == null ) {
300+ return Collections .emptyList ();
301+ } else {
302+ return converter .apply (t );
248303 }
249- return result ;
250304 }
251305
252306 private <T > CompletableFuture <T > retryAsync (
@@ -348,4 +402,55 @@ public long getMinRetryTotalTime(long baseRetryInterval, int numRetry) {
348402
349403 public abstract long getMinRetryTotalTime (long baseRetryInterval , int numRetry );
350404 }
405+
406+ /**
407+ * Metadata that can be read from the output row about error messages. Referenced from Flink
408+ * HTTP Connector's ReadableMetadata.
409+ */
410+ protected enum ErrorMessageMetadata {
411+ ERROR_STRING (
412+ "error-string" ,
413+ DataTypes .STRING (),
414+ x -> BinaryStringData .fromString (x .getMessage ()),
415+ "A message associated with the error" ),
416+ HTTP_STATUS_CODE (
417+ "http-status-code" ,
418+ DataTypes .INT (),
419+ e ->
420+ ExceptionUtils .findThrowable (e , OpenAIServiceException .class )
421+ .map (OpenAIServiceException ::statusCode )
422+ .orElse (null ),
423+ "The HTTP status code" );
424+
425+ final String key ;
426+ final DataType dataType ;
427+ final Function <Throwable , Object > converter ;
428+ final String description ;
429+
430+ ErrorMessageMetadata (
431+ String key ,
432+ DataType dataType ,
433+ Function <Throwable , Object > converter ,
434+ String description ) {
435+ this .key = key ;
436+ this .dataType = dataType ;
437+ this .converter = converter ;
438+ this .description = description ;
439+ }
440+
441+ static @ Nullable ErrorMessageMetadata get (String key ) {
442+ for (ErrorMessageMetadata value : values ()) {
443+ if (value .key .equals (key )) {
444+ return value ;
445+ }
446+ }
447+ return null ;
448+ }
449+
450+ static String getAllKeysAndDescriptions () {
451+ return Arrays .stream (values ())
452+ .map (value -> value .key + ":\t " + value .description )
453+ .collect (Collectors .joining ("\n " ));
454+ }
455+ }
351456}
0 commit comments