Skip to content

Commit 26b871c

Browse files
committed
Merge branch 'feat/context-api-ga' into 'master'
feat(ark-runtime): add context chat api supports See merge request iaasng/volcengine-java-sdk!296
2 parents ef6070b + a5b9e74 commit 26b871c

File tree

15 files changed

+761
-16
lines changed

15 files changed

+761
-16
lines changed

volcengine-java-sdk-ark-runtime/src/main/java/com/volcengine/ark/runtime/Const.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,9 @@ public class Const {
1515

1616
public static final String RESOURCE_TYPE_BOT = "bot";
1717
public static final String RESOURCE_TYPE_ENDPOINT = "endpoint";
18+
19+
public static final String CONTEXT_MODE_SESSION = "session";
20+
public static final String CONTEXT_MODE_COMMON_PREFIX = "common_prefix";
21+
public static final String TRUNCATION_STRATEGY_TYPE_LAST_HISTORY_TOKENS = "last_history_tokens";
22+
public static final String TRUNCATION_STRATEGY_TYPE_ROLLING_TOKENS = "rolling_tokens";
1823
}

volcengine-java-sdk-ark-runtime/src/main/java/com/volcengine/ark/runtime/exception/ArkHttpException.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import com.google.gson.Gson;
44

55
public class ArkHttpException extends RuntimeException {
6+
public static Integer INTERNAL_SERVICE_CODE = 500;
67

78
public final int statusCode;
89

@@ -31,6 +32,7 @@ public String getMessage() {
3132
public String toString() {
3233
return "ArkHttpException{" +
3334
"statusCode=" + statusCode +
35+
", message='" + super.getMessage() + '\'' +
3436
", code='" + code + '\'' +
3537
", param='" + param + '\'' +
3638
", type='" + type + '\'' +

volcengine-java-sdk-ark-runtime/src/main/java/com/volcengine/ark/runtime/interceptor/RequestIdInterceptor.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
package com.volcengine.ark.runtime.interceptor;
22

33
import com.volcengine.ark.runtime.Const;
4+
import com.volcengine.ark.runtime.exception.ArkAPIError;
5+
import com.volcengine.ark.runtime.exception.ArkException;
6+
import com.volcengine.ark.runtime.exception.ArkHttpException;
47
import com.volcengine.version.Version;
58
import okhttp3.Interceptor;
69
import okhttp3.Request;
@@ -26,7 +29,14 @@ public Response intercept(Chain chain) throws IOException {
2629
requestBuilder.header("User-Agent", getUserAgent());
2730

2831
Request request = requestBuilder.build();
29-
return chain.proceed(request);
32+
33+
try {
34+
return chain.proceed(request);
35+
} catch (Exception e) {
36+
String requestId = request.header(Const.CLIENT_REQUEST_HEADER);
37+
ArkAPIError arkAPIError = new ArkAPIError(new ArkAPIError.ArkErrorDetails(e.getMessage(), "", "", ""));
38+
throw new ArkHttpException(arkAPIError, e, ArkHttpException.INTERNAL_SERVICE_CODE, requestId);
39+
}
3040
}
3141

3242
private String genRequestId() {

volcengine-java-sdk-ark-runtime/src/main/java/com/volcengine/ark/runtime/interceptor/RetryInterceptor.java

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,31 @@ public RetryInterceptor(int retryTimes) {
2121
}
2222

2323
@Override
24-
public Response intercept(Chain chain) throws IOException {
24+
public Response intercept(Chain chain) throws RuntimeException, InterruptedIOException {
2525
Request request = chain.request();
2626

27-
// try the request
28-
Response response = chain.proceed(request);
29-
27+
Response response = null;
3028
int tryCount = 0;
31-
while ((response.code() >= 500 || response.code() == 429) && tryCount < retryTimes) {
32-
tryCount++;
29+
boolean shouldRetry;
30+
Exception exception;
31+
do {
32+
if (response != null) {
33+
response.close();
34+
}
35+
exception = null;
36+
37+
try {
38+
response = chain.proceed(request);
39+
shouldRetry = response.code() >= 500 || response.code() == 429;
40+
} catch (Exception e) {
41+
shouldRetry = true;
42+
exception = e;
43+
}
3344

34-
// retry the request
35-
response.close();
45+
tryCount++;
46+
if (!(shouldRetry && tryCount <= retryTimes)) {
47+
break;
48+
}
3649

3750
try {
3851
double interval = retryInterval(retryTimes, retryTimes - tryCount) * 1000;
@@ -41,10 +54,12 @@ public Response intercept(Chain chain) throws IOException {
4154
Thread.currentThread().interrupt();
4255
throw new InterruptedIOException();
4356
}
44-
response = chain.proceed(request);
45-
}
57+
} while (true);
4658

47-
return response;
59+
if (response != null) {
60+
return response;
61+
}
62+
throw new RuntimeException(exception);
4863
}
4964

5065
public double retryInterval(int max, int remain) {
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package com.volcengine.ark.runtime.model;
2+
3+
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
4+
import com.fasterxml.jackson.annotation.JsonProperty;
5+
6+
@JsonIgnoreProperties(ignoreUnknown = true)
7+
public class PromptTokensDetails {
8+
9+
@JsonProperty("cached_tokens")
10+
private Integer cachedTokens;
11+
12+
public Integer getCachedTokens() {
13+
return cachedTokens;
14+
}
15+
16+
public void setCachedTokens(Integer cachedTokens) {
17+
this.cachedTokens = cachedTokens;
18+
}
19+
20+
@Override
21+
public String toString() {
22+
return "PromptTokensDetails{" +
23+
"cachedTokens=" + cachedTokens +
24+
'}';
25+
}
26+
}

volcengine-java-sdk-ark-runtime/src/main/java/com/volcengine/ark/runtime/model/Usage.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,22 @@ public class Usage {
2323
@JsonProperty("total_tokens")
2424
long totalTokens;
2525

26+
@JsonProperty("prompt_tokens_details")
27+
private PromptTokensDetails promptTokensDetails;
28+
2629
public Usage(long promptTokens, long completionTokens, long totalTokens) {
2730
this.promptTokens = promptTokens;
2831
this.completionTokens = completionTokens;
2932
this.totalTokens = totalTokens;
3033
}
3134

35+
public Usage(long promptTokens, long completionTokens, long totalTokens, PromptTokensDetails promptTokensDetails) {
36+
this.promptTokens = promptTokens;
37+
this.completionTokens = completionTokens;
38+
this.totalTokens = totalTokens;
39+
this.promptTokensDetails = promptTokensDetails;
40+
}
41+
3242
public Usage() {}
3343

3444
public long getPromptTokens() {
@@ -55,12 +65,21 @@ public void setTotalTokens(long totalTokens) {
5565
this.totalTokens = totalTokens;
5666
}
5767

68+
public PromptTokensDetails getPromptTokensDetails() {
69+
return promptTokensDetails;
70+
}
71+
72+
public void setPromptTokensDetails(PromptTokensDetails promptTokensDetails) {
73+
this.promptTokensDetails = promptTokensDetails;
74+
}
75+
5876
@Override
5977
public String toString() {
6078
return "Usage{" +
6179
"promptTokens=" + promptTokens +
6280
", completionTokens=" + completionTokens +
6381
", totalTokens=" + totalTokens +
82+
", promptTokensDetails=" + promptTokensDetails +
6483
'}';
6584
}
6685
}
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
package com.volcengine.ark.runtime.model.context;
2+
3+
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
4+
import com.fasterxml.jackson.annotation.JsonProperty;
5+
import com.volcengine.ark.runtime.model.completion.chat.ChatMessage;
6+
7+
import java.util.List;
8+
9+
@JsonIgnoreProperties(ignoreUnknown = true)
10+
public class CreateContextRequest {
11+
12+
@JsonProperty("model")
13+
private String model;
14+
15+
@JsonProperty("mode")
16+
private String mode;
17+
18+
@JsonProperty("messages")
19+
private List<ChatMessage> messages;
20+
21+
@JsonProperty("ttl")
22+
private Integer ttl;
23+
24+
@JsonProperty("truncation_strategy")
25+
private TruncationStrategy truncationStrategy;
26+
27+
public CreateContextRequest() {
28+
}
29+
30+
public CreateContextRequest(String model, String mode, List<ChatMessage> messages, Integer ttl, TruncationStrategy truncationStrategy) {
31+
this.model = model;
32+
this.mode = mode;
33+
this.messages = messages;
34+
this.ttl = ttl;
35+
this.truncationStrategy = truncationStrategy;
36+
}
37+
38+
public String getModel() {
39+
return model;
40+
}
41+
42+
public void setModel(String model) {
43+
this.model = model;
44+
}
45+
46+
public String getMode() {
47+
return mode;
48+
}
49+
50+
public void setMode(String mode) {
51+
this.mode = mode;
52+
}
53+
54+
public List<ChatMessage> getMessages() {
55+
return messages;
56+
}
57+
58+
public void setMessages(List<ChatMessage> messages) {
59+
this.messages = messages;
60+
}
61+
62+
public Integer getTtl() {
63+
return ttl;
64+
}
65+
66+
public void setTtl(Integer ttl) {
67+
this.ttl = ttl;
68+
}
69+
70+
public TruncationStrategy getTruncationStrategy() {
71+
return truncationStrategy;
72+
}
73+
74+
public void setTruncationStrategy(TruncationStrategy truncationStrategy) {
75+
this.truncationStrategy = truncationStrategy;
76+
}
77+
78+
@Override
79+
public String toString() {
80+
return "CreateContextRequest{" +
81+
"model='" + model + '\'' +
82+
", mode='" + mode + '\'' +
83+
", messages=" + messages +
84+
", ttl=" + ttl +
85+
", truncationStrategy=" + truncationStrategy +
86+
'}';
87+
}
88+
89+
public static CreateContextRequest.Builder builder() {
90+
return new Builder();
91+
}
92+
93+
public static class Builder {
94+
private String model;
95+
private String mode;
96+
private List<ChatMessage> messages;
97+
private Integer ttl;
98+
private TruncationStrategy truncationStrategy;
99+
100+
private Builder() {
101+
}
102+
103+
public Builder model(String model) {
104+
this.model = model;
105+
return this;
106+
}
107+
108+
public Builder mode(String mode) {
109+
this.mode = mode;
110+
return this;
111+
}
112+
113+
public Builder messages(List<ChatMessage> messages) {
114+
this.messages = messages;
115+
return this;
116+
}
117+
118+
public Builder ttl(Integer ttl) {
119+
this.ttl = ttl;
120+
return this;
121+
}
122+
123+
public Builder truncationStrategy(TruncationStrategy truncationStrategy) {
124+
this.truncationStrategy = truncationStrategy;
125+
return this;
126+
}
127+
128+
public CreateContextRequest build() {
129+
CreateContextRequest createContextRequest = new CreateContextRequest();
130+
createContextRequest.setModel(model);
131+
createContextRequest.setMode(mode);
132+
createContextRequest.setMessages(messages);
133+
createContextRequest.setTtl(ttl);
134+
createContextRequest.setTruncationStrategy(truncationStrategy);
135+
return createContextRequest;
136+
}
137+
}
138+
}

0 commit comments

Comments
 (0)