Skip to content

Commit 35ba194

Browse files
mdesmetebyhr
authored andcommitted
Configure experimental_python_types through dbapi.Connection
1 parent ed2203f commit 35ba194

File tree

6 files changed

+74
-22
lines changed

6 files changed

+74
-22
lines changed

README.md

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,15 +99,17 @@ engine = create_engine(
9999
'trino://user@localhost:8080/system',
100100
connect_args={
101101
"session_properties": {'query_max_run_time': '1d'},
102-
"client_tags": ["tag1", "tag2"]
102+
"client_tags": ["tag1", "tag2"],
103+
"experimental_python_types": True,
103104
}
104105
)
105106

106107
# or in connection string
107108
engine = create_engine(
108109
'trino://user@localhost:8080/system?'
109110
'session_properties={"query_max_run_time": "1d"}'
110-
'&client_tags=["tag1", "tag2"]',
111+
'&client_tags=["tag1", "tag2"]'
112+
'&experimental_python_types=true',
111113
)
112114
```
113115

@@ -361,10 +363,11 @@ import pytz
361363
from datetime import datetime
362364

363365
conn = trino.dbapi.connect(
366+
experimental_python_types=True,
364367
...
365368
)
366369

367-
cur = conn.cursor(experimental_python_types=True)
370+
cur = conn.cursor()
368371

369372
params = datetime(2020, 1, 1, 16, 43, 22, 320000, tzinfo=pytz.timezone('America/Los_Angeles'))
370373

tests/integration/test_dbapi_integration.py

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,34 @@ def test_execute_many_select(trino_connection):
166166
assert "Query must return update type" in str(e.value)
167167

168168

169-
def test_python_types_not_used_when_experimental_python_types_is_not_set(trino_connection):
170-
cur = trino_connection.cursor()
169+
@pytest.mark.parametrize("connection_experimental_python_types,cursor_experimental_python_types,expected",
170+
[
171+
(None, None, False),
172+
(None, False, False),
173+
(None, True, True),
174+
(False, None, False),
175+
(False, False, False),
176+
(False, True, True),
177+
(True, None, True),
178+
(True, False, False),
179+
(True, True, True),
180+
])
181+
def test_experimental_python_types_with_connection_and_cursor(
182+
connection_experimental_python_types,
183+
cursor_experimental_python_types,
184+
expected,
185+
run_trino
186+
):
187+
_, host, port = run_trino
188+
189+
connection = trino.dbapi.Connection(
190+
host=host,
191+
port=port,
192+
user="test",
193+
experimental_python_types=connection_experimental_python_types,
194+
)
195+
196+
cur = connection.cursor(experimental_python_types=cursor_experimental_python_types)
171197

172198
cur.execute("""
173199
SELECT
@@ -180,15 +206,23 @@ def test_python_types_not_used_when_experimental_python_types_is_not_set(trino_c
180206
""")
181207
rows = cur.fetchall()
182208

183-
for value in rows[0]:
184-
assert isinstance(value, str)
185-
186-
assert rows[0][0] == '0.142857'
187-
assert rows[0][1] == '2018-01-01'
188-
assert rows[0][2] == '2019-01-01 00:00:00.000 +01:00'
189-
assert rows[0][3] == '2019-01-01 00:00:00.000 UTC'
190-
assert rows[0][4] == '2019-01-01 00:00:00.000'
191-
assert rows[0][5] == '00:00:00.000'
209+
if expected:
210+
assert rows[0][0] == Decimal('0.142857')
211+
assert rows[0][1] == date(2018, 1, 1)
212+
assert rows[0][2] == datetime(2019, 1, 1, tzinfo=timezone(timedelta(hours=1)))
213+
assert rows[0][3] == datetime(2019, 1, 1, tzinfo=pytz.timezone('UTC'))
214+
assert rows[0][4] == datetime(2019, 1, 1)
215+
assert rows[0][5] == time(0, 0, 0, 0)
216+
else:
217+
for value in rows[0]:
218+
assert isinstance(value, str)
219+
220+
assert rows[0][0] == '0.142857'
221+
assert rows[0][1] == '2018-01-01'
222+
assert rows[0][2] == '2019-01-01 00:00:00.000 +01:00'
223+
assert rows[0][3] == '2019-01-01 00:00:00.000 UTC'
224+
assert rows[0][4] == '2019-01-01 00:00:00.000'
225+
assert rows[0][5] == '00:00:00.000'
192226

193227

194228
def test_decimal_query_param(trino_connection):

tests/unit/sqlalchemy/test_dialect.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ def setup(self):
4747
'session_properties={"query_max_run_time": "1d"}'
4848
'&http_headers={"trino": 1}'
4949
'&extra_credential=[("a", "b"), ("c", "d")]'
50-
'&client_tags=[1, "sql"]'),
50+
'&client_tags=[1, "sql"]'
51+
'&experimental_python_types=true'),
5152
list(),
5253
dict(
5354
host="localhost",
@@ -58,7 +59,8 @@ def setup(self):
5859
session_properties={"query_max_run_time": "1d"},
5960
http_headers={"trino": 1},
6061
extra_credential=[("a", "b"), ("c", "d")],
61-
client_tags=[1, "sql"]
62+
client_tags=[1, "sql"],
63+
experimental_python_types=True,
6264
),
6365
),
6466
],

trino/dbapi.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,8 @@ def __init__(
109109
isolation_level=IsolationLevel.AUTOCOMMIT,
110110
verify=True,
111111
http_session=None,
112-
client_tags=None
112+
client_tags=None,
113+
experimental_python_types=False,
113114
):
114115
self.host = host
115116
self.port = port
@@ -136,6 +137,7 @@ def __init__(
136137
self._isolation_level = isolation_level
137138
self._request = None
138139
self._transaction = None
140+
self.experimental_python_types = experimental_python_types
139141

140142
@property
141143
def isolation_level(self):
@@ -199,17 +201,21 @@ def _create_request(self):
199201
client_tags=self.client_tags
200202
)
201203

202-
def cursor(self, experimental_python_types=False):
204+
def cursor(self, experimental_python_types: bool = None):
203205
"""Return a new :py:class:`Cursor` object using the connection."""
204206
if self.isolation_level != IsolationLevel.AUTOCOMMIT:
205207
if self.transaction is None:
206208
self.start_transaction()
207-
request = self.transaction._request
208-
elif self.transaction is not None:
209-
request = self.transaction._request
209+
if self.transaction is not None:
210+
request = self.transaction.request
210211
else:
211212
request = self._create_request()
212-
return Cursor(self, request, experimental_python_types)
213+
return Cursor(
214+
self,
215+
request,
216+
# if experimental_python_types is not explicitly set in Cursor, take from Connection
217+
experimental_python_types if experimental_python_types is not None else self.experimental_python_types
218+
)
213219

214220

215221
class Cursor(object):

trino/sqlalchemy/dialect.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,9 @@ def create_connect_args(self, url: URL) -> Tuple[Sequence[Any], Mapping[str, Any
122122
if "client_tags" in url.query:
123123
kwargs["client_tags"] = json.loads(url.query["client_tags"])
124124

125+
if "experimental_python_types" in url.query:
126+
kwargs["experimental_python_types"] = json.loads(url.query["experimental_python_types"])
127+
125128
return args, kwargs
126129

127130
def get_columns(self, connection: Connection, table_name: str, schema: str = None, **kw) -> List[Dict[str, Any]]:

trino/transaction.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ def __init__(self, request):
5959
def id(self):
6060
return self._id
6161

62+
@property
63+
def request(self):
64+
return self._request
65+
6266
def begin(self):
6367
response = self._request.post(START_TRANSACTION)
6468
if not response.ok:

0 commit comments

Comments
 (0)