1+ /*
2+ * Copyright 2023 Basis Technology Corp.
3+ *
4+ * Licensed under the Apache License, Version 2.0 (the "License");
5+ * you may not use this file except in compliance with the License.
6+ * You may obtain a copy of the License at
7+ *
8+ * http://www.apache.org/licenses/LICENSE-2.0
9+ *
10+ * Unless required by applicable law or agreed to in writing, software
11+ * distributed under the License is distributed on an "AS IS" BASIS,
12+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+ * See the License for the specific language governing permissions and
14+ * limitations under the License.
15+ */
16+ package com .basistech .rosette .api ;
17+
18+ import com .basistech .rosette .api .common .AbstractRosetteAPI ;
19+ import com .basistech .rosette .apimodel .DocumentRequest ;
20+ import com .basistech .rosette .apimodel .EntitiesOptions ;
21+ import com .basistech .rosette .apimodel .EntitiesResponse ;
22+ import com .basistech .rosette .apimodel .ErrorResponse ;
23+ import com .basistech .rosette .apimodel .LanguageOptions ;
24+ import com .basistech .rosette .apimodel .LanguageResponse ;
25+ import com .basistech .rosette .apimodel .Response ;
26+ import org .junit .jupiter .api .AfterEach ;
27+ import org .junit .jupiter .api .BeforeEach ;
28+ import org .junit .jupiter .api .Test ;
29+ import org .junit .jupiter .api .extension .ExtendWith ;
30+ import org .mockserver .client .MockServerClient ;
31+ import org .mockserver .junit .jupiter .MockServerExtension ;
32+ import org .mockserver .matchers .Times ;
33+ import org .mockserver .model .HttpRequest ;
34+ import org .mockserver .model .HttpResponse ;
35+
36+ import java .io .IOException ;
37+ import java .util .ArrayList ;
38+ import java .util .Date ;
39+ import java .util .List ;
40+ import java .util .concurrent .ExecutionException ;
41+ import java .util .concurrent .ExecutorService ;
42+ import java .util .concurrent .Executors ;
43+ import java .util .concurrent .Future ;
44+ import java .util .concurrent .TimeUnit ;
45+
46+ import static org .junit .jupiter .api .Assertions .assertEquals ;
47+ import static org .junit .jupiter .api .Assertions .assertInstanceOf ;
48+ import static org .junit .jupiter .api .Assertions .assertTrue ;
49+
50+ @ ExtendWith (MockServerExtension .class )
51+ class RosetteRequestTest {
52+ private MockServerClient mockServer ;
53+ private HttpRosetteAPI api ;
54+
55+ @ BeforeEach
56+ void setup (MockServerClient mockServer ) {
57+ this .mockServer = mockServer ;
58+ }
59+
60+ private void setupResponse (String requestPath , String responseString , int statusCode , int delayMillis , int requestTimes ) {
61+ this .mockServer .when (HttpRequest .request ().withPath (requestPath ), Times .exactly (requestTimes ))
62+ .respond (HttpResponse .response ()
63+ .withHeader ("Content-Type" , "application/json" )
64+ .withHeader ("X-RosetteAPI-Concurrency" , "5" )
65+ .withStatusCode (statusCode )
66+ .withBody (responseString )
67+ .withDelay (TimeUnit .MILLISECONDS , delayMillis ));
68+ }
69+
70+
71+ @ Test
72+ void successfulRequest () throws ExecutionException , InterruptedException {
73+ //Api client setup
74+ this .api = new HttpRosetteAPI .Builder ().url (String .format ("http://localhost:%d/rest/v1" , mockServer .getPort ())).build ();
75+
76+ //response setup
77+ String entitiesResponse = "{\" entities\" : [ { \" type\" : \" ORGANIZATION\" , \" mention\" : \" Securities and Exchange Commission\" , \" normalized\" : \" U.S. Securities and Exchange Commission\" , \" count\" : 1, \" mentionOffsets\" : [ { \" startOffset\" : 4, \" endOffset\" : 38 } ], \" entityId\" : \" Q953944\" , \" confidence\" : 0.39934742, \" linkingConfidence\" : 0.67404154 } ] }" ;
78+ setupResponse ("/rest/v1/entities" , entitiesResponse , 200 , 0 , 1 );
79+
80+ //request setup
81+ String entitiesTextData = "The Securities and Exchange Commission today announced the leadership of the agency’s trial unit." ;
82+ DocumentRequest <EntitiesOptions > entitiesRequestData = DocumentRequest .<EntitiesOptions >builder ()
83+ .content (entitiesTextData )
84+ .build ();
85+ RosetteRequest entitiesRequest = this .api .createRosetteRequest (AbstractRosetteAPI .ENTITIES_SERVICE_PATH , entitiesRequestData , EntitiesResponse .class );
86+
87+ //testing the request
88+ ExecutorService threadPool = Executors .newFixedThreadPool (1 );
89+ Future <Response > response = threadPool .submit (entitiesRequest );
90+ assertInstanceOf (EntitiesResponse .class , response .get ());
91+ assertEquals (response .get (), entitiesRequest .getResponse ());
92+ threadPool .shutdownNow ();
93+ }
94+
95+
96+ @ Test
97+ void errorResponse () throws ExecutionException , InterruptedException {
98+ //Api client setup
99+ this .api = new HttpRosetteAPI .Builder ().url (String .format ("http://localhost:%d/rest/v1" , mockServer .getPort ())).build ();
100+
101+ //response setup
102+ String entitiesResponse = "{ \" code\" : \" badRequestFormat\" , \" message\" : \" no content provided; must be one of an attachment, an inline \\ \" content\\ \" field, or an external \\ \" contentUri\\ \" \" }" ;
103+ setupResponse ("/rest/v1/entities" , entitiesResponse , 400 , 0 , 1 );
104+
105+ //request setup
106+ DocumentRequest <EntitiesOptions > entitiesRequestData = DocumentRequest .<EntitiesOptions >builder ()
107+ .build ();
108+ RosetteRequest entitiesRequest = this .api .createRosetteRequest (AbstractRosetteAPI .ENTITIES_SERVICE_PATH , entitiesRequestData , EntitiesResponse .class );
109+
110+ //testing the request
111+ ExecutorService threadPool = Executors .newFixedThreadPool (1 );
112+ Future <Response > response = threadPool .submit (entitiesRequest );
113+ assertInstanceOf (ErrorResponse .class , response .get ());
114+ assertEquals (response .get (), entitiesRequest .getResponse ());
115+ threadPool .shutdownNow ();
116+ }
117+
118+ @ Test
119+ void testTiming () throws ExecutionException , InterruptedException {
120+ int delay = 100 ;
121+ //api setup
122+ this .api = new HttpRosetteAPI .Builder ().url (String .format ("http://localhost:%d/rest/v1" , mockServer .getPort ()))
123+ .connectionConcurrency (1 ).build ();
124+
125+ //responses setup
126+ int entitiesRespCount = 10 ;
127+ int languageRespCount = 4 ;
128+ assertEquals (0 , entitiesRespCount % 2 );
129+ assertEquals (0 , entitiesRespCount % 2 );
130+ String entitiesResponse = "{\" entities\" : [ { \" type\" : \" ORGANIZATION\" , \" mention\" : \" Securities and Exchange Commission\" , \" normalized\" : \" U.S. Securities and Exchange Commission\" , \" count\" : 1, \" mentionOffsets\" : [ { \" startOffset\" : 4, \" endOffset\" : 38 } ], \" entityId\" : \" Q953944\" , \" confidence\" : 0.39934742, \" linkingConfidence\" : 0.67404154 } ] }" ;
131+ setupResponse ("/rest/v1/entities" , entitiesResponse , 200 , delay , entitiesRespCount );
132+ String languageResponse = " {\" code\" : \" badRequestFormat\" , \" message\" : \" no content provided; must be one of an attachment, an inline \\ \" content\\ \" field, or an external \\ \" contentUri\\ \" \" }" ;
133+ setupResponse ("/rest/v1/language" , languageResponse , 400 , delay , languageRespCount );
134+
135+ //requests setup
136+ String entitiesTextData = "The Securities and Exchange Commission today announced the leadership of the agency’s trial unit." ;
137+ DocumentRequest <EntitiesOptions > entitiesRequestData = DocumentRequest .<EntitiesOptions >builder ()
138+ .content (entitiesTextData )
139+ .build ();
140+ DocumentRequest <LanguageOptions > languageRequestData = DocumentRequest .<LanguageOptions >builder ().build ();
141+ List <RosetteRequest > requests = new ArrayList <>();
142+ for (int i = 0 ; i < entitiesRespCount / 2 ; i ++) {
143+ requests .add (this .api .createRosetteRequest (AbstractRosetteAPI .ENTITIES_SERVICE_PATH , entitiesRequestData , EntitiesResponse .class ));
144+ }
145+ for (int i = 0 ; i < languageRespCount / 2 ; i ++) {
146+ requests .add (this .api .createRosetteRequest (AbstractRosetteAPI .LANGUAGE_SERVICE_PATH , languageRequestData , LanguageResponse .class ));
147+ }
148+
149+ //run requests
150+ ExecutorService threadPool = Executors .newFixedThreadPool (7 );
151+ Date d1 = new Date ();
152+ List <Future <Response >> responses = threadPool .invokeAll (requests );
153+ for (int i = 0 ; i < responses .size (); i ++) {
154+ responses .get (i ).get ();
155+ }
156+ Date d2 = new Date ();
157+
158+ assertTrue (d2 .getTime () - d1 .getTime () > delay * requests .size ()); // at least as long as the delay in the request
159+
160+ //run requests concurrently
161+ int concurrency = 3 ;
162+ this .api = new HttpRosetteAPI .Builder ().url (String .format ("http://localhost:%d/rest/v1" , mockServer .getPort ()))
163+ .connectionConcurrency (3 ).build ();
164+
165+ requests = new ArrayList <>();
166+ for (int i = 0 ; i < entitiesRespCount / 2 ; i ++) {
167+ requests .add (this .api .createRosetteRequest (AbstractRosetteAPI .ENTITIES_SERVICE_PATH , entitiesRequestData , EntitiesResponse .class ));
168+ }
169+ for (int i = 0 ; i < entitiesRespCount / 2 ; i ++) {
170+ requests .add (this .api .createRosetteRequest (AbstractRosetteAPI .LANGUAGE_SERVICE_PATH , languageRequestData , LanguageResponse .class ));
171+ }
172+
173+
174+ d1 = new Date ();
175+ responses = threadPool .invokeAll (requests );
176+ for (int i = 0 ; i < responses .size (); i ++) {
177+ responses .get (i ).get ();
178+ }
179+ d2 = new Date ();
180+
181+ assertTrue (d2 .getTime () - d1 .getTime () < delay * requests .size ()); // less than serial requests
182+ assertTrue (d2 .getTime () - d1 .getTime () > requests .size () / concurrency * delay ); // running faster than this would suggest it exceeds the maximum concurrency
183+ }
184+
185+ @ AfterEach
186+ void after () throws IOException {
187+ this .api .close ();
188+ }
189+
190+ }
0 commit comments