Skip to content

Commit 7e0dfda

Browse files
author
BitsAdmin
committed
Merge branch 'feat/ark/batch/embedding' into 'integration_2025-08-21_1037883052802'
feat: [development task] ark runtime (1576787) See merge request iaasng/volcengine-java-sdk!622
2 parents 33ad963 + 0020a4a commit 7e0dfda

File tree

8 files changed

+175
-4
lines changed

8 files changed

+175
-4
lines changed

volcengine-java-sdk-ark-runtime/src/main/java/com/volcengine/ark/runtime/Const.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,6 @@ public class Const {
2222
public static final String TRUNCATION_STRATEGY_TYPE_LAST_HISTORY_TOKENS = "last_history_tokens";
2323
public static final String TRUNCATION_STRATEGY_TYPE_ROLLING_TOKENS = "rolling_tokens";
2424

25-
public static final String BATCH_CHAT_PATH = "/api/v3/batch/chat/completions";
25+
public static final String BATCH_PATH_PREFIX = "/api/v3/batch";
2626
public static final int MAX_RETRY_TIMES = 259200;
2727
}

volcengine-java-sdk-ark-runtime/src/main/java/com/volcengine/ark/runtime/interceptor/BatchInterceptor.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ public BatchInterceptor() {
2121
public Response intercept(Chain chain) throws IOException {
2222
Request request = chain.request();
2323
HttpUrl url = request.url();
24-
if (!url.encodedPath().equals(BATCH_CHAT_PATH)) {
24+
if (!url.encodedPath().startsWith(BATCH_PATH_PREFIX)) {
2525
return chain.proceed(request);
2626
}
2727
String endpoint = request.header(REQUEST_MODEL);

volcengine-java-sdk-ark-runtime/src/main/java/com/volcengine/ark/runtime/interceptor/RetryInterceptor.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import okhttp3.Request;
55
import okhttp3.Response;
66

7-
import java.io.IOException;
87
import java.io.InterruptedIOException;
98

109
import static com.volcengine.ark.runtime.Const.*;
@@ -73,7 +72,7 @@ public double retryInterval(int max, int remain) {
7372

7473
public int getRetryTimes(Request request) {
7574
String path = request.url().encodedPath();
76-
if (path.equals(BATCH_CHAT_PATH)) {
75+
if (path.startsWith(BATCH_PATH_PREFIX)) {
7776
return MAX_RETRY_TIMES;
7877
}
7978
return retryTimes;

volcengine-java-sdk-ark-runtime/src/main/java/com/volcengine/ark/runtime/service/ArkApi.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,15 @@ public interface ArkApi {
6161
@POST("/api/v3/embeddings")
6262
Single<EmbeddingResult> createEmbeddings(@Body EmbeddingRequest request, @Header(Const.REQUEST_MODEL) String model, @HeaderMap Map<String, String> customHeaders);
6363

64+
@POST("/api/v3/batch/embeddings")
65+
Single<EmbeddingResult> createBatchEmbeddings(@Body EmbeddingRequest request, @Header(Const.REQUEST_MODEL) String model, @HeaderMap Map<String, String> customHeaders);
66+
6467
@POST("/api/v3/embeddings/multimodal")
6568
Single<MultimodalEmbeddingResult> createMultiModalEmbeddings(@Body MultimodalEmbeddingRequest request, @Header(Const.REQUEST_MODEL) String model, @HeaderMap Map<String, String> customHeaders);
6669

70+
@POST("/api/v3/batch/embeddings/multimodal")
71+
Single<MultimodalEmbeddingResult> createBatchMultiModalEmbeddings(@Body MultimodalEmbeddingRequest request, @Header(Const.REQUEST_MODEL) String model, @HeaderMap Map<String, String> customHeaders);
72+
6773
@POST("/api/v3/tokenization")
6874
Single<TokenizationResult> createTokenization(@Body TokenizationRequest request, @Header(Const.REQUEST_MODEL) String model, @HeaderMap Map<String, String> customHeaders);
6975

volcengine-java-sdk-ark-runtime/src/main/java/com/volcengine/ark/runtime/service/ArkBaseServiceImpl.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,12 @@ public interface ArkBaseServiceImpl {
4343

4444
EmbeddingResult createEmbeddings(EmbeddingRequest request);
4545

46+
EmbeddingResult createBatchEmbeddings(EmbeddingRequest request);
47+
4648
MultimodalEmbeddingResult createMultiModalEmbeddings(MultimodalEmbeddingRequest request);
4749

50+
MultimodalEmbeddingResult createBatchMultiModalEmbeddings(MultimodalEmbeddingRequest request);
51+
4852
ImagesResponse generateImages(GenerateImagesRequest request);
4953

5054
CreateContentGenerationTaskResult createContentGenerationTask(CreateContentGenerationTaskRequest request);

volcengine-java-sdk-ark-runtime/src/main/java/com/volcengine/ark/runtime/service/ArkService.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,14 @@ public EmbeddingResult createEmbeddings(EmbeddingRequest request, Map<String, St
203203
return execute(api.createEmbeddings(request, request.getModel(), customHeaders));
204204
}
205205

206+
public EmbeddingResult createBatchEmbeddings(EmbeddingRequest request) {
207+
return execute(api.createBatchEmbeddings(request, request.getModel(), new HashMap<>()));
208+
}
209+
210+
public EmbeddingResult createBatchEmbeddings(EmbeddingRequest request, Map<String, String> customHeaders) {
211+
return execute(api.createBatchEmbeddings(request, request.getModel(), customHeaders));
212+
}
213+
206214
public MultimodalEmbeddingResult createMultiModalEmbeddings(MultimodalEmbeddingRequest request) {
207215
return execute(api.createMultiModalEmbeddings(request, request.getModel(), new HashMap<>()));
208216
}
@@ -211,6 +219,14 @@ public MultimodalEmbeddingResult createMultiModalEmbeddings(MultimodalEmbeddingR
211219
return execute(api.createMultiModalEmbeddings(request, request.getModel(), customHeaders));
212220
}
213221

222+
public MultimodalEmbeddingResult createBatchMultiModalEmbeddings(MultimodalEmbeddingRequest request) {
223+
return execute(api.createBatchMultiModalEmbeddings(request, request.getModel(), new HashMap<>()));
224+
}
225+
226+
public MultimodalEmbeddingResult createBatchMultiModalEmbeddings(MultimodalEmbeddingRequest request, Map<String, String> customHeaders) {
227+
return execute(api.createBatchMultiModalEmbeddings(request, request.getModel(), customHeaders));
228+
}
229+
214230
@Override
215231
public BotChatCompletionResult createBotChatCompletion(BotChatCompletionRequest request) {
216232
return execute(api.createBotChatCompletion(request, request.getModel(), new HashMap<>()));
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
package com.volcengine.ark.runtime;
2+
3+
import com.volcengine.ark.runtime.model.embeddings.EmbeddingRequest;
4+
import com.volcengine.ark.runtime.service.ArkService;
5+
import okhttp3.ConnectionPool;
6+
import okhttp3.Dispatcher;
7+
8+
import java.time.Duration;
9+
import java.util.ArrayList;
10+
import java.util.List;
11+
import java.util.concurrent.CountDownLatch;
12+
import java.util.concurrent.ExecutorService;
13+
import java.util.concurrent.Executors;
14+
import java.util.concurrent.TimeUnit;
15+
16+
public class BatchEmbeddingsExample {
17+
18+
public static void main(String[] args) {
19+
// 为batch embeddings设置一个较大的超时时间,最小不小于10分钟
20+
Duration timeout = Duration.ofHours(1);
21+
int maxConcurrency = 5000;
22+
int taskNumPerWorker = 5;
23+
String apiKey = System.getenv("ARK_API_KEY");
24+
ConnectionPool connectionPool = new ConnectionPool(maxConcurrency, 10, TimeUnit.MINUTES);
25+
Dispatcher dispatcher = new Dispatcher();
26+
// 设置最大并发数
27+
dispatcher.setMaxRequests(maxConcurrency);
28+
dispatcher.setMaxRequestsPerHost(maxConcurrency);
29+
// 请单独为batch embeddings单独初始化一个service实例,且多个Endpoint间也不要复用同一个service实例,避免互相影响。单个service会根据最大并发数启动对应的线程池,会占用一定的资源
30+
ArkService service = ArkService.builder().dispatcher(dispatcher).timeout(timeout).connectionPool(connectionPool).apiKey(apiKey).build();
31+
32+
ExecutorService executorService = Executors.newFixedThreadPool(maxConcurrency);
33+
CountDownLatch latch = new CountDownLatch(maxConcurrency);
34+
Runnable batchEmbeddingTask = () -> {
35+
System.out.println("Executing task in " + Thread.currentThread().getName());
36+
for (int i = 0; i < taskNumPerWorker; i++) {
37+
// 每个线程执行的任务逻辑
38+
try {
39+
final List<String> input = new ArrayList<>();
40+
input.add("你是豆包,是由字节跳动开发的 AI 人工智能助手");
41+
input.add("常见的十字花科植物有哪些?");
42+
43+
EmbeddingRequest batchEmbeddingsRequest = EmbeddingRequest.builder()
44+
.model("${YOUR_ENDPOINT_ID}")
45+
.input(input)
46+
.build();
47+
48+
service.createBatchEmbeddings(batchEmbeddingsRequest);
49+
System.out.println(Thread.currentThread().getName() + ": request " + i + "succeed");
50+
} catch (Exception e) {
51+
System.out.println(Thread.currentThread().getName() + ": request " + i + " failed " + e.getMessage());
52+
}
53+
}
54+
System.out.println(Thread.currentThread().getName() + " done");
55+
latch.countDown();
56+
};
57+
for (int i = 0; i < maxConcurrency; i++) {
58+
executorService.submit(batchEmbeddingTask);
59+
}
60+
try {
61+
latch.await();
62+
} catch (InterruptedException ignored) {
63+
}
64+
System.out.println("所有线程已退出");
65+
executorService.shutdown();
66+
System.out.println("线程池已退出");
67+
// shutdown service after all requests is finished
68+
service.shutdownExecutor();
69+
}
70+
71+
}
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
package com.volcengine.ark.runtime;
2+
3+
import com.volcengine.ark.runtime.model.multimodalembeddings.MultimodalEmbeddingInput;
4+
import com.volcengine.ark.runtime.model.multimodalembeddings.MultimodalEmbeddingRequest;
5+
import com.volcengine.ark.runtime.service.ArkService;
6+
import okhttp3.ConnectionPool;
7+
import okhttp3.Dispatcher;
8+
9+
import java.time.Duration;
10+
import java.util.ArrayList;
11+
import java.util.List;
12+
import java.util.concurrent.CountDownLatch;
13+
import java.util.concurrent.ExecutorService;
14+
import java.util.concurrent.Executors;
15+
import java.util.concurrent.TimeUnit;
16+
17+
public class BatchMultiModalEmbeddingsExample {
18+
19+
public static void main(String[] args) {
20+
// 为batch multimodal embeddings设置一个较大的超时时间,最小不小于10分钟
21+
Duration timeout = Duration.ofHours(1);
22+
int maxConcurrency = 5000;
23+
int taskNumPerWorker = 5;
24+
String apiKey = System.getenv("ARK_API_KEY");
25+
ConnectionPool connectionPool = new ConnectionPool(maxConcurrency, 10, TimeUnit.MINUTES);
26+
Dispatcher dispatcher = new Dispatcher();
27+
// 设置最大并发数
28+
dispatcher.setMaxRequests(maxConcurrency);
29+
dispatcher.setMaxRequestsPerHost(maxConcurrency);
30+
// 请单独为batch multimodal embeddings单独初始化一个service实例,且多个Endpoint间也不要复用同一个service实例,避免互相影响。单个service会根据最大并发数启动对应的线程池,会占用一定的资源
31+
ArkService service = ArkService.builder().dispatcher(dispatcher).timeout(timeout).connectionPool(connectionPool).apiKey(apiKey).build();
32+
33+
ExecutorService executorService = Executors.newFixedThreadPool(maxConcurrency);
34+
CountDownLatch latch = new CountDownLatch(maxConcurrency);
35+
Runnable batchMultimodalEmbeddingTask = () -> {
36+
System.out.println("Executing task in " + Thread.currentThread().getName());
37+
for (int i = 0; i < taskNumPerWorker; i++) {
38+
// 每个线程执行的任务逻辑
39+
try {
40+
final List<MultimodalEmbeddingInput> input = new ArrayList<>();
41+
input.add(MultimodalEmbeddingInput.builder().type("text").text("把图中的蓝天换成白云").build());
42+
input.add(MultimodalEmbeddingInput.builder()
43+
.type("image_url")
44+
.imageUrl(new MultimodalEmbeddingInput.MultiModalEmbeddingContentPartImageURL("https://ark-project.tos-cn-beijing.ivolces.com/images/view.jpeg"))
45+
.build());
46+
47+
MultimodalEmbeddingRequest batchMultimodalEmbeddingRequest = MultimodalEmbeddingRequest.builder()
48+
.model("${YOUR_ENDPOINT_ID}")
49+
.input(input)
50+
.build();
51+
52+
service.createBatchMultiModalEmbeddings(batchMultimodalEmbeddingRequest);
53+
System.out.println(Thread.currentThread().getName() + ": request " + i + "succeed");
54+
} catch (Exception e) {
55+
System.out.println(Thread.currentThread().getName() + ": request " + i + " failed " + e.getMessage());
56+
}
57+
}
58+
System.out.println(Thread.currentThread().getName() + " done");
59+
latch.countDown();
60+
};
61+
for (int i = 0; i < maxConcurrency; i++) {
62+
executorService.submit(batchMultimodalEmbeddingTask);
63+
}
64+
try {
65+
latch.await();
66+
} catch (InterruptedException ignored) {
67+
}
68+
System.out.println("所有线程已退出");
69+
executorService.shutdown();
70+
System.out.println("线程池已退出");
71+
// shutdown service after all requests is finished
72+
service.shutdownExecutor();
73+
}
74+
75+
}

0 commit comments

Comments
 (0)