Skip to content

Commit 81184b5

Browse files
mdesmetebyhr
authored andcommitted
Support float("inf") and float("nan") with experimental_python_types
1 parent 939cab0 commit 81184b5

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

tests/integration/test_dbapi_integration.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -614,9 +614,8 @@ def test_float_query_param(trino_connection):
614614
assert rows[0][0] == 1.1
615615

616616

617-
@pytest.mark.skip(reason="Nan currently not returning the correct python type for nan")
618617
def test_float_nan_query_param(trino_connection):
619-
cur = trino_connection.cursor()
618+
cur = trino_connection.cursor(experimental_python_types=True)
620619
cur.execute("SELECT ?", params=(float("nan"),))
621620
rows = cur.fetchall()
622621

@@ -625,15 +624,14 @@ def test_float_nan_query_param(trino_connection):
625624
assert math.isnan(rows[0][0])
626625

627626

628-
@pytest.mark.skip(reason="Nan currently not returning the correct python type fon inf")
629627
def test_float_inf_query_param(trino_connection):
630-
cur = trino_connection.cursor()
628+
cur = trino_connection.cursor(experimental_python_types=True)
631629
cur.execute("SELECT ?", params=(float("inf"),))
632630
rows = cur.fetchall()
633631

634632
assert rows[0][0] == float("inf")
635633

636-
cur.execute("SELECT ?", params=(-float("-inf"),))
634+
cur.execute("SELECT ?", params=(float("-inf"),))
637635
rows = cur.fetchall()
638636

639637
assert rows[0][0] == float("-inf")

trino/client.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@
6161

6262
_HEADER_EXTRA_CREDENTIAL_KEY_REGEX = re.compile(r'^\S[^\s=]*$')
6363

64+
INF = float("inf")
65+
NEGATIVE_INF = float("-inf")
66+
NAN = float("nan")
67+
6468

6569
class ClientSession(object):
6670
def __init__(
@@ -526,6 +530,14 @@ def _map_to_python_type(cls, item: Tuple[Any, Dict]) -> Any:
526530
return [cls._map_to_python_type((array_item, raw_type)) for array_item in value]
527531
elif "decimal" in raw_type:
528532
return Decimal(value)
533+
elif raw_type == "double":
534+
if value == 'Infinity':
535+
return INF
536+
elif value == '-Infinity':
537+
return NEGATIVE_INF
538+
elif value == 'NaN':
539+
return NAN
540+
return value
529541
elif raw_type == "date":
530542
return datetime.strptime(value, "%Y-%m-%d").date()
531543
elif raw_type == "timestamp with time zone":

0 commit comments

Comments
 (0)