Skip to content

Commit 5433358

Browse files
committed
fix: migrate to native java http client
1 parent 18af927 commit 5433358

File tree

1 file changed

+92
-90
lines changed

1 file changed

+92
-90
lines changed

examples/clients/client-java/app/src/main/java/com/example/VespaClient.java

Lines changed: 92 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,18 @@
33
import java.io.FileInputStream;
44
import java.io.IOException;
55
import java.net.URI;
6+
import java.net.URLEncoder;
7+
import java.net.http.HttpClient;
8+
import java.net.http.HttpRequest;
9+
import java.net.http.HttpResponse;
10+
import java.net.http.HttpResponse.BodyHandler;
11+
import java.nio.charset.StandardCharsets;
612
import java.nio.file.Path;
713
import java.time.Duration;
8-
import java.util.Optional;
9-
import java.util.concurrent.ExecutorService;
10-
import java.util.concurrent.Executors;
11-
import java.util.concurrent.TimeUnit;
14+
import java.util.ArrayList;
15+
import java.util.List;
16+
import java.util.concurrent.CompletableFuture;
17+
import java.util.concurrent.CountDownLatch;
1218
import java.util.concurrent.atomic.AtomicLong;
1319
import java.util.logging.Logger;
1420

@@ -27,11 +33,7 @@
2733
import ai.vespa.feed.client.Result;
2834
import nl.altindag.ssl.SSLFactory;
2935
import nl.altindag.ssl.pem.util.PemUtils;
30-
import okhttp3.ConnectionPool;
31-
import okhttp3.HttpUrl;
32-
import okhttp3.OkHttpClient;
33-
import okhttp3.Request;
34-
import okhttp3.Response;
36+
3537

3638
public class VespaClient {
3739
private final static Logger log = Logger.getLogger(VespaClient.class.getName());
@@ -40,20 +42,26 @@ private enum AuthMethod {
4042
MTLS, // mTLS: Recommended for Vespa Cloud
4143
TOKEN, // Token-based authentication
4244
NONE // E.g. if self-hosting.
43-
};
45+
}
4446

4547
private static final AuthMethod AUTH_METHOD = AuthMethod.MTLS;
4648

47-
private static final String ENDPOINT = "";
48-
// Auth method mTLS
49-
private static final String PUBLIC_CERT = "";
50-
private static final String PRIVATE_KEY = "";
49+
private static final String ENDPOINT = "YOUR_ENDPOINT";
50+
// Auth method: mTLS
51+
private static final String PUBLIC_CERT = "/path/to/public-cert.pem";
52+
private static final String PRIVATE_KEY = "/peth/to/private-key.pem";
5153

52-
// Auth method token.
53-
private static final String TOKEN = "";
54+
// Auth method: token.
55+
private static final String TOKEN = "YOUR_TOKEN";
5456

55-
private static final int LOAD_CONCURRENCY = 400;
56-
private static final int LOAD_NUM_QUERIES = 50000;
57+
// Number of concurrent in-flight HTTP/2 streams across all connections.
58+
private static final int LOAD_POOL_SIZE = 800;
59+
private static final int LOAD_NUM_QUERIES = 1000000;
60+
// Each HttpClient opens its own connection. Multiple connections spread load
61+
// across container nodes via the load balancer.
62+
private static final int NUM_CONNECTIONS = 16;
63+
private static final String LOAD_TEST_YQL = "select * from sources * where userQuery()";
64+
private static final String LOAD_TEST_QUERY = "guinness world record";
5765

5866
public static void main(String[] args) throws Exception {
5967
Options options = new Options();
@@ -74,8 +82,8 @@ public static void main(String[] args) throws Exception {
7482
} else if (cmd.hasOption("q")) {
7583
String query = cmd.getOptionValue("q");
7684
try {
77-
String result = runSingleQuery(createHttpClient(), "select * from sources * where userQuery()", query).get();
78-
log.info(result);
85+
HttpResponse<String> response = runSingleQuery(createHttpClient(), "select * from sources * where userQuery()", query, HttpResponse.BodyHandlers.ofString()).get();
86+
log.info(response.body());
7987
} catch (Exception e) {
8088
log.severe("Query failed with message: " + e.getMessage());
8189
}
@@ -98,39 +106,16 @@ static SSLFactory getSSLFactory() {
98106
return sslFactory;
99107
}
100108

101-
/**
102-
* Create a {@link OkHttpClient} for querying, with settings based on {@link VespaClient#AUTH_METHOD}.
103-
*/
104-
static OkHttpClient createHttpClient() {
105-
var builder = new OkHttpClient.Builder()
106-
.connectionPool(new ConnectionPool(LOAD_CONCURRENCY, 5, TimeUnit.MINUTES))
107-
.connectTimeout(5, TimeUnit.SECONDS)
108-
.readTimeout(2, TimeUnit.SECONDS);
109+
static HttpClient createHttpClient() {
110+
var clientBuilder = HttpClient.newBuilder()
111+
.version(HttpClient.Version.HTTP_2)
112+
.connectTimeout(Duration.ofSeconds(5));
109113

110-
switch (AUTH_METHOD) {
111-
case MTLS:
112-
{
113-
var sslFactory = getSSLFactory();
114-
builder.sslSocketFactory(sslFactory.getSslSocketFactory(), sslFactory.getTrustManager().orElseThrow());
115-
}
116-
break;
117-
case TOKEN:
118-
{
119-
builder.addInterceptor(chain -> {
120-
return chain.proceed(
121-
chain.request()
122-
.newBuilder()
123-
.header("Authorization", "Bearer " + TOKEN)
124-
.build()
125-
);
126-
});
127-
}
128-
break;
129-
case NONE:
130-
break;
114+
if (AUTH_METHOD == AuthMethod.MTLS) {
115+
clientBuilder.sslContext(getSSLFactory().getSslContext());
131116
}
132117

133-
return builder.build();
118+
return clientBuilder.build();
134119
}
135120

136121
/**
@@ -154,55 +139,54 @@ static JsonFeeder createFeeder() {
154139
.build();
155140
}
156141

157-
static Optional<String> runSingleQuery(OkHttpClient client, String yql, String query) throws IOException {
158-
HttpUrl url = HttpUrl.parse(ENDPOINT + "search/")
159-
.newBuilder()
160-
.addQueryParameter("yql", yql)
161-
.addQueryParameter("query", query)
162-
.build();
142+
static <T> CompletableFuture<HttpResponse<T>> runSingleQuery(HttpClient client, String yql, String query, BodyHandler<T> handler) {
143+
String base = ENDPOINT.endsWith("/") ? ENDPOINT : ENDPOINT + "/";
144+
URI uri = URI.create(String.format("%ssearch/?yql=%s&query=%s",
145+
base,
146+
URLEncoder.encode(yql, StandardCharsets.UTF_8),
147+
URLEncoder.encode(query, StandardCharsets.UTF_8)));
163148

164-
Request request = new Request.Builder()
165-
.url(url)
166-
.build();
149+
var reqBuilder = HttpRequest.newBuilder()
150+
.uri(uri)
151+
.GET()
152+
.timeout(Duration.ofSeconds(5));
167153

168-
try (Response response = client.newCall(request).execute()) {
169-
if (response.code() != 200) {
170-
throw new IOException("Error code " + response.code());
171-
}
172-
if (response.body() != null) {
173-
// consume
174-
return Optional.of(response.body().string());
175-
}
154+
if (AUTH_METHOD == AuthMethod.TOKEN) {
155+
reqBuilder.header("Authorization", "Bearer " + TOKEN);
176156
}
177-
return Optional.empty();
157+
158+
return client.sendAsync(reqBuilder.build(), handler);
178159
}
179160

180161
static void loadTest() throws Exception {
181-
var client = createHttpClient();
162+
List<HttpClient> clients = new ArrayList<>(NUM_CONNECTIONS);
163+
for (int i = 0; i < NUM_CONNECTIONS; i++) {
164+
clients.add(createHttpClient());
165+
}
166+
167+
log.info("Warmup: 100 synchronous queries");
168+
for (int i = 0; i < 100; ++i) {
169+
try {
170+
runSingleQuery(clients.get(i % NUM_CONNECTIONS), LOAD_TEST_YQL, LOAD_TEST_QUERY, HttpResponse.BodyHandlers.discarding()).get();
171+
} catch (Exception e) {
172+
log.severe("Warmup query failed: " + e.getMessage());
173+
}
174+
}
182175

183-
ExecutorService executor = Executors.newFixedThreadPool(LOAD_CONCURRENCY);
184-
185-
AtomicLong resultsReceived = new AtomicLong(0);
186-
AtomicLong errorsReceived = new AtomicLong(0);
176+
log.info("Performing " + LOAD_NUM_QUERIES + " queries with " + LOAD_POOL_SIZE + " concurrent requests across " + NUM_CONNECTIONS + " connections");
187177

188-
log.info("Performing " + LOAD_NUM_QUERIES + " queries with concurrency: " + LOAD_CONCURRENCY);
178+
var remaining = new AtomicLong(LOAD_NUM_QUERIES);
179+
var resultsReceived = new AtomicLong(0);
180+
var errorsReceived = new AtomicLong(0);
181+
var latch = new CountDownLatch(LOAD_POOL_SIZE);
189182

190183
long startTimeMillis = System.currentTimeMillis();
191184

192-
for (int i = 0; i < LOAD_NUM_QUERIES; ++i) {
193-
executor.submit(() -> {
194-
try {
195-
runSingleQuery(client, "select * from sources * where userQuery()", "guinness world record");
196-
} catch (Exception e) {
197-
log.severe("Query iteration failed with: " + e.getMessage());
198-
errorsReceived.incrementAndGet();
199-
} finally {
200-
resultsReceived.incrementAndGet();
201-
}
202-
});
185+
for (int i = 0; i < LOAD_POOL_SIZE; i++) {
186+
sendNext(clients.get(i % NUM_CONNECTIONS), remaining, resultsReceived, errorsReceived, latch);
203187
}
204-
executor.shutdown();
205-
executor.awaitTermination(1, TimeUnit.HOURS);
188+
189+
latch.await();
206190

207191
long timeSpentMillis = System.currentTimeMillis() - startTimeMillis;
208192
double qps = (double)(resultsReceived.get() - errorsReceived.get()) / (timeSpentMillis / 1000.0);
@@ -212,12 +196,31 @@ static void loadTest() throws Exception {
212196
log.info("QPS: " + qps);
213197
}
214198

199+
static void sendNext(HttpClient client, AtomicLong remaining,
200+
AtomicLong resultsReceived, AtomicLong errorsReceived,
201+
CountDownLatch latch) {
202+
if (remaining.decrementAndGet() < 0) {
203+
latch.countDown();
204+
return;
205+
}
206+
runSingleQuery(client, "select * from sources * where userQuery()",
207+
"guinness world record", HttpResponse.BodyHandlers.discarding())
208+
.whenComplete((resp, ex) -> {
209+
if (ex != null) {
210+
log.severe("Query failed: " + ex.getMessage());
211+
errorsReceived.incrementAndGet();
212+
}
213+
resultsReceived.incrementAndGet();
214+
sendNext(client, remaining, resultsReceived, errorsReceived, latch);
215+
});
216+
}
217+
215218
/**
216219
* Feed documents from a .jsonl file given by {@code filePath}.
217220
*/
218221
static void feedFromFile(String filePath) {
219-
try (FileInputStream jsonStream = new FileInputStream(filePath)) {
220-
JsonFeeder feeder = createFeeder();
222+
try (FileInputStream jsonStream = new FileInputStream(filePath);
223+
JsonFeeder feeder = createFeeder()) {
221224
log.info("Starting feed");
222225

223226
AtomicLong resultsReceived = new AtomicLong(0);
@@ -245,7 +248,6 @@ public void onError(FeedException error) {
245248
});
246249

247250
promise.join();
248-
feeder.close();
249251

250252
long timeSpentMillis = (System.currentTimeMillis() - startTimeMillis);
251253
double okRatePerSec = (double)(resultsReceived.get() - errorsReceived.get()) / (timeSpentMillis / 1000.0);

0 commit comments

Comments
 (0)