33import java .io .FileInputStream ;
44import java .io .IOException ;
55import java .net .URI ;
6+ import java .net .URLEncoder ;
7+ import java .net .http .HttpClient ;
8+ import java .net .http .HttpRequest ;
9+ import java .net .http .HttpResponse ;
10+ import java .net .http .HttpResponse .BodyHandler ;
11+ import java .nio .charset .StandardCharsets ;
612import java .nio .file .Path ;
713import java .time .Duration ;
8- import java .util .Optional ;
9- import java .util .concurrent . ExecutorService ;
10- import java .util .concurrent .Executors ;
11- import java .util .concurrent .TimeUnit ;
14+ import java .util .ArrayList ;
15+ import java .util .List ;
16+ import java .util .concurrent .CompletableFuture ;
17+ import java .util .concurrent .CountDownLatch ;
1218import java .util .concurrent .atomic .AtomicLong ;
1319import java .util .logging .Logger ;
1420
2733import ai .vespa .feed .client .Result ;
2834import nl .altindag .ssl .SSLFactory ;
2935import nl .altindag .ssl .pem .util .PemUtils ;
30- import okhttp3 .ConnectionPool ;
31- import okhttp3 .HttpUrl ;
32- import okhttp3 .OkHttpClient ;
33- import okhttp3 .Request ;
34- import okhttp3 .Response ;
36+
3537
3638public class VespaClient {
3739 private final static Logger log = Logger .getLogger (VespaClient .class .getName ());
@@ -40,20 +42,26 @@ private enum AuthMethod {
4042 MTLS , // mTLS: Recommended for Vespa Cloud
4143 TOKEN , // Token-based authentication
4244 NONE // E.g. if self-hosting.
43- };
45+ }
4446
4547 private static final AuthMethod AUTH_METHOD = AuthMethod .MTLS ;
4648
47- private static final String ENDPOINT = "" ;
48- // Auth method mTLS
49- private static final String PUBLIC_CERT = "" ;
50- private static final String PRIVATE_KEY = "" ;
49+ private static final String ENDPOINT = "YOUR_ENDPOINT " ;
50+ // Auth method: mTLS
51+ private static final String PUBLIC_CERT = "/path/to/public-cert.pem " ;
52+ private static final String PRIVATE_KEY = "/peth/to/private-key.pem " ;
5153
52- // Auth method token.
53- private static final String TOKEN = "" ;
54+ // Auth method: token.
55+ private static final String TOKEN = "YOUR_TOKEN " ;
5456
55- private static final int LOAD_CONCURRENCY = 400 ;
56- private static final int LOAD_NUM_QUERIES = 50000 ;
57+ // Number of concurrent in-flight HTTP/2 streams across all connections.
58+ private static final int LOAD_POOL_SIZE = 800 ;
59+ private static final int LOAD_NUM_QUERIES = 1000000 ;
60+ // Each HttpClient opens its own connection. Multiple connections spread load
61+ // across container nodes via the load balancer.
62+ private static final int NUM_CONNECTIONS = 16 ;
63+ private static final String LOAD_TEST_YQL = "select * from sources * where userQuery()" ;
64+ private static final String LOAD_TEST_QUERY = "guinness world record" ;
5765
5866 public static void main (String [] args ) throws Exception {
5967 Options options = new Options ();
@@ -74,8 +82,8 @@ public static void main(String[] args) throws Exception {
7482 } else if (cmd .hasOption ("q" )) {
7583 String query = cmd .getOptionValue ("q" );
7684 try {
77- String result = runSingleQuery (createHttpClient (), "select * from sources * where userQuery()" , query ).get ();
78- log .info (result );
85+ HttpResponse < String > response = runSingleQuery (createHttpClient (), "select * from sources * where userQuery()" , query , HttpResponse . BodyHandlers . ofString () ).get ();
86+ log .info (response . body () );
7987 } catch (Exception e ) {
8088 log .severe ("Query failed with message: " + e .getMessage ());
8189 }
@@ -98,39 +106,16 @@ static SSLFactory getSSLFactory() {
98106 return sslFactory ;
99107 }
100108
101- /**
102- * Create a {@link OkHttpClient} for querying, with settings based on {@link VespaClient#AUTH_METHOD}.
103- */
104- static OkHttpClient createHttpClient () {
105- var builder = new OkHttpClient .Builder ()
106- .connectionPool (new ConnectionPool (LOAD_CONCURRENCY , 5 , TimeUnit .MINUTES ))
107- .connectTimeout (5 , TimeUnit .SECONDS )
108- .readTimeout (2 , TimeUnit .SECONDS );
109+ static HttpClient createHttpClient () {
110+ var clientBuilder = HttpClient .newBuilder ()
111+ .version (HttpClient .Version .HTTP_2 )
112+ .connectTimeout (Duration .ofSeconds (5 ));
109113
110- switch (AUTH_METHOD ) {
111- case MTLS :
112- {
113- var sslFactory = getSSLFactory ();
114- builder .sslSocketFactory (sslFactory .getSslSocketFactory (), sslFactory .getTrustManager ().orElseThrow ());
115- }
116- break ;
117- case TOKEN :
118- {
119- builder .addInterceptor (chain -> {
120- return chain .proceed (
121- chain .request ()
122- .newBuilder ()
123- .header ("Authorization" , "Bearer " + TOKEN )
124- .build ()
125- );
126- });
127- }
128- break ;
129- case NONE :
130- break ;
114+ if (AUTH_METHOD == AuthMethod .MTLS ) {
115+ clientBuilder .sslContext (getSSLFactory ().getSslContext ());
131116 }
132117
133- return builder .build ();
118+ return clientBuilder .build ();
134119 }
135120
136121 /**
@@ -154,55 +139,54 @@ static JsonFeeder createFeeder() {
154139 .build ();
155140 }
156141
157- static Optional < String > runSingleQuery (OkHttpClient client , String yql , String query ) throws IOException {
158- HttpUrl url = HttpUrl . parse ( ENDPOINT + "search/" )
159- . newBuilder ()
160- . addQueryParameter ( "yql" , yql )
161- . addQueryParameter ( "query" , query )
162- . build ( );
142+ static < T > CompletableFuture < HttpResponse < T >> runSingleQuery (HttpClient client , String yql , String query , BodyHandler < T > handler ) {
143+ String base = ENDPOINT . endsWith ( "/" ) ? ENDPOINT : ENDPOINT + "/" ;
144+ URI uri = URI . create ( String . format ( "%ssearch/?yql=%s&query=%s" ,
145+ base ,
146+ URLEncoder . encode ( yql , StandardCharsets . UTF_8 ),
147+ URLEncoder . encode ( query , StandardCharsets . UTF_8 )) );
163148
164- Request request = new Request .Builder ()
165- .url (url )
166- .build ();
149+ var reqBuilder = HttpRequest .newBuilder ()
150+ .uri (uri )
151+ .GET ()
152+ .timeout (Duration .ofSeconds (5 ));
167153
168- try (Response response = client .newCall (request ).execute ()) {
169- if (response .code () != 200 ) {
170- throw new IOException ("Error code " + response .code ());
171- }
172- if (response .body () != null ) {
173- // consume
174- return Optional .of (response .body ().string ());
175- }
154+ if (AUTH_METHOD == AuthMethod .TOKEN ) {
155+ reqBuilder .header ("Authorization" , "Bearer " + TOKEN );
176156 }
177- return Optional .empty ();
157+
158+ return client .sendAsync (reqBuilder .build (), handler );
178159 }
179160
180161 static void loadTest () throws Exception {
181- var client = createHttpClient ();
162+ List <HttpClient > clients = new ArrayList <>(NUM_CONNECTIONS );
163+ for (int i = 0 ; i < NUM_CONNECTIONS ; i ++) {
164+ clients .add (createHttpClient ());
165+ }
166+
167+ log .info ("Warmup: 100 synchronous queries" );
168+ for (int i = 0 ; i < 100 ; ++i ) {
169+ try {
170+ runSingleQuery (clients .get (i % NUM_CONNECTIONS ), LOAD_TEST_YQL , LOAD_TEST_QUERY , HttpResponse .BodyHandlers .discarding ()).get ();
171+ } catch (Exception e ) {
172+ log .severe ("Warmup query failed: " + e .getMessage ());
173+ }
174+ }
182175
183- ExecutorService executor = Executors .newFixedThreadPool (LOAD_CONCURRENCY );
184-
185- AtomicLong resultsReceived = new AtomicLong (0 );
186- AtomicLong errorsReceived = new AtomicLong (0 );
176+ log .info ("Performing " + LOAD_NUM_QUERIES + " queries with " + LOAD_POOL_SIZE + " concurrent requests across " + NUM_CONNECTIONS + " connections" );
187177
188- log .info ("Performing " + LOAD_NUM_QUERIES + " queries with concurrency: " + LOAD_CONCURRENCY );
178+ var remaining = new AtomicLong (LOAD_NUM_QUERIES );
179+ var resultsReceived = new AtomicLong (0 );
180+ var errorsReceived = new AtomicLong (0 );
181+ var latch = new CountDownLatch (LOAD_POOL_SIZE );
189182
190183 long startTimeMillis = System .currentTimeMillis ();
191184
192- for (int i = 0 ; i < LOAD_NUM_QUERIES ; ++i ) {
193- executor .submit (() -> {
194- try {
195- runSingleQuery (client , "select * from sources * where userQuery()" , "guinness world record" );
196- } catch (Exception e ) {
197- log .severe ("Query iteration failed with: " + e .getMessage ());
198- errorsReceived .incrementAndGet ();
199- } finally {
200- resultsReceived .incrementAndGet ();
201- }
202- });
185+ for (int i = 0 ; i < LOAD_POOL_SIZE ; i ++) {
186+ sendNext (clients .get (i % NUM_CONNECTIONS ), remaining , resultsReceived , errorsReceived , latch );
203187 }
204- executor . shutdown ();
205- executor . awaitTermination ( 1 , TimeUnit . HOURS );
188+
189+ latch . await ( );
206190
207191 long timeSpentMillis = System .currentTimeMillis () - startTimeMillis ;
208192 double qps = (double )(resultsReceived .get () - errorsReceived .get ()) / (timeSpentMillis / 1000.0 );
@@ -212,12 +196,31 @@ static void loadTest() throws Exception {
212196 log .info ("QPS: " + qps );
213197 }
214198
199+ static void sendNext (HttpClient client , AtomicLong remaining ,
200+ AtomicLong resultsReceived , AtomicLong errorsReceived ,
201+ CountDownLatch latch ) {
202+ if (remaining .decrementAndGet () < 0 ) {
203+ latch .countDown ();
204+ return ;
205+ }
206+ runSingleQuery (client , "select * from sources * where userQuery()" ,
207+ "guinness world record" , HttpResponse .BodyHandlers .discarding ())
208+ .whenComplete ((resp , ex ) -> {
209+ if (ex != null ) {
210+ log .severe ("Query failed: " + ex .getMessage ());
211+ errorsReceived .incrementAndGet ();
212+ }
213+ resultsReceived .incrementAndGet ();
214+ sendNext (client , remaining , resultsReceived , errorsReceived , latch );
215+ });
216+ }
217+
215218 /**
216219 * Feed documents from a .jsonl file given by {@code filePath}.
217220 */
218221 static void feedFromFile (String filePath ) {
219- try (FileInputStream jsonStream = new FileInputStream (filePath )) {
220- JsonFeeder feeder = createFeeder ();
222+ try (FileInputStream jsonStream = new FileInputStream (filePath );
223+ JsonFeeder feeder = createFeeder ()) {
221224 log .info ("Starting feed" );
222225
223226 AtomicLong resultsReceived = new AtomicLong (0 );
@@ -245,7 +248,6 @@ public void onError(FeedException error) {
245248 });
246249
247250 promise .join ();
248- feeder .close ();
249251
250252 long timeSpentMillis = (System .currentTimeMillis () - startTimeMillis );
251253 double okRatePerSec = (double )(resultsReceived .get () - errorsReceived .get ()) / (timeSpentMillis / 1000.0 );
0 commit comments