Skip to content

Commit d2cc2b8

Browse files
SNOW-1846962: remove type conversion when calling a system function (#2737)
1 parent 0488021 commit d2cc2b8

File tree

4 files changed

+171
-3
lines changed

4 files changed

+171
-3
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@
2828
- Added support for mixed case field names in struct type columns.
2929
- Added support for `SeriesGroupBy.unique`
3030

31+
#### Bug Fixes
32+
33+
- Fixed a bug that system function called through `session.call` have incorrect type conversion.
34+
3135
#### Improvements
3236
- Improve performance of `DataFrame.map`, `Series.apply` and `Series.map` methods by mapping numpy functions to snowpark functions if possible.
3337
- Updated integration testing for `session.lineage.trace` to exclude deleted objects

src/snowflake/snowpark/_internal/analyzer/datatype_mapper.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,55 @@ def float_nan_inf_to_sql(value: float) -> str:
6363
return f"{cast_value} :: FLOAT"
6464

6565

66-
def to_sql(value: Any, datatype: DataType, from_values_statement: bool = False) -> str:
67-
"""Convert a value with DataType to a snowflake compatible sql"""
66+
def to_sql_no_cast(
67+
value: Any,
68+
datatype: DataType,
69+
) -> str:
70+
if value is None:
71+
return "NULL"
72+
if isinstance(datatype, VariantType):
73+
# PARSE_JSON returns VARIANT, so no need to append :: VARIANT here explicitly.
74+
return f"PARSE_JSON({str_to_sql(json.dumps(value, cls=PythonObjJSONEncoder))})"
75+
if isinstance(value, str):
76+
if isinstance(datatype, GeographyType):
77+
return f"TO_GEOGRAPHY({str_to_sql(value)})"
78+
if isinstance(datatype, GeometryType):
79+
return f"TO_GEOMETRY({str_to_sql(value)})"
80+
return str_to_sql(value)
81+
if isinstance(value, float) and (math.isnan(value) or math.isinf(value)):
82+
cast_value = float_nan_inf_to_sql(value)
83+
return cast_value[:-9]
84+
if isinstance(value, (list, bytes, bytearray)) and isinstance(datatype, BinaryType):
85+
return str(bytes(value))
86+
if isinstance(value, (list, tuple, array)) and isinstance(datatype, ArrayType):
87+
return f"PARSE_JSON({str_to_sql(json.dumps(value, cls=PythonObjJSONEncoder))})"
88+
if isinstance(value, dict) and isinstance(datatype, MapType):
89+
return f"PARSE_JSON({str_to_sql(json.dumps(value, cls=PythonObjJSONEncoder))})"
90+
if isinstance(datatype, DateType):
91+
if isinstance(value, int):
92+
# add value as number of days to 1970-01-01
93+
target_date = date(1970, 1, 1) + timedelta(days=value)
94+
return f"'{target_date.isoformat()}'"
95+
elif isinstance(value, date):
96+
return f"'{value.isoformat()}'"
6897

98+
if isinstance(datatype, TimestampType):
99+
if isinstance(value, (int, datetime)):
100+
if isinstance(value, int):
101+
# add value as microseconds to 1970-01-01 00:00:00.00.
102+
value = datetime(1970, 1, 1, tzinfo=timezone.utc) + timedelta(
103+
microseconds=value
104+
)
105+
return f"'{value}'"
106+
return f"{value}"
107+
108+
109+
def to_sql(
110+
value: Any,
111+
datatype: DataType,
112+
from_values_statement: bool = False,
113+
) -> str:
114+
"""Convert a value with DataType to a snowflake compatible sql"""
69115
# Handle null values
70116
if isinstance(
71117
datatype,

src/snowflake/snowpark/_internal/udf_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
import snowflake.snowpark
3030
from snowflake.connector.options import installed_pandas, pandas
3131
from snowflake.snowpark._internal import code_generation, type_utils
32-
from snowflake.snowpark._internal.analyzer.datatype_mapper import to_sql
32+
from snowflake.snowpark._internal.analyzer.datatype_mapper import to_sql, to_sql_no_cast
3333
from snowflake.snowpark._internal.telemetry import TelemetryField
3434
from snowflake.snowpark._internal.type_utils import (
3535
NoneType,
@@ -1481,6 +1481,8 @@ def generate_call_python_sp_sql(
14811481
for arg in args:
14821482
if isinstance(arg, snowflake.snowpark.Column):
14831483
sql_args.append(session._analyzer.analyze(arg._expression, {}))
1484+
elif "system$" in sproc_name.lower():
1485+
sql_args.append(to_sql_no_cast(arg, infer_type(arg)))
14841486
else:
14851487
sql_args.append(to_sql(arg, infer_type(arg)))
14861488
return f"CALL {sproc_name}({', '.join(sql_args)})"

tests/unit/test_datatype_mapper.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,18 @@
55

66
import datetime
77
from decimal import Decimal
8+
from unittest.mock import MagicMock
89

910
import pytest
1011

12+
from snowflake.snowpark import Session
1113
from snowflake.snowpark._internal.analyzer.datatype_mapper import (
1214
numeric_to_sql_without_cast,
1315
schema_expression,
1416
to_sql,
17+
to_sql_no_cast,
1518
)
19+
from snowflake.snowpark._internal.udf_utils import generate_call_python_sp_sql
1620
from snowflake.snowpark.types import (
1721
ArrayType,
1822
BinaryType,
@@ -156,6 +160,118 @@ def test_to_sql():
156160
)
157161

158162

163+
def test_to_sql_system_function():
164+
# Test nulls
165+
assert to_sql_no_cast(None, NullType()) == "NULL"
166+
assert to_sql_no_cast(None, ArrayType(DoubleType())) == "NULL"
167+
assert to_sql_no_cast(None, MapType(IntegerType(), ByteType())) == "NULL"
168+
assert to_sql_no_cast(None, StructType([])) == "NULL"
169+
assert to_sql_no_cast(None, GeographyType()) == "NULL"
170+
assert to_sql_no_cast(None, GeometryType()) == "NULL"
171+
172+
assert to_sql_no_cast(None, IntegerType()) == "NULL"
173+
assert to_sql_no_cast(None, ShortType()) == "NULL"
174+
assert to_sql_no_cast(None, ByteType()) == "NULL"
175+
assert to_sql_no_cast(None, LongType()) == "NULL"
176+
assert to_sql_no_cast(None, FloatType()) == "NULL"
177+
assert to_sql_no_cast(None, StringType()) == "NULL"
178+
assert to_sql_no_cast(None, DoubleType()) == "NULL"
179+
assert to_sql_no_cast(None, BooleanType()) == "NULL"
180+
181+
assert to_sql_no_cast(None, "Not any of the previous types") == "NULL"
182+
183+
# Test non-nulls
184+
assert (
185+
to_sql_no_cast("\\ ' ' abc \n \\", StringType())
186+
== "'\\\\ '' '' abc \\n \\\\'"
187+
)
188+
assert (
189+
to_sql_no_cast("\\ ' ' abc \n \\", StringType())
190+
== "'\\\\ '' '' abc \\n \\\\'"
191+
)
192+
assert to_sql_no_cast(1, ByteType()) == "1"
193+
assert to_sql_no_cast(1, ShortType()) == "1"
194+
assert to_sql_no_cast(1, IntegerType()) == "1"
195+
assert to_sql_no_cast(1, LongType()) == "1"
196+
assert to_sql_no_cast(1, BooleanType()) == "1"
197+
assert to_sql_no_cast(0, ByteType()) == "0"
198+
assert to_sql_no_cast(0, ShortType()) == "0"
199+
assert to_sql_no_cast(0, IntegerType()) == "0"
200+
assert to_sql_no_cast(0, LongType()) == "0"
201+
assert to_sql_no_cast(0, BooleanType()) == "0"
202+
203+
assert to_sql_no_cast(float("nan"), FloatType()) == "'NAN'"
204+
assert to_sql_no_cast(float("inf"), FloatType()) == "'INF'"
205+
assert to_sql_no_cast(float("-inf"), FloatType()) == "'-INF'"
206+
assert to_sql_no_cast(1.2, FloatType()) == "1.2"
207+
208+
assert to_sql_no_cast(float("nan"), DoubleType()) == "'NAN'"
209+
assert to_sql_no_cast(float("inf"), DoubleType()) == "'INF'"
210+
assert to_sql_no_cast(float("-inf"), DoubleType()) == "'-INF'"
211+
assert to_sql_no_cast(1.2, DoubleType()) == "1.2"
212+
213+
assert to_sql_no_cast(Decimal(0.5), DecimalType(2, 1)) == "0.5"
214+
215+
assert to_sql_no_cast(397, DateType()) == "'1971-02-02'"
216+
217+
assert to_sql_no_cast(datetime.date(1971, 2, 2), DateType()) == "'1971-02-02'"
218+
219+
assert (
220+
to_sql_no_cast(1622002533000000, TimestampType())
221+
== "'2021-05-26 04:15:33+00:00'"
222+
)
223+
224+
assert (
225+
to_sql_no_cast(bytearray.fromhex("2Ef0 F1f2 "), BinaryType())
226+
== "b'.\\xf0\\xf1\\xf2'"
227+
)
228+
229+
assert to_sql_no_cast([1, "2", 3.5], ArrayType()) == "PARSE_JSON('[1, \"2\", 3.5]')"
230+
assert to_sql_no_cast({"'": '"'}, MapType()) == 'PARSE_JSON(\'{"\'\'": "\\\\""}\')'
231+
assert to_sql_no_cast([{1: 2}], ArrayType()) == "PARSE_JSON('[{\"1\": 2}]')"
232+
assert to_sql_no_cast({1: [2]}, MapType()) == "PARSE_JSON('{\"1\": [2]}')"
233+
234+
assert to_sql_no_cast([1, bytearray(1)], ArrayType()) == "PARSE_JSON('[1, \"00\"]')"
235+
236+
assert (
237+
to_sql_no_cast(["2", Decimal(0.5)], ArrayType()) == "PARSE_JSON('[\"2\", 0.5]')"
238+
)
239+
240+
dt = datetime.datetime.today()
241+
assert (
242+
to_sql_no_cast({1: dt}, MapType())
243+
== 'PARSE_JSON(\'{"1": "' + dt.isoformat() + "\"}')"
244+
)
245+
246+
assert to_sql_no_cast([1, 2, 3.5], VectorType(float, 3)) == "[1, 2, 3.5]"
247+
assert (
248+
to_sql_no_cast("POINT(-122.35 37.55)", GeographyType())
249+
== "TO_GEOGRAPHY('POINT(-122.35 37.55)')"
250+
)
251+
assert (
252+
to_sql_no_cast("POINT(-122.35 37.55)", GeometryType())
253+
== "TO_GEOMETRY('POINT(-122.35 37.55)')"
254+
)
255+
assert to_sql_no_cast("1", VariantType()) == "PARSE_JSON('\"1\"')"
256+
assert (
257+
to_sql_no_cast([1, 2, 3.5, 4.1234567, -3.8], VectorType("float", 5))
258+
== "[1, 2, 3.5, 4.1234567, -3.8]"
259+
)
260+
assert to_sql_no_cast([1, 2, 3], VectorType(int, 3)) == "[1, 2, 3]"
261+
assert (
262+
to_sql_no_cast([1, 2, 31234567, -1928, 0, -3], VectorType(int, 5))
263+
== "[1, 2, 31234567, -1928, 0, -3]"
264+
)
265+
266+
267+
def test_generate_call_python_sp_sql():
268+
fake_session = MagicMock(Session)
269+
assert (
270+
generate_call_python_sp_sql(fake_session, "system$wait", 1)
271+
== "CALL system$wait(1)"
272+
)
273+
274+
159275
@pytest.mark.parametrize(
160276
"timezone, expected",
161277
[

0 commit comments

Comments
 (0)