Skip to content

Commit ba9907f

Browse files
committed
Migrate from built-in json to faster orjson
It is way faster than the current json parser: https://github.com/ijl/orjson?tab=readme-ov-file#performance
1 parent 6ea25a3 commit ba9907f

File tree

5 files changed

+20
-19
lines changed

5 files changed

+20
-19
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
python_requires=">=3.9",
8585
install_requires=[
8686
"lz4",
87+
"orjson",
8788
"python-dateutil",
8889
"pytz",
8990
# requests CVE https://github.com/advisories/GHSA-j8r2-6x86-q33q

tests/unit/oauth_test_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1010
# See the License for the specific language governing permissions and
1111
# limitations under the License.
12-
import json
1312
import re
1413
import uuid
1514
from collections import namedtuple
1615

1716
import httpretty
17+
import orjson
1818

1919
from trino import constants
2020

@@ -51,7 +51,7 @@ def __init__(self, redirect_server, token_server, tokens, sample_post_response_d
5151
def __call__(self, request, uri, response_headers):
5252
authorization = request.headers.get("Authorization")
5353
if authorization and authorization.replace("Bearer ", "") in self.tokens:
54-
return [200, response_headers, json.dumps(self.sample_post_response_data)]
54+
return [200, response_headers, orjson.dumps(self.sample_post_response_data)]
5555
elif self.redirect_server is None and self.token_server is not None:
5656
return [401,
5757
{
@@ -127,7 +127,7 @@ def post_statement_callback(self, request, uri, response_headers):
127127
authorization = request.headers.get("Authorization")
128128

129129
if authorization and authorization.replace("Bearer ", "") in self.tokens:
130-
return [200, response_headers, json.dumps(self.sample_post_response_data)]
130+
return [200, response_headers, orjson.dumps(self.sample_post_response_data)]
131131

132132
challenge_id = str(uuid.uuid4())
133133
token = str(uuid.uuid4())

tests/unit/test_client.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1010
# See the License for the specific language governing permissions and
1111
# limitations under the License.
12-
import json
1312
import threading
1413
import time
1514
import urllib
@@ -26,6 +25,7 @@
2625
import gssapi
2726
import httpretty
2827
import keyring
28+
import orjson
2929
import pytest
3030
import requests
3131
from httpretty import httprettified
@@ -61,8 +61,7 @@
6161

6262
@mock.patch("trino.client.TrinoRequest.http")
6363
def test_trino_initial_request(mock_requests, sample_post_response_data):
64-
mock_requests.Response.return_value.json.return_value = sample_post_response_data
65-
64+
mock_requests.Response.return_value.text = orjson.dumps(sample_post_response_data)
6665
req = TrinoRequest(
6766
host="coordinator",
6867
port=8080,
@@ -558,7 +557,7 @@ def test_oauth2_header_parsing(header, sample_post_response_data):
558557
def post_statement(request, uri, response_headers):
559558
authorization = request.headers.get("Authorization")
560559
if authorization and authorization.replace("Bearer ", "") in token:
561-
return [200, response_headers, json.dumps(sample_post_response_data)]
560+
return [200, response_headers, orjson.dumps(sample_post_response_data)]
562561
return [401, {'Www-Authenticate': header.format(redirect_server=redirect_server, token_server=token_server),
563562
'Basic realm': '"Trino"'}, ""]
564563

@@ -692,7 +691,7 @@ def run(self) -> None:
692691

693692
@mock.patch("trino.client.TrinoRequest.http")
694693
def test_trino_fetch_request(mock_requests, sample_get_response_data):
695-
mock_requests.Response.return_value.json.return_value = sample_get_response_data
694+
mock_requests.Response.return_value.text = orjson.dumps(sample_get_response_data)
696695

697696
req = TrinoRequest(
698697
host="coordinator",
@@ -718,7 +717,7 @@ def test_trino_fetch_request(mock_requests, sample_get_response_data):
718717

719718
@mock.patch("trino.client.TrinoRequest.http")
720719
def test_trino_fetch_request_data_none(mock_requests, sample_get_response_data_none):
721-
mock_requests.Response.return_value.json.return_value = sample_get_response_data_none
720+
mock_requests.Response.return_value.text = orjson.dumps(sample_get_response_data_none)
722721

723722
req = TrinoRequest(
724723
host="coordinator",
@@ -744,7 +743,7 @@ def test_trino_fetch_request_data_none(mock_requests, sample_get_response_data_n
744743

745744
@mock.patch("trino.client.TrinoRequest.http")
746745
def test_trino_fetch_error(mock_requests, sample_get_error_response_data):
747-
mock_requests.Response.return_value.json.return_value = sample_get_error_response_data
746+
mock_requests.Response.return_value.text = orjson.dumps(sample_get_error_response_data)
748747

749748
req = TrinoRequest(
750749
host="coordinator",
@@ -1154,8 +1153,9 @@ def headers(self):
11541153
'X-Trino-Fake-2': 'two',
11551154
}
11561155

1157-
def json(self):
1158-
return sample_get_response_data
1156+
@property
1157+
def text(self):
1158+
return orjson.dumps(sample_get_response_data)
11591159

11601160
req = TrinoRequest(
11611161
host="coordinator",

trino/auth.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# limitations under the License.
1212
import abc
1313
import importlib
14-
import json
1514
import os
1615
import re
1716
import threading
@@ -25,6 +24,7 @@
2524
from typing import Tuple
2625
from urllib.parse import urlparse
2726

27+
import orjson
2828
from requests import PreparedRequest
2929
from requests import Request
3030
from requests import Response
@@ -368,7 +368,7 @@ def get_token_from_cache(self, key: Optional[str]) -> Optional[str]:
368368
password = self._keyring.get_password(key, "token")
369369

370370
try:
371-
password_as_dict = json.loads(str(password))
371+
password_as_dict = orjson.loads(str(password))
372372
if password_as_dict.get("sharded_password"):
373373
# if password was stored shared, reconstruct it
374374
shard_count = int(password_as_dict.get("shard_count"))
@@ -404,7 +404,7 @@ def store_token_to_cache(self, key: Optional[str], token: str) -> None:
404404
}
405405

406406
# store the "shard info" as the "base" password
407-
self._keyring.set_password(key, "token", json.dumps(shard_info))
407+
self._keyring.set_password(key, "token", orjson.dumps(shard_info))
408408
# then store all shards with the shard number as postfix
409409
for i, s in enumerate(password_shards):
410410
self._keyring.set_password(key, f"token__{i}", s)
@@ -521,7 +521,7 @@ def _get_token(self, token_server: str, response: Response, **kwargs: Any) -> st
521521
with response.connection.send(Request(
522522
method='GET', url=token_server).prepare(), **kwargs) as response:
523523
if response.status_code == 200:
524-
token_response = json.loads(response.text)
524+
token_response = orjson.loads(response.text)
525525
token = token_response.get('token')
526526
if token:
527527
return token

trino/client.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
import base64
4040
import copy
4141
import functools
42-
import json
4342
import os
4443
import random
4544
import re
@@ -66,6 +65,7 @@
6665
from zoneinfo import ZoneInfo
6766

6867
import lz4.block
68+
import orjson
6969
import requests
7070
import zstandard
7171
from requests import Response
@@ -687,7 +687,7 @@ def process(self, http_response: Response) -> TrinoStatus:
687687
self.raise_response_error(http_response)
688688

689689
http_response.encoding = "utf-8"
690-
response = http_response.json()
690+
response = orjson.loads(http_response.text)
691691
if "error" in response and response["error"]:
692692
raise self._process_error(response["error"], response.get("id"))
693693

@@ -1285,7 +1285,7 @@ def __init__(self, mapper: RowMapper) -> None:
12851285
self._mapper = mapper
12861286

12871287
def decode(self, data: bytes, metadata: Dict[str, Any]) -> List[List[Any]]:
1288-
return self._mapper.map(json.loads(data.decode("utf8")))
1288+
return self._mapper.map(orjson.loads(data.decode("utf8")))
12891289

12901290

12911291
class CompressedQueryDataDecoder(QueryDataDecoder):

0 commit comments

Comments
 (0)