Skip to content

Commit c3bb363

Browse files
[FLINK-38549][model] Support limiting context window size
1 parent bced0b4 commit c3bb363

File tree

12 files changed

+468
-2
lines changed

12 files changed

+468
-2
lines changed

docs/content.zh/docs/connectors/models/openai.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,33 @@ FROM ML_PREDICT(
130130
<td>String</td>
131131
<td>模型名称,例如:<code>gpt-3.5-turbo</code>, <code>text-embedding-ada-002</code>。</td>
132132
</tr>
133+
<tr>
134+
<td>
135+
<h5>max-context-size</h5>
136+
</td>
137+
<td>可选</td>
138+
<td style="word-wrap: break-word;">(none)</td>
139+
<td>Integer</td>
140+
<td>单个请求的最大上下文长度,单位为Token数量。当长度超过该值时,将使用context-overflow-action指定的溢出行为。</td>
141+
</tr>
142+
<tr>
143+
<td>
144+
<h5>context-overflow-action</h5>
145+
</td>
146+
<td>可选</td>
147+
<td style="word-wrap: break-word;">(none)</td>
148+
<td>String</td>
149+
<td>处理上下文溢出的操作。支持的操作:
150+
<ul>
151+
<li><code>truncated-tail</code>(默认): 从上下文尾部截断超出的token。</li>
152+
<li><code>truncated-tail-log</code>: 从上下文尾部截断超出的token。记录截断日志。</li>
153+
<li><code>truncated-head</code>: 从上下文头部截断超出的token。</li>
154+
<li><code>truncated-head-log</code>: 从上下文头部截断超出的token。记录截断日志。</li>
155+
<li><code>skipped</code>: 跳过输入行。</li>
156+
<li><code>skipped-log</code>: 跳过输入行。记录跳过日志。</li>
157+
</ul>
158+
</td>
159+
</tr>
133160
</tbody>
134161
</table>
135162

docs/content/docs/connectors/models/openai.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,33 @@ FROM ML_PREDICT(
130130
<td>String</td>
131131
<td>Model name, e.g. <code>gpt-3.5-turbo</code>, <code>text-embedding-ada-002</code>.</td>
132132
</tr>
133+
<tr>
134+
<td>
135+
<h5>max-context-size</h5>
136+
</td>
137+
<td>optional</td>
138+
<td style="word-wrap: break-word;">(none)</td>
139+
<td>Integer</td>
140+
<td>Max number of tokens for context. context-overflow-action would be triggered if this threshold is exceeded.</td>
141+
</tr>
142+
<tr>
143+
<td>
144+
<h5>context-overflow-action</h5>
145+
</td>
146+
<td>optional</td>
147+
<td style="word-wrap: break-word;">(none)</td>
148+
<td>String</td>
149+
<td>Action to handle context overflows. Supported actions:
150+
<ul>
151+
<li><code>truncated-tail</code>(default): Truncates exceeded tokens from the tail of the context.</li>
152+
<li><code>truncated-tail-log</code>: Truncates exceeded tokens from the tail of the context. Records the truncation log.</li>
153+
<li><code>truncated-head</code>: Truncates exceeded tokens from the head of the context.</li>
154+
<li><code>truncated-head-log</code>: Truncates exceeded tokens from the head of the context. Records the truncation log.</li>
155+
<li><code>skipped</code>: Skips the input row.</li>
156+
<li><code>skipped-log</code>: Skips the input row. Records the skipping log.</li>
157+
</ul>
158+
</td>
159+
</tr>
133160
</tbody>
134161
</table>
135162

34.8 MB
Binary file not shown.

flink-models/flink-model-openai/pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,12 @@ under the License.
7272
<optional>${flink.markBundledAsOptional}</optional>
7373
</dependency>
7474

75+
<dependency>
76+
<groupId>com.knuddels</groupId>
77+
<artifactId>jtokkit</artifactId>
78+
<version>1.1.0</version>
79+
</dependency>
80+
7581
<!-- Core dependencies -->
7682
<dependency>
7783
<groupId>org.apache.flink</groupId>

flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/AbstractOpenAIModelFunction.java

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
import org.slf4j.Logger;
3636
import org.slf4j.LoggerFactory;
3737

38+
import javax.annotation.Nullable;
39+
3840
import java.util.List;
3941
import java.util.stream.Collectors;
4042

@@ -73,11 +75,32 @@ public abstract class AbstractOpenAIModelFunction extends AsyncPredictFunction {
7375
code("gpt-3.5-turbo"), code("text-embedding-ada-002"))
7476
.build());
7577

78+
public static final ConfigOption<Integer> MAX_CONTEXT_SIZE =
79+
ConfigOptions.key("max-context-size")
80+
.intType()
81+
.noDefaultValue()
82+
.withDescription(
83+
"Max number of tokens for context. context-overflow-action would be triggered if this threshold is exceeded.");
84+
85+
public static final ConfigOption<ContextOverflowAction> CONTEXT_OVERFLOW_ACTION =
86+
ConfigOptions.key("context-overflow-action")
87+
.enumType(ContextOverflowAction.class)
88+
.defaultValue(ContextOverflowAction.TRUNCATED_TAIL)
89+
.withDescription(
90+
Description.builder()
91+
.text("Action to handle context overflows. Supported actions:")
92+
.linebreak()
93+
.text(ContextOverflowAction.getAllValuesAndDescriptions())
94+
.build());
95+
7696
protected transient OpenAIClientAsync client;
7797

7898
private final int numRetry;
7999
private final String baseUrl;
80100
private final String apiKey;
101+
private final String model;
102+
@Nullable private final Integer maxContextSize;
103+
private final ContextOverflowAction contextOverflowAction;
81104

82105
public AbstractOpenAIModelFunction(
83106
ModelProviderFactory.Context factoryContext, ReadableConfig config) {
@@ -94,6 +117,9 @@ public AbstractOpenAIModelFunction(
94117
// resilience while maintaining throughput efficiency.
95118
this.numRetry =
96119
config.get(ExecutionConfigOptions.TABLE_EXEC_ASYNC_LOOKUP_BUFFER_CAPACITY) * 10;
120+
this.model = config.get(MODEL);
121+
this.maxContextSize = config.get(MAX_CONTEXT_SIZE);
122+
this.contextOverflowAction = config.get(CONTEXT_OVERFLOW_ACTION);
97123

98124
validateSingleColumnSchema(
99125
factoryContext.getCatalogModel().getResolvedInputSchema(),
@@ -106,6 +132,7 @@ public void open(FunctionContext context) throws Exception {
106132
super.open(context);
107133
LOG.debug("Creating an OpenAI client.");
108134
this.client = OpenAIUtils.createAsyncClient(baseUrl, apiKey, numRetry);
135+
this.contextOverflowAction.initializeEncodingForContextLimit(model, maxContextSize);
109136
}
110137

111138
@Override
@@ -120,6 +147,15 @@ public void close() throws Exception {
120147

121148
protected abstract String getEndpointSuffix();
122149

150+
/**
151+
* Preprocesses the input string to meet the context limit.
152+
*
153+
* @return The processed input string, or null if the input is too long and should be skipped.
154+
*/
155+
protected @Nullable String preprocessTextWithTokenLimit(String input) {
156+
return contextOverflowAction.processTokensWithLimit(model, input, maxContextSize);
157+
}
158+
123159
protected void validateSingleColumnSchema(
124160
ResolvedSchema schema, LogicalType expectedType, String inputOrOutput) {
125161
List<Column> columns = schema.getColumns();

0 commit comments

Comments
 (0)