Skip to content

Commit fb58948

Browse files
committed
feat: arkruntime support batch embeddings
1 parent 9e8d3c7 commit fb58948

File tree

5 files changed

+171
-1
lines changed

5 files changed

+171
-1
lines changed

volcengine-java-sdk-ark-runtime/src/main/java/com/volcengine/ark/runtime/service/ArkApi.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,15 @@ public interface ArkApi {
6161
@POST("/api/v3/embeddings")
6262
Single<EmbeddingResult> createEmbeddings(@Body EmbeddingRequest request, @Header(Const.REQUEST_MODEL) String model, @HeaderMap Map<String, String> customHeaders);
6363

64+
@POST("/api/v3/batch/embeddings")
65+
Single<EmbeddingResult> createBatchEmbeddings(@Body EmbeddingRequest request, @Header(Const.REQUEST_MODEL) String model, @HeaderMap Map<String, String> customHeaders);
66+
6467
@POST("/api/v3/embeddings/multimodal")
6568
Single<MultimodalEmbeddingResult> createMultiModalEmbeddings(@Body MultimodalEmbeddingRequest request, @Header(Const.REQUEST_MODEL) String model, @HeaderMap Map<String, String> customHeaders);
6669

70+
@POST("/api/v3/batch/embeddings/multimodal")
71+
Single<MultimodalEmbeddingResult> createBatchMultiModalEmbeddings(@Body MultimodalEmbeddingRequest request, @Header(Const.REQUEST_MODEL) String model, @HeaderMap Map<String, String> customHeaders);
72+
6773
@POST("/api/v3/tokenization")
6874
Single<TokenizationResult> createTokenization(@Body TokenizationRequest request, @Header(Const.REQUEST_MODEL) String model, @HeaderMap Map<String, String> customHeaders);
6975

volcengine-java-sdk-ark-runtime/src/main/java/com/volcengine/ark/runtime/service/ArkBaseServiceImpl.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,11 @@ public interface ArkBaseServiceImpl {
4343

4444
EmbeddingResult createEmbeddings(EmbeddingRequest request);
4545

46+
EmbeddingResult createBatchEmbeddings(EmbeddingRequest request);
47+
4648
MultimodalEmbeddingResult createMultiModalEmbeddings(MultimodalEmbeddingRequest request);
4749

48-
ImagesResponse generateImages(GenerateImagesRequest request);
50+
MultimodalEmbeddingResult createBatchMultiModalEmbeddings(MultimodalEmbeddingRequest request);
4951

5052
CreateContentGenerationTaskResult createContentGenerationTask(CreateContentGenerationTaskRequest request);
5153

volcengine-java-sdk-ark-runtime/src/main/java/com/volcengine/ark/runtime/service/ArkService.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,14 @@ public EmbeddingResult createEmbeddings(EmbeddingRequest request, Map<String, St
203203
return execute(api.createEmbeddings(request, request.getModel(), customHeaders));
204204
}
205205

206+
public EmbeddingResult createBatchEmbeddings(EmbeddingRequest request) {
207+
return execute(api.createBatchEmbeddings(request, request.getModel(), new HashMap<>()));
208+
}
209+
210+
public EmbeddingResult createBatchEmbeddings(EmbeddingRequest request, Map<String, String> customHeaders) {
211+
return execute(api.createBatchEmbeddings(request, request.getModel(), customHeaders));
212+
}
213+
206214
public MultimodalEmbeddingResult createMultiModalEmbeddings(MultimodalEmbeddingRequest request) {
207215
return execute(api.createMultiModalEmbeddings(request, request.getModel(), new HashMap<>()));
208216
}
@@ -211,6 +219,14 @@ public MultimodalEmbeddingResult createMultiModalEmbeddings(MultimodalEmbeddingR
211219
return execute(api.createMultiModalEmbeddings(request, request.getModel(), customHeaders));
212220
}
213221

222+
public MultimodalEmbeddingResult createBatchMultiModalEmbeddings(MultimodalEmbeddingRequest request) {
223+
return execute(api.createBatchMultiModalEmbeddings(request, request.getModel(), new HashMap<>()));
224+
}
225+
226+
public MultimodalEmbeddingResult createBatchMultiModalEmbeddings(MultimodalEmbeddingRequest request, Map<String, String> customHeaders) {
227+
return execute(api.createBatchMultiModalEmbeddings(request, request.getModel(), customHeaders));
228+
}
229+
214230
@Override
215231
public BotChatCompletionResult createBotChatCompletion(BotChatCompletionRequest request) {
216232
return execute(api.createBotChatCompletion(request, request.getModel(), new HashMap<>()));
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
package com.volcengine.ark.runtime;
2+
3+
import com.volcengine.ark.runtime.model.embeddings.EmbeddingRequest;
4+
import com.volcengine.ark.runtime.service.ArkService;
5+
import okhttp3.ConnectionPool;
6+
import okhttp3.Dispatcher;
7+
8+
import java.time.Duration;
9+
import java.util.ArrayList;
10+
import java.util.List;
11+
import java.util.concurrent.CountDownLatch;
12+
import java.util.concurrent.ExecutorService;
13+
import java.util.concurrent.Executors;
14+
import java.util.concurrent.TimeUnit;
15+
16+
public class BatchEmbeddingsExample {
17+
18+
public static void main(String[] args) {
19+
// 为batch embeddings设置一个较大的超时时间,最小不小于10分钟
20+
Duration timeout = Duration.ofHours(1);
21+
int maxConcurrency = 5000;
22+
int taskNumPerWorker = 5;
23+
String apiKey = System.getenv("ARK_API_KEY");
24+
ConnectionPool connectionPool = new ConnectionPool(maxConcurrency, 10, TimeUnit.MINUTES);
25+
Dispatcher dispatcher = new Dispatcher();
26+
// 设置最大并发数
27+
dispatcher.setMaxRequests(maxConcurrency);
28+
dispatcher.setMaxRequestsPerHost(maxConcurrency);
29+
// 请单独为batch embeddings单独初始化一个service实例,且多个Endpoint间也不要复用同一个service实例,避免互相影响。单个service会根据最大并发数启动对应的线程池,会占用一定的资源
30+
ArkService service = ArkService.builder().dispatcher(dispatcher).timeout(timeout).connectionPool(connectionPool).apiKey(apiKey).build();
31+
32+
ExecutorService executorService = Executors.newFixedThreadPool(maxConcurrency);
33+
CountDownLatch latch = new CountDownLatch(maxConcurrency);
34+
Runnable batchEmbeddingTask = () -> {
35+
System.out.println("Executing task in " + Thread.currentThread().getName());
36+
for (int i = 0; i < taskNumPerWorker; i++) {
37+
// 每个线程执行的任务逻辑
38+
try {
39+
final List<String> input = new ArrayList<>();
40+
input.add("你是豆包,是由字节跳动开发的 AI 人工智能助手");
41+
input.add("常见的十字花科植物有哪些?");
42+
43+
EmbeddingRequest batchEmbeddingsRequest = EmbeddingRequest.builder()
44+
.model("${YOUR_ENDPOINT_ID}")
45+
.input(input)
46+
.build();
47+
48+
service.createBatchEmbeddings(batchEmbeddingsRequest);
49+
System.out.println(Thread.currentThread().getName() + ": request " + i + "succeed");
50+
} catch (Exception e) {
51+
System.out.println(Thread.currentThread().getName() + ": request " + i + " failed " + e.getMessage());
52+
}
53+
}
54+
System.out.println(Thread.currentThread().getName() + " done");
55+
latch.countDown();
56+
};
57+
for (int i = 0; i < maxConcurrency; i++) {
58+
executorService.submit(batchEmbeddingTask);
59+
}
60+
try {
61+
latch.await();
62+
} catch (InterruptedException ignored) {
63+
}
64+
System.out.println("所有线程已退出");
65+
executorService.shutdown();
66+
System.out.println("线程池已退出");
67+
// shutdown service after all requests is finished
68+
service.shutdownExecutor();
69+
}
70+
71+
}
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
package com.volcengine.ark.runtime;
2+
3+
import com.volcengine.ark.runtime.model.multimodalembeddings.MultimodalEmbeddingInput;
4+
import com.volcengine.ark.runtime.model.multimodalembeddings.MultimodalEmbeddingRequest;
5+
import com.volcengine.ark.runtime.service.ArkService;
6+
import okhttp3.ConnectionPool;
7+
import okhttp3.Dispatcher;
8+
9+
import java.time.Duration;
10+
import java.util.ArrayList;
11+
import java.util.List;
12+
import java.util.concurrent.CountDownLatch;
13+
import java.util.concurrent.ExecutorService;
14+
import java.util.concurrent.Executors;
15+
import java.util.concurrent.TimeUnit;
16+
17+
public class BatchMultiModalEmbeddingsExample {
18+
19+
public static void main(String[] args) {
20+
// 为batch multimodal embeddings设置一个较大的超时时间,最小不小于10分钟
21+
Duration timeout = Duration.ofHours(1);
22+
int maxConcurrency = 5000;
23+
int taskNumPerWorker = 5;
24+
String apiKey = System.getenv("ARK_API_KEY");
25+
ConnectionPool connectionPool = new ConnectionPool(maxConcurrency, 10, TimeUnit.MINUTES);
26+
Dispatcher dispatcher = new Dispatcher();
27+
// 设置最大并发数
28+
dispatcher.setMaxRequests(maxConcurrency);
29+
dispatcher.setMaxRequestsPerHost(maxConcurrency);
30+
// 请单独为batch multimodal embeddings单独初始化一个service实例,且多个Endpoint间也不要复用同一个service实例,避免互相影响。单个service会根据最大并发数启动对应的线程池,会占用一定的资源
31+
ArkService service = ArkService.builder().dispatcher(dispatcher).timeout(timeout).connectionPool(connectionPool).apiKey(apiKey).build();
32+
33+
ExecutorService executorService = Executors.newFixedThreadPool(maxConcurrency);
34+
CountDownLatch latch = new CountDownLatch(maxConcurrency);
35+
Runnable batchMultimodalEmbeddingTask = () -> {
36+
System.out.println("Executing task in " + Thread.currentThread().getName());
37+
for (int i = 0; i < taskNumPerWorker; i++) {
38+
// 每个线程执行的任务逻辑
39+
try {
40+
final List<MultimodalEmbeddingInput> input = new ArrayList<>();
41+
input.add(MultimodalEmbeddingInput.builder().type("text").text("把图中的蓝天换成白云").build());
42+
input.add(MultimodalEmbeddingInput.builder()
43+
.type("image_url")
44+
.imageUrl(new MultimodalEmbeddingInput.MultiModalEmbeddingContentPartImageURL("https://ark-project.tos-cn-beijing.ivolces.com/images/view.jpeg"))
45+
.build());
46+
47+
MultimodalEmbeddingRequest batchMultimodalEmbeddingRequest = MultimodalEmbeddingRequest.builder()
48+
.model("${YOUR_ENDPOINT_ID}")
49+
.input(input)
50+
.build();
51+
52+
service.createBatchMultiModalEmbeddings(batchMultimodalEmbeddingRequest);
53+
System.out.println(Thread.currentThread().getName() + ": request " + i + "succeed");
54+
} catch (Exception e) {
55+
System.out.println(Thread.currentThread().getName() + ": request " + i + " failed " + e.getMessage());
56+
}
57+
}
58+
System.out.println(Thread.currentThread().getName() + " done");
59+
latch.countDown();
60+
};
61+
for (int i = 0; i < maxConcurrency; i++) {
62+
executorService.submit(batchMultimodalEmbeddingTask);
63+
}
64+
try {
65+
latch.await();
66+
} catch (InterruptedException ignored) {
67+
}
68+
System.out.println("所有线程已退出");
69+
executorService.shutdown();
70+
System.out.println("线程池已退出");
71+
// shutdown service after all requests is finished
72+
service.shutdownExecutor();
73+
}
74+
75+
}

0 commit comments

Comments
 (0)