Skip to content

Commit 6fda959

Browse files
author
BitsAdmin
committed
Merge branch 'feat/ark/add_embedding_merge' into 'integration_2024-06-06_285058385410'
feat: [development task] ark-runtime-java (695113) See merge request iaasng/volcengine-java-sdk!189
2 parents f53b7f9 + f06d782 commit 6fda959

File tree

13 files changed

+515
-64
lines changed

13 files changed

+515
-64
lines changed

volcengine-java-sdk-ark-runtime/src/main/java/com/volcengine/ark/runtime/interceptor/EndpointStsAuthenticationInterceptor.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ public class EndpointStsAuthenticationInterceptor implements Interceptor {
3131
private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
3232
private final ArkApi volcClient;
3333

34-
public EndpointStsAuthenticationInterceptor(String ak, String sk) {
34+
public EndpointStsAuthenticationInterceptor(String ak, String sk, String region) {
3535
Objects.requireNonNull(ak, "Ak token required");
3636
Objects.requireNonNull(sk, "Sk token required");
3737
this.ak = ak;
@@ -40,9 +40,8 @@ public EndpointStsAuthenticationInterceptor(String ak, String sk) {
4040

4141
ApiClient apiClient = new ApiClient()
4242
.setCredentials(Credentials.getCredentials(ak,sk))
43-
.setRegion("cn-beijing");
44-
ArkApi arkApi = new ArkApi(apiClient);
45-
this.volcClient = arkApi;
43+
.setRegion(region);
44+
this.volcClient = new ArkApi(apiClient);
4645
}
4746

4847
@Override
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
package com.volcengine.ark.runtime.interceptor;
2+
3+
4+
import okhttp3.Interceptor;
5+
import okhttp3.Request;
6+
import okhttp3.Response;
7+
8+
import java.io.IOException;
9+
import java.io.InterruptedIOException;
10+
11+
import static java.lang.Math.random;
12+
13+
public class RetryInterceptor implements Interceptor {
14+
15+
private final int retryTimes;
16+
private final double INITIAL_RETRY_DELAY = 0.5;
17+
private final double MAX_RETRY_DELAY = 8.0;
18+
19+
public RetryInterceptor(int retryTimes) {
20+
this.retryTimes = retryTimes;
21+
}
22+
23+
@Override
24+
public Response intercept(Chain chain) throws IOException {
25+
Request request = chain.request();
26+
27+
// try the request
28+
Response response = chain.proceed(request);
29+
30+
int tryCount = 0;
31+
while (response.code() >= 500 && tryCount < retryTimes) {
32+
tryCount++;
33+
34+
// retry the request
35+
response.close();
36+
37+
try {
38+
double interval = retryInterval(retryTimes, retryTimes - tryCount) * 1000;
39+
Thread.sleep(Math.round(interval));
40+
} catch (InterruptedException e) {
41+
Thread.currentThread().interrupt();
42+
throw new InterruptedIOException();
43+
}
44+
response = chain.proceed(request);
45+
}
46+
47+
return response;
48+
}
49+
50+
public double retryInterval(int max, int remain) {
51+
int nbRetries = max - remain;
52+
double sleepSeconds = Math.min(INITIAL_RETRY_DELAY * Math.pow(2.0, nbRetries), MAX_RETRY_DELAY);
53+
double jitter = 1 - 0.25 * random();
54+
return sleepSeconds * jitter;
55+
}
56+
}
57+

volcengine-java-sdk-ark-runtime/src/main/java/com/volcengine/ark/runtime/model/completion/chat/ChatFunctionCall.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ public class ChatFunctionCall {
1212
/**
1313
* The arguments of the call produced by the model, represented as a JsonNode for easy manipulation.
1414
*/
15-
JsonNode arguments;
15+
String arguments;
1616

17-
public ChatFunctionCall(String name, JsonNode arguments) {
17+
public ChatFunctionCall(String name, String arguments) {
1818
this.name = name;
1919
this.arguments = arguments;
2020
}
@@ -29,11 +29,11 @@ public void setName(String name) {
2929
this.name = name;
3030
}
3131

32-
public JsonNode getArguments() {
32+
public String getArguments() {
3333
return arguments;
3434
}
3535

36-
public void setArguments(JsonNode arguments) {
36+
public void setArguments(String arguments) {
3737
this.arguments = arguments;
3838
}
3939

volcengine-java-sdk-ark-runtime/src/main/java/com/volcengine/ark/runtime/model/completion/chat/ChatToolCall.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ public class ChatToolCall {
1616
/**
1717
* The function that the model called
1818
*/
19-
ChatFunction function;
19+
ChatFunctionCall function;
2020

21-
public ChatToolCall(String id, String type, ChatFunction function) {
21+
public ChatToolCall(String id, String type, ChatFunctionCall function) {
2222
this.id = id;
2323
this.type = type;
2424
this.function = function;
@@ -42,11 +42,11 @@ public void setType(String type) {
4242
this.type = type;
4343
}
4444

45-
public ChatFunction getFunction() {
45+
public ChatFunctionCall getFunction() {
4646
return function;
4747
}
4848

49-
public void setFunction(ChatFunction function) {
49+
public void setFunction(ChatFunctionCall function) {
5050
this.function = function;
5151
}
5252

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
package com.volcengine.ark.runtime.model.embeddings;
2+
3+
import java.util.List;
4+
5+
public class Embedding {
6+
7+
/**
8+
* The type of object returned, should be "embedding"
9+
*/
10+
String object;
11+
12+
/**
13+
* The embedding vector
14+
*/
15+
List<Double> embedding;
16+
17+
/**
18+
* The position of this embedding in the list
19+
*/
20+
Integer index;
21+
22+
public String getObject() {
23+
return object;
24+
}
25+
26+
public void setObject(String object) {
27+
this.object = object;
28+
}
29+
30+
public List<Double> getEmbedding() {
31+
return embedding;
32+
}
33+
34+
public void setEmbedding(List<Double> embedding) {
35+
this.embedding = embedding;
36+
}
37+
38+
public Integer getIndex() {
39+
return index;
40+
}
41+
42+
public void setIndex(Integer index) {
43+
this.index = index;
44+
}
45+
46+
@Override
47+
public String toString() {
48+
return "Embedding{" +
49+
"object='" + object + '\'' +
50+
", embedding=" + embedding +
51+
", index=" + index +
52+
'}';
53+
}
54+
}
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
package com.volcengine.ark.runtime.model.embeddings;
2+
3+
4+
import java.util.List;
5+
6+
7+
public class EmbeddingRequest {
8+
9+
/**
10+
* The name of the model to use.
11+
* Required if using the new v1/embeddings endpoint.
12+
*/
13+
String model;
14+
15+
/**
16+
* Input text to get embeddings for, encoded as a string or array of tokens.
17+
* To get embeddings for multiple inputs in a single request, pass an array of strings or array of token arrays.
18+
* Each input must not exceed 2048 tokens in length.
19+
* <p>
20+
* Unless you are embedding code, we suggest replacing newlines (\n) in your input with a single space,
21+
* as we have observed inferior results when newlines are present.
22+
*/
23+
List<String> input;
24+
25+
/**
26+
* A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse.
27+
*/
28+
String user;
29+
30+
public EmbeddingRequest() {
31+
}
32+
33+
public String getModel() {
34+
return model;
35+
}
36+
37+
public void setModel(String model) {
38+
this.model = model;
39+
}
40+
41+
public List<String> getInput() {
42+
return input;
43+
}
44+
45+
public void setInput(List<String> input) {
46+
this.input = input;
47+
}
48+
49+
public String getUser() {
50+
return user;
51+
}
52+
53+
public void setUser(String user) {
54+
this.user = user;
55+
}
56+
57+
public static EmbeddingRequest.Builder builder() {
58+
return new Builder();
59+
}
60+
61+
public static final class Builder {
62+
private String model;
63+
private List<String> input;
64+
private String user;
65+
66+
private Builder() {
67+
}
68+
69+
public Builder model(String model) {
70+
this.model = model;
71+
return this;
72+
}
73+
74+
public Builder input(List<String> input) {
75+
this.input = input;
76+
return this;
77+
}
78+
79+
public Builder user(String user) {
80+
this.user = user;
81+
return this;
82+
}
83+
84+
public EmbeddingRequest build() {
85+
EmbeddingRequest embeddingRequest = new EmbeddingRequest();
86+
embeddingRequest.setModel(model);
87+
embeddingRequest.setInput(input);
88+
embeddingRequest.setUser(user);
89+
return embeddingRequest;
90+
}
91+
}
92+
}
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
package com.volcengine.ark.runtime.model.embeddings;
2+
3+
4+
import com.volcengine.ark.runtime.model.Usage;
5+
6+
import java.util.List;
7+
8+
public class EmbeddingResult {
9+
10+
/**
11+
* The GPTmodel used for generating embeddings
12+
*/
13+
String model;
14+
15+
/**
16+
* The type of object returned, should be "list"
17+
*/
18+
String object;
19+
20+
/**
21+
* A list of the calculated embeddings
22+
*/
23+
List<Embedding> data;
24+
25+
/**
26+
* The API usage for this request
27+
*/
28+
Usage usage;
29+
30+
public String getModel() {
31+
return model;
32+
}
33+
34+
public void setModel(String model) {
35+
this.model = model;
36+
}
37+
38+
public String getObject() {
39+
return object;
40+
}
41+
42+
public void setObject(String object) {
43+
this.object = object;
44+
}
45+
46+
public List<Embedding> getData() {
47+
return data;
48+
}
49+
50+
public void setData(List<Embedding> data) {
51+
this.data = data;
52+
}
53+
54+
public Usage getUsage() {
55+
return usage;
56+
}
57+
58+
public void setUsage(Usage usage) {
59+
this.usage = usage;
60+
}
61+
62+
@Override
63+
public String toString() {
64+
return "EmbeddingResult{" +
65+
"model='" + model + '\'' +
66+
", object='" + object + '\'' +
67+
", data=" + data +
68+
", usage=" + usage +
69+
'}';
70+
}
71+
}

volcengine-java-sdk-ark-runtime/src/main/java/com/volcengine/ark/runtime/service/ArkApi.java

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,24 @@
33
import com.volcengine.ark.runtime.Const;
44
import com.volcengine.ark.runtime.model.completion.chat.ChatCompletionRequest;
55
import com.volcengine.ark.runtime.model.completion.chat.ChatCompletionResult;
6+
import com.volcengine.ark.runtime.model.embeddings.EmbeddingRequest;
7+
import com.volcengine.ark.runtime.model.embeddings.EmbeddingResult;
68
import okhttp3.ResponseBody;
7-
import retrofit2.http.Body;
8-
import retrofit2.http.Header;
9-
import retrofit2.http.POST;
10-
import retrofit2.http.Streaming;
9+
import retrofit2.http.*;
1110
import retrofit2.Call;
1211
import io.reactivex.Single;
1312

13+
import java.util.Map;
14+
1415
public interface ArkApi {
1516

1617
@POST("/api/v3/chat/completions")
17-
Single<ChatCompletionResult> createChatCompletion(@Body ChatCompletionRequest request, @Header(Const.REQUEST_MODEL) String model);
18+
Single<ChatCompletionResult> createChatCompletion(@Body ChatCompletionRequest request, @Header(Const.REQUEST_MODEL) String model, @HeaderMap Map<String, String> customHeaders);
1819

1920
@Streaming
2021
@POST("/api/v3/chat/completions")
21-
Call<ResponseBody> createChatCompletionStream(@Body ChatCompletionRequest request, @Header(Const.REQUEST_MODEL) String model);
22+
Call<ResponseBody> createChatCompletionStream(@Body ChatCompletionRequest request, @Header(Const.REQUEST_MODEL) String model, @HeaderMap Map<String, String> customHeaders);
2223

24+
@POST("/api/v3/embeddings")
25+
Single<EmbeddingResult> createEmbeddings(@Body EmbeddingRequest request, @Header(Const.REQUEST_MODEL) String model, @HeaderMap Map<String, String> customHeaders);
2326
}

volcengine-java-sdk-ark-runtime/src/main/java/com/volcengine/ark/runtime/service/ArkBaseService.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
public abstract class ArkBaseService {
1111

1212
static final String BASE_URL = "https://ark.cn-beijing.volces.com";
13+
static final String BASE_REGION = "cn-beijing";
1314
static final Duration DEFAULT_TIMEOUT = Duration.ofSeconds(60);
1415
String apiKey = "";
1516
String ak = "";

0 commit comments

Comments
 (0)