Skip to content

Commit 662e73f

Browse files
[FLINK-38581][model] Support surfacing error message
1 parent b4e212e commit 662e73f

File tree

7 files changed

+266
-60
lines changed

7 files changed

+266
-60
lines changed

docs/content.zh/docs/connectors/models/openai.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,12 @@ FROM ML_PREDICT(
168168
<ul>
169169
<li><code>retry</code>: 重试发送请求。重试行为受 retry-num、retry-fallback-strategy、retry-backoff-strategy 和 retry-backoff-base-interval 限制。</li>
170170
<li><code>failover</code>: 抛出异常并使 Flink 作业失败。</li>
171-
<li><code>ignore</code>: 忽略导致错误的输入并继续。错误本身将记录在日志中。</li>
171+
<li><code>ignore</code>: 忽略导致错误的输入并继续执行。错误本身将被记录在日志中。您还可以指定以下元数据列,以便在输出流中显示错误信息。
172+
<ul>
173+
<li><code>error-string</code>: 与错误相关的消息</li>
174+
<li><code>http-status-code</code>: HTTP状态码</li>
175+
</ul>
176+
</li>
172177
</ul>
173178
</td>
174179
</tr>

docs/content/docs/connectors/models/openai.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,12 @@ FROM ML_PREDICT(
168168
<ul>
169169
<li><code>retry</code>: Retry sending the request. The retrying behavior is limited by retry-num, retry-fallback-strategy, retry-backoff-strategy and retry-backoff-base-interval.</li>
170170
<li><code>failover</code>: Throw exceptions and fail the Flink job.</li>
171-
<li><code>ignore</code>: Ignore the input that caused the error and continue. The error itself would be recorded in log.</li>
171+
<li><code>ignore</code>: Ignore the input that caused the error and continue. The error itself would be recorded in log. You can also specify the following metadata columns to surface the error message in the output stream.
172+
<ul>
173+
<li><code>error-string</code>: A message associated with the error</li>
174+
<li><code>http-status-code</code>: The HTTP status code</li>
175+
</ul>
176+
</li>
172177
</ul>
173178
</td>
174179
</tr>

flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/AbstractOpenAIModelFunction.java

Lines changed: 134 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,17 @@
1919
package org.apache.flink.model.openai;
2020

2121
import org.apache.flink.configuration.ReadableConfig;
22+
import org.apache.flink.table.api.DataTypes;
2223
import org.apache.flink.table.api.config.ExecutionConfigOptions;
2324
import org.apache.flink.table.catalog.Column;
2425
import org.apache.flink.table.catalog.ResolvedSchema;
26+
import org.apache.flink.table.data.GenericRowData;
2527
import org.apache.flink.table.data.RowData;
28+
import org.apache.flink.table.data.binary.BinaryStringData;
2629
import org.apache.flink.table.factories.ModelProviderFactory;
2730
import org.apache.flink.table.functions.AsyncPredictFunction;
2831
import org.apache.flink.table.functions.FunctionContext;
32+
import org.apache.flink.table.types.DataType;
2933
import org.apache.flink.table.types.logical.LogicalType;
3034
import org.apache.flink.table.types.logical.VarCharType;
3135
import org.apache.flink.util.ExceptionUtils;
@@ -41,6 +45,7 @@
4145

4246
import java.io.IOException;
4347
import java.time.Duration;
48+
import java.util.Arrays;
4449
import java.util.Collection;
4550
import java.util.Collections;
4651
import java.util.HashSet;
@@ -78,6 +83,7 @@ public abstract class AbstractOpenAIModelFunction extends AsyncPredictFunction {
7883
private final String model;
7984
@Nullable private final Integer maxContextSize;
8085
private final ContextOverflowAction contextOverflowAction;
86+
protected final List<String> outputColumnNames;
8187

8288
public AbstractOpenAIModelFunction(
8389
ModelProviderFactory.Context factoryContext, ReadableConfig config) {
@@ -140,6 +146,9 @@ public AbstractOpenAIModelFunction(
140146
factoryContext.getCatalogModel().getResolvedInputSchema(),
141147
new VarCharType(VarCharType.MAX_LENGTH),
142148
"input");
149+
150+
this.outputColumnNames =
151+
factoryContext.getCatalogModel().getResolvedOutputSchema().getColumnNames();
143152
}
144153

145154
@Override
@@ -184,23 +193,19 @@ public void close() throws Exception {
184193
protected void validateSingleColumnSchema(
185194
ResolvedSchema schema, LogicalType expectedType, String inputOrOutput) {
186195
List<Column> columns = schema.getColumns();
187-
if (columns.size() != 1) {
196+
List<String> physicalColumnNames =
197+
columns.stream()
198+
.filter(Column::isPhysical)
199+
.map(Column::getName)
200+
.collect(Collectors.toList());
201+
if (physicalColumnNames.size() != 1) {
188202
throw new IllegalArgumentException(
189203
String.format(
190-
"Model should have exactly one %s column, but actually has %s columns: %s",
191-
inputOrOutput,
192-
columns.size(),
193-
columns.stream().map(Column::getName).collect(Collectors.toList())));
194-
}
195-
196-
Column column = columns.get(0);
197-
if (!column.isPhysical()) {
198-
throw new IllegalArgumentException(
199-
String.format(
200-
"%s column %s should be a physical column, but is a %s.",
201-
inputOrOutput, column.getName(), column.getClass()));
204+
"Model should have exactly one %s physical column, but actually has %s physical columns: %s",
205+
inputOrOutput, physicalColumnNames.size(), physicalColumnNames));
202206
}
203207

208+
Column column = schema.getColumn(physicalColumnNames.get(0)).get();
204209
if (!expectedType.equals(column.getDataType().getLogicalType())) {
205210
throw new IllegalArgumentException(
206211
String.format(
@@ -210,6 +215,33 @@ protected void validateSingleColumnSchema(
210215
expectedType,
211216
column.getDataType().getLogicalType()));
212217
}
218+
219+
List<Column> metadataColumns =
220+
columns.stream()
221+
.filter(x -> x instanceof Column.MetadataColumn)
222+
.collect(Collectors.toList());
223+
if (!metadataColumns.isEmpty()) {
224+
Preconditions.checkArgument(
225+
"output".equals(inputOrOutput), "Only output schema supports metadata column");
226+
227+
for (Column metadataColumn : metadataColumns) {
228+
ErrorMessageMetadata errorMessageMetadata =
229+
ErrorMessageMetadata.get(metadataColumn.getName());
230+
Preconditions.checkNotNull(
231+
errorMessageMetadata,
232+
String.format(
233+
"Unexpected metadata column %s. Supported metadata columns:\n%s",
234+
metadataColumn.getName(),
235+
ErrorMessageMetadata.getAllKeysAndDescriptions()));
236+
Preconditions.checkArgument(
237+
errorMessageMetadata.dataType.equals(metadataColumn.getDataType()),
238+
String.format(
239+
"Expected metadata column %s to be of type %s, but is of type %s",
240+
metadataColumn.getName(),
241+
errorMessageMetadata.dataType,
242+
metadataColumn.getDataType()));
243+
}
244+
}
213245
}
214246

215247
/**
@@ -223,30 +255,52 @@ protected void validateSingleColumnSchema(
223255
* appropriate retry and error handling applied, or a null value if the request failed in
224256
* the middle and the failure should be ignored.
225257
*/
226-
protected <T> CompletableFuture<T> sendAsyncOpenAIRequest(
227-
Supplier<CompletableFuture<T>> requestSender) {
258+
protected <T> CompletableFuture<Collection<RowData>> sendAsyncOpenAIRequest(
259+
Supplier<CompletableFuture<T>> requestSender,
260+
Function<T, Collection<RowData>> converter) {
228261
CompletableFuture<T> result =
229262
retryAsync(
230263
requestSender,
231264
numRetry,
232265
retryBackoffBaseIntervalMs,
233266
retryBackoffStrategy,
234267
null);
235-
ErrorHandlingStrategy finalErrorHandlingStrategy =
236-
this.errorHandlingStrategy == ErrorHandlingStrategy.RETRY
237-
? this.retryFallbackStrategy
238-
: this.errorHandlingStrategy;
239-
if (finalErrorHandlingStrategy == ErrorHandlingStrategy.IGNORE) {
240-
result =
241-
result.exceptionally(
242-
(e) -> {
243-
LOG.warn(
244-
"The input row data failed to acquire a valid response. Ignoring the input.",
245-
e);
246-
return null;
247-
});
268+
return result.handle((x, throwable) -> this.convertToRowData(x, throwable, converter));
269+
}
270+
271+
private <T> Collection<RowData> convertToRowData(
272+
@Nullable T t,
273+
@Nullable Throwable throwable,
274+
Function<T, Collection<RowData>> converter) {
275+
if (throwable != null) {
276+
ErrorHandlingStrategy finalErrorHandlingStrategy =
277+
this.errorHandlingStrategy == ErrorHandlingStrategy.RETRY
278+
? this.retryFallbackStrategy
279+
: this.errorHandlingStrategy;
280+
if (finalErrorHandlingStrategy == ErrorHandlingStrategy.FAILOVER) {
281+
throw new RuntimeException(throwable);
282+
} else {
283+
LOG.warn(
284+
"The input row data failed to acquire a valid response. Ignoring the input.",
285+
throwable);
286+
GenericRowData rowData = new GenericRowData(this.outputColumnNames.size());
287+
boolean isMetadataSet = false;
288+
for (int i = 0; i < this.outputColumnNames.size(); i++) {
289+
String columnName = this.outputColumnNames.get(i);
290+
ErrorMessageMetadata errorMessageMetadata =
291+
ErrorMessageMetadata.get(columnName);
292+
if (errorMessageMetadata != null) {
293+
rowData.setField(i, errorMessageMetadata.converter.apply(throwable));
294+
isMetadataSet = true;
295+
}
296+
}
297+
return isMetadataSet ? Collections.singletonList(rowData) : Collections.emptyList();
298+
}
299+
} else if (t == null) {
300+
return Collections.emptyList();
301+
} else {
302+
return converter.apply(t);
248303
}
249-
return result;
250304
}
251305

252306
private <T> CompletableFuture<T> retryAsync(
@@ -348,4 +402,55 @@ public long getMinRetryTotalTime(long baseRetryInterval, int numRetry) {
348402

349403
public abstract long getMinRetryTotalTime(long baseRetryInterval, int numRetry);
350404
}
405+
406+
/**
407+
* Metadata that can be read from the output row about error messages. Referenced from Flink
408+
* HTTP Connector's ReadableMetadata.
409+
*/
410+
protected enum ErrorMessageMetadata {
411+
ERROR_STRING(
412+
"error-string",
413+
DataTypes.STRING(),
414+
x -> BinaryStringData.fromString(x.getMessage()),
415+
"A message associated with the error"),
416+
HTTP_STATUS_CODE(
417+
"http-status-code",
418+
DataTypes.INT(),
419+
e ->
420+
ExceptionUtils.findThrowable(e, OpenAIServiceException.class)
421+
.map(OpenAIServiceException::statusCode)
422+
.orElse(null),
423+
"The HTTP status code");
424+
425+
final String key;
426+
final DataType dataType;
427+
final Function<Throwable, Object> converter;
428+
final String description;
429+
430+
ErrorMessageMetadata(
431+
String key,
432+
DataType dataType,
433+
Function<Throwable, Object> converter,
434+
String description) {
435+
this.key = key;
436+
this.dataType = dataType;
437+
this.converter = converter;
438+
this.description = description;
439+
}
440+
441+
static @Nullable ErrorMessageMetadata get(String key) {
442+
for (ErrorMessageMetadata value : values()) {
443+
if (value.key.equals(key)) {
444+
return value;
445+
}
446+
}
447+
return null;
448+
}
449+
450+
static String getAllKeysAndDescriptions() {
451+
return Arrays.stream(values())
452+
.map(value -> value.key + ":\t" + value.description)
453+
.collect(Collectors.joining("\n"));
454+
}
455+
}
351456
}

flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/OpenAIChatModelFunction.java

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,8 @@
3333
import com.openai.models.chat.completions.ChatCompletionCreateParams.ResponseFormat;
3434
import com.openai.services.async.chat.ChatCompletionServiceAsync;
3535

36-
import javax.annotation.Nullable;
37-
3836
import java.util.Arrays;
3937
import java.util.Collection;
40-
import java.util.Collections;
4138
import java.util.List;
4239
import java.util.concurrent.CompletableFuture;
4340
import java.util.stream.Collectors;
@@ -53,6 +50,7 @@ public class OpenAIChatModelFunction extends AbstractOpenAIModelFunction {
5350
private final String model;
5451
private final String systemPrompt;
5552
private final Configuration config;
53+
private final int outputColumnIndex;
5654

5755
public OpenAIChatModelFunction(
5856
ModelProviderFactory.Context factoryContext, ReadableConfig config) {
@@ -64,6 +62,21 @@ public OpenAIChatModelFunction(
6462
factoryContext.getCatalogModel().getResolvedOutputSchema(),
6563
new VarCharType(VarCharType.MAX_LENGTH),
6664
"output");
65+
this.outputColumnIndex = getOutputColumnIndex();
66+
}
67+
68+
private int getOutputColumnIndex() {
69+
for (int i = 0; i < this.outputColumnNames.size(); i++) {
70+
String columnName = this.outputColumnNames.get(i);
71+
if (ErrorMessageMetadata.get(columnName) == null) {
72+
// Prior checks have guaranteed that there is one and only one physical output
73+
// column.
74+
return i;
75+
}
76+
}
77+
throw new IllegalArgumentException(
78+
"There should be one and only one physical output column. Actual columns: "
79+
+ this.outputColumnNames);
6780
}
6881

6982
@Override
@@ -93,21 +106,21 @@ public CompletableFuture<Collection<RowData>> asyncPredictInternal(String input)
93106

94107
ChatCompletionCreateParams params = builder.build();
95108
ChatCompletionServiceAsync serviceAsync = client.chat().completions();
96-
return sendAsyncOpenAIRequest(() -> serviceAsync.create(params))
97-
.thenApply(this::convertToRowData);
109+
return sendAsyncOpenAIRequest(() -> serviceAsync.create(params), this::convertToRowData);
98110
}
99111

100-
private List<RowData> convertToRowData(@Nullable ChatCompletion chatCompletion) {
101-
if (chatCompletion == null) {
102-
return Collections.emptyList();
103-
}
104-
112+
private List<RowData> convertToRowData(ChatCompletion chatCompletion) {
105113
return chatCompletion.choices().stream()
106114
.map(
107-
choice ->
108-
GenericRowData.of(
109-
BinaryStringData.fromString(
110-
choice.message().content().orElse(""))))
115+
choice -> {
116+
GenericRowData rowData =
117+
new GenericRowData(this.outputColumnNames.size());
118+
rowData.setField(
119+
this.outputColumnIndex,
120+
BinaryStringData.fromString(
121+
choice.message().content().orElse("")));
122+
return rowData;
123+
})
111124
.collect(Collectors.toList());
112125
}
113126

0 commit comments

Comments
 (0)