Skip to content

Commit f725954

Browse files
author
Chris Park
committed
Merge branch 'develop'
2 parents 2795ca0 + 4e654ed commit f725954

File tree

4 files changed

+101
-44
lines changed

4 files changed

+101
-44
lines changed

examples/text_embedding.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# -*- coding: utf-8 -*-
2+
3+
"""
4+
Example code to call Rosette API to get text vectors from a piece of text.
5+
"""
6+
7+
import argparse
8+
import json
9+
import os
10+
11+
from rosette.api import API, DocumentParameters
12+
13+
14+
def run(key, altUrl='https://api.rosette.com/rest/v1/'):
15+
# Create an API instance
16+
api = API(user_key=key, service_url=altUrl)
17+
embeddings_data = "Cambridge, Massachusetts"
18+
params = DocumentParameters()
19+
params["content"] = embeddings_data
20+
return api.text_embedding(params)
21+
22+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description='Calls the ' + os.path.splitext(os.path.basename(__file__))[0] + ' endpoint')
23+
parser.add_argument('-k', '--key', help='Rosette API Key', required=True)
24+
parser.add_argument('-u', '--url', help="Alternative API URL", default='https://api.rosette.com/rest/v1/')
25+
26+
if __name__ == '__main__':
27+
args = parser.parse_args()
28+
result = run(args.key, args.url)
29+
print(json.dumps(result, indent=2, ensure_ascii=False, sort_keys=True).encode("utf8"))

rosette/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@
1616
limitations under the License.
1717
"""
1818

19-
__version__ = '1.2.0'
19+
__version__ = '1.3.0'

rosette/api.py

Lines changed: 42 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -24,26 +24,16 @@
2424
import sys
2525
import time
2626
import os
27-
from socket import gaierror
2827
import requests
2928
import re
3029
import warnings
3130

32-
_BINDING_VERSION = '1.1.1'
31+
_BINDING_VERSION = '1.3.0'
3332
_GZIP_BYTEARRAY = bytearray([0x1F, 0x8b, 0x08])
3433

3534
_IsPy3 = sys.version_info[0] == 3
3635

3736

38-
try:
39-
import urlparse
40-
except ImportError:
41-
import urllib.parse as urlparse
42-
try:
43-
import httplib
44-
except ImportError:
45-
import http.client as httplib
46-
4737
if _IsPy3:
4838
_GZIP_SIGNATURE = _GZIP_BYTEARRAY
4939
else:
@@ -482,8 +472,8 @@ def call(self, parameters):
482472
'application/json')}
483473
request = requests.Request(
484474
'POST', url, files=files, headers=headers)
485-
prepared_request = request.prepare()
486475
session = requests.Session()
476+
prepared_request = session.prepare_request(request)
487477
resp = session.send(prepared_request)
488478
rdata = resp.content
489479
response_headers = {"responseHeaders": dict(resp.headers)}
@@ -512,7 +502,6 @@ def __init__(
512502
user_key=None,
513503
service_url='https://api.rosette.com/rest/v1/',
514504
retries=5,
515-
reuse_connection=True,
516505
refresh_duration=0.5,
517506
debug=False):
518507
""" Create an L{API} object.
@@ -534,22 +523,18 @@ def __init__(
534523
refresh_duration = 0
535524

536525
self.num_retries = retries
537-
self.reuse_connection = reuse_connection
538526
self.connection_refresh_duration = refresh_duration
539-
self.http_connection = None
540527
self.options = {}
541528
self.customHeaders = {}
529+
self.maxPoolSize = 1
530+
self.session = requests.Session()
542531

543-
def _connect(self, parsedUrl):
544-
""" Simple connection method
545-
@param parsedUrl: The URL on which to process
546-
"""
547-
if not self.reuse_connection or self.http_connection is None:
548-
loc = parsedUrl.netloc
549-
if parsedUrl.scheme == "https":
550-
self.http_connection = httplib.HTTPSConnection(loc)
551-
else:
552-
self.http_connection = httplib.HTTPConnection(loc)
532+
def _set_pool_size(self):
533+
adapter = requests.adapters.HTTPAdapter(pool_maxsize=self.maxPoolSize)
534+
if 'https:' in self.service_url:
535+
self.session.mount('https://', adapter)
536+
else:
537+
self.session.mount('http://', adapter)
553538

554539
def _make_request(self, op, url, data, headers):
555540
"""
@@ -561,32 +546,34 @@ def _make_request(self, op, url, data, headers):
561546
@param headers: request headers
562547
"""
563548
headers['User-Agent'] = "RosetteAPIPython/" + _BINDING_VERSION
564-
parsedUrl = urlparse.urlparse(url)
565-
566-
self._connect(parsedUrl)
567549

568550
message = None
569551
code = "unknownError"
570552
rdata = None
571553
response_headers = {}
554+
555+
request = requests.Request(op, url, data=data, headers=headers)
556+
session = requests.Session()
557+
prepared_request = session.prepare_request(request)
558+
572559
for i in range(self.num_retries + 1):
573560
try:
574-
self.http_connection.request(op, url, data, headers)
575-
response = self.http_connection.getresponse()
576-
status = response.status
577-
rdata = response.read()
578-
response_headers["responseHeaders"] = (
579-
dict(response.getheaders()))
561+
response = session.send(prepared_request)
562+
status = response.status_code
563+
rdata = response.content
564+
dict_headers = dict(response.headers)
565+
response_headers = {"responseHeaders": dict_headers}
566+
if 'x-rosetteapi-concurrency' in dict_headers:
567+
if dict_headers['x-rosetteapi-concurrency'] != self.maxPoolSize:
568+
self.maxPoolSize = dict_headers['x-rosetteapi-concurrency']
569+
self._set_pool_size()
570+
580571
if status == 200:
581-
if not self.reuse_connection:
582-
self.http_connection.close()
583572
return rdata, status, response_headers
584573
if status == 429:
585574
code = status
586575
message = "{0} ({1})".format(rdata, i)
587576
time.sleep(self.connection_refresh_duration)
588-
self.http_connection.close()
589-
self._connect(parsedUrl)
590577
continue
591578
if rdata is not None:
592579
try:
@@ -598,17 +585,15 @@ def _make_request(self, op, url, data, headers):
598585
else:
599586
code = status
600587
raise RosetteException(code, message, url)
588+
601589
except:
602590
raise
603-
except (httplib.BadStatusLine, gaierror):
591+
except requests.exceptions.RequestException as e:
604592
raise RosetteException(
605-
"ConnectionError",
593+
e.message,
606594
"Unable to establish connection to the Rosette API server",
607595
url)
608596

609-
if not self.reuse_connection:
610-
self.http_connection.close()
611-
612597
raise RosetteException(code, message, url)
613598

614599
def _get_http(self, url, headers):
@@ -644,6 +629,12 @@ def _post_http(self, url, data, headers):
644629

645630
return _ReturnObject(_my_loads(rdata, response_headers), status)
646631

632+
def getPoolSize(self):
633+
"""
634+
Returns the maximum pool size, which is the returned x-rosetteapi-concurrency value
635+
"""
636+
return int(self.maxPoolSize)
637+
647638
def setOption(self, name, value):
648639
"""
649640
Sets an option
@@ -842,3 +833,11 @@ def matched_name(self, parameters):
842833
@type parameters: L{NameSimilarityParameters}
843834
@return: A python dictionary containing the results of name matching."""
844835
return self.name_similarity(parameters)
836+
837+
def text_embedding(self, parameters):
838+
"""
839+
Create an L{EndpointCaller} to identify text vectors found in the texts
840+
to which it is applied and call it.
841+
@type parameters: L{DocumentParameters} or L{str}
842+
@return: A python dictionary containing the results of text embedding."""
843+
return EndpointCaller(self, "text-embedding").call(parameters)

tests/test_rosette_api.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,24 @@ def test_for_409(api, json_409):
170170
httpretty.disable()
171171
httpretty.reset()
172172

173+
# Test the maxPoolSize
174+
175+
176+
def test_the_max_pool_size(json_response, doc_params):
177+
httpretty.enable()
178+
httpretty.register_uri(httpretty.POST, "https://api.rosette.com/rest/v1/language",
179+
body=json_response, status=200, content_type="application/json",
180+
adding_headers={
181+
'x-rosetteapi-concurrency': 5
182+
})
183+
api = API('bogus_key')
184+
assert api.getPoolSize() == 1
185+
result = api.language(doc_params)
186+
assert result["name"] == "Rosette API"
187+
assert api.getPoolSize() == 5
188+
httpretty.disable()
189+
httpretty.reset()
190+
173191
# Test the language endpoint
174192

175193

@@ -561,3 +579,14 @@ def test_for_name_translation_required_parameters(api, json_response):
561579

562580
httpretty.disable()
563581
httpretty.reset()
582+
583+
584+
def test_the_text_embedded_endpoint(api, json_response, doc_params):
585+
httpretty.enable()
586+
httpretty.register_uri(httpretty.POST, "https://api.rosette.com/rest/v1/text-embedding",
587+
body=json_response, status=200, content_type="application/json")
588+
589+
result = api.text_embedding(doc_params)
590+
assert result["name"] == "Rosette API"
591+
httpretty.disable()
592+
httpretty.reset()

0 commit comments

Comments
 (0)