Skip to content

Commit c4a26ef

Browse files
committed
Migrate from built-in json to faster orjson
It is way faster than the current json parser: https://github.com/ijl/orjson?tab=readme-ov-file#performance
1 parent 6ea25a3 commit c4a26ef

File tree

3 files changed

+18
-10
lines changed

3 files changed

+18
-10
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
python_requires=">=3.9",
8585
install_requires=[
8686
"lz4",
87+
"orjson >= 3.11.0 ; platform_python_implementation != 'PyPy'",
8788
"python-dateutil",
8889
"pytz",
8990
# requests CVE https://github.com/advisories/GHSA-j8r2-6x86-q33q

tests/unit/test_client.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1010
# See the License for the specific language governing permissions and
1111
# limitations under the License.
12-
import json
1312
import threading
1413
import time
1514
import urllib
@@ -26,6 +25,10 @@
2625
import gssapi
2726
import httpretty
2827
import keyring
28+
try:
29+
import orjson as json
30+
except ImportError:
31+
import json
2932
import pytest
3033
import requests
3134
from httpretty import httprettified
@@ -61,8 +64,7 @@
6164

6265
@mock.patch("trino.client.TrinoRequest.http")
6366
def test_trino_initial_request(mock_requests, sample_post_response_data):
64-
mock_requests.Response.return_value.json.return_value = sample_post_response_data
65-
67+
mock_requests.Response.return_value.text = json.dumps(sample_post_response_data)
6668
req = TrinoRequest(
6769
host="coordinator",
6870
port=8080,
@@ -692,7 +694,7 @@ def run(self) -> None:
692694

693695
@mock.patch("trino.client.TrinoRequest.http")
694696
def test_trino_fetch_request(mock_requests, sample_get_response_data):
695-
mock_requests.Response.return_value.json.return_value = sample_get_response_data
697+
mock_requests.Response.return_value.text = json.dumps(sample_get_response_data)
696698

697699
req = TrinoRequest(
698700
host="coordinator",
@@ -718,7 +720,7 @@ def test_trino_fetch_request(mock_requests, sample_get_response_data):
718720

719721
@mock.patch("trino.client.TrinoRequest.http")
720722
def test_trino_fetch_request_data_none(mock_requests, sample_get_response_data_none):
721-
mock_requests.Response.return_value.json.return_value = sample_get_response_data_none
723+
mock_requests.Response.return_value.text = json.dumps(sample_get_response_data_none)
722724

723725
req = TrinoRequest(
724726
host="coordinator",
@@ -744,7 +746,7 @@ def test_trino_fetch_request_data_none(mock_requests, sample_get_response_data_n
744746

745747
@mock.patch("trino.client.TrinoRequest.http")
746748
def test_trino_fetch_error(mock_requests, sample_get_error_response_data):
747-
mock_requests.Response.return_value.json.return_value = sample_get_error_response_data
749+
mock_requests.Response.return_value.text = json.dumps(sample_get_error_response_data)
748750

749751
req = TrinoRequest(
750752
host="coordinator",
@@ -1154,8 +1156,9 @@ def headers(self):
11541156
'X-Trino-Fake-2': 'two',
11551157
}
11561158

1157-
def json(self):
1158-
return sample_get_response_data
1159+
@property
1160+
def text(self):
1161+
return json.dumps(sample_get_response_data)
11591162

11601163
req = TrinoRequest(
11611164
host="coordinator",

trino/client.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
import base64
4040
import copy
4141
import functools
42-
import json
4342
import os
4443
import random
4544
import re
@@ -66,6 +65,11 @@
6665
from zoneinfo import ZoneInfo
6766

6867
import lz4.block
68+
try:
69+
import orjson as json
70+
except ImportError:
71+
import json
72+
6973
import requests
7074
import zstandard
7175
from requests import Response
@@ -687,7 +691,7 @@ def process(self, http_response: Response) -> TrinoStatus:
687691
self.raise_response_error(http_response)
688692

689693
http_response.encoding = "utf-8"
690-
response = http_response.json()
694+
response = json.loads(http_response.text)
691695
if "error" in response and response["error"]:
692696
raise self._process_error(response["error"], response.get("id"))
693697

0 commit comments

Comments
 (0)