|
5 | 5 |
|
6 | 6 | import datetime |
7 | 7 | from decimal import Decimal |
| 8 | +from unittest.mock import MagicMock |
8 | 9 |
|
9 | 10 | import pytest |
10 | 11 |
|
| 12 | +from snowflake.snowpark import Session |
11 | 13 | from snowflake.snowpark._internal.analyzer.datatype_mapper import ( |
12 | 14 | numeric_to_sql_without_cast, |
13 | 15 | schema_expression, |
14 | 16 | to_sql, |
| 17 | + to_sql_no_cast, |
15 | 18 | ) |
| 19 | +from snowflake.snowpark._internal.udf_utils import generate_call_python_sp_sql |
16 | 20 | from snowflake.snowpark.types import ( |
17 | 21 | ArrayType, |
18 | 22 | BinaryType, |
@@ -156,6 +160,118 @@ def test_to_sql(): |
156 | 160 | ) |
157 | 161 |
|
158 | 162 |
|
| 163 | +def test_to_sql_system_function(): |
| 164 | + # Test nulls |
| 165 | + assert to_sql_no_cast(None, NullType()) == "NULL" |
| 166 | + assert to_sql_no_cast(None, ArrayType(DoubleType())) == "NULL" |
| 167 | + assert to_sql_no_cast(None, MapType(IntegerType(), ByteType())) == "NULL" |
| 168 | + assert to_sql_no_cast(None, StructType([])) == "NULL" |
| 169 | + assert to_sql_no_cast(None, GeographyType()) == "NULL" |
| 170 | + assert to_sql_no_cast(None, GeometryType()) == "NULL" |
| 171 | + |
| 172 | + assert to_sql_no_cast(None, IntegerType()) == "NULL" |
| 173 | + assert to_sql_no_cast(None, ShortType()) == "NULL" |
| 174 | + assert to_sql_no_cast(None, ByteType()) == "NULL" |
| 175 | + assert to_sql_no_cast(None, LongType()) == "NULL" |
| 176 | + assert to_sql_no_cast(None, FloatType()) == "NULL" |
| 177 | + assert to_sql_no_cast(None, StringType()) == "NULL" |
| 178 | + assert to_sql_no_cast(None, DoubleType()) == "NULL" |
| 179 | + assert to_sql_no_cast(None, BooleanType()) == "NULL" |
| 180 | + |
| 181 | + assert to_sql_no_cast(None, "Not any of the previous types") == "NULL" |
| 182 | + |
| 183 | + # Test non-nulls |
| 184 | + assert ( |
| 185 | + to_sql_no_cast("\\ ' ' abc \n \\", StringType()) |
| 186 | + == "'\\\\ '' '' abc \\n \\\\'" |
| 187 | + ) |
| 188 | + assert ( |
| 189 | + to_sql_no_cast("\\ ' ' abc \n \\", StringType()) |
| 190 | + == "'\\\\ '' '' abc \\n \\\\'" |
| 191 | + ) |
| 192 | + assert to_sql_no_cast(1, ByteType()) == "1" |
| 193 | + assert to_sql_no_cast(1, ShortType()) == "1" |
| 194 | + assert to_sql_no_cast(1, IntegerType()) == "1" |
| 195 | + assert to_sql_no_cast(1, LongType()) == "1" |
| 196 | + assert to_sql_no_cast(1, BooleanType()) == "1" |
| 197 | + assert to_sql_no_cast(0, ByteType()) == "0" |
| 198 | + assert to_sql_no_cast(0, ShortType()) == "0" |
| 199 | + assert to_sql_no_cast(0, IntegerType()) == "0" |
| 200 | + assert to_sql_no_cast(0, LongType()) == "0" |
| 201 | + assert to_sql_no_cast(0, BooleanType()) == "0" |
| 202 | + |
| 203 | + assert to_sql_no_cast(float("nan"), FloatType()) == "'NAN'" |
| 204 | + assert to_sql_no_cast(float("inf"), FloatType()) == "'INF'" |
| 205 | + assert to_sql_no_cast(float("-inf"), FloatType()) == "'-INF'" |
| 206 | + assert to_sql_no_cast(1.2, FloatType()) == "1.2" |
| 207 | + |
| 208 | + assert to_sql_no_cast(float("nan"), DoubleType()) == "'NAN'" |
| 209 | + assert to_sql_no_cast(float("inf"), DoubleType()) == "'INF'" |
| 210 | + assert to_sql_no_cast(float("-inf"), DoubleType()) == "'-INF'" |
| 211 | + assert to_sql_no_cast(1.2, DoubleType()) == "1.2" |
| 212 | + |
| 213 | + assert to_sql_no_cast(Decimal(0.5), DecimalType(2, 1)) == "0.5" |
| 214 | + |
| 215 | + assert to_sql_no_cast(397, DateType()) == "'1971-02-02'" |
| 216 | + |
| 217 | + assert to_sql_no_cast(datetime.date(1971, 2, 2), DateType()) == "'1971-02-02'" |
| 218 | + |
| 219 | + assert ( |
| 220 | + to_sql_no_cast(1622002533000000, TimestampType()) |
| 221 | + == "'2021-05-26 04:15:33+00:00'" |
| 222 | + ) |
| 223 | + |
| 224 | + assert ( |
| 225 | + to_sql_no_cast(bytearray.fromhex("2Ef0 F1f2 "), BinaryType()) |
| 226 | + == "b'.\\xf0\\xf1\\xf2'" |
| 227 | + ) |
| 228 | + |
| 229 | + assert to_sql_no_cast([1, "2", 3.5], ArrayType()) == "PARSE_JSON('[1, \"2\", 3.5]')" |
| 230 | + assert to_sql_no_cast({"'": '"'}, MapType()) == 'PARSE_JSON(\'{"\'\'": "\\\\""}\')' |
| 231 | + assert to_sql_no_cast([{1: 2}], ArrayType()) == "PARSE_JSON('[{\"1\": 2}]')" |
| 232 | + assert to_sql_no_cast({1: [2]}, MapType()) == "PARSE_JSON('{\"1\": [2]}')" |
| 233 | + |
| 234 | + assert to_sql_no_cast([1, bytearray(1)], ArrayType()) == "PARSE_JSON('[1, \"00\"]')" |
| 235 | + |
| 236 | + assert ( |
| 237 | + to_sql_no_cast(["2", Decimal(0.5)], ArrayType()) == "PARSE_JSON('[\"2\", 0.5]')" |
| 238 | + ) |
| 239 | + |
| 240 | + dt = datetime.datetime.today() |
| 241 | + assert ( |
| 242 | + to_sql_no_cast({1: dt}, MapType()) |
| 243 | + == 'PARSE_JSON(\'{"1": "' + dt.isoformat() + "\"}')" |
| 244 | + ) |
| 245 | + |
| 246 | + assert to_sql_no_cast([1, 2, 3.5], VectorType(float, 3)) == "[1, 2, 3.5]" |
| 247 | + assert ( |
| 248 | + to_sql_no_cast("POINT(-122.35 37.55)", GeographyType()) |
| 249 | + == "TO_GEOGRAPHY('POINT(-122.35 37.55)')" |
| 250 | + ) |
| 251 | + assert ( |
| 252 | + to_sql_no_cast("POINT(-122.35 37.55)", GeometryType()) |
| 253 | + == "TO_GEOMETRY('POINT(-122.35 37.55)')" |
| 254 | + ) |
| 255 | + assert to_sql_no_cast("1", VariantType()) == "PARSE_JSON('\"1\"')" |
| 256 | + assert ( |
| 257 | + to_sql_no_cast([1, 2, 3.5, 4.1234567, -3.8], VectorType("float", 5)) |
| 258 | + == "[1, 2, 3.5, 4.1234567, -3.8]" |
| 259 | + ) |
| 260 | + assert to_sql_no_cast([1, 2, 3], VectorType(int, 3)) == "[1, 2, 3]" |
| 261 | + assert ( |
| 262 | + to_sql_no_cast([1, 2, 31234567, -1928, 0, -3], VectorType(int, 5)) |
| 263 | + == "[1, 2, 31234567, -1928, 0, -3]" |
| 264 | + ) |
| 265 | + |
| 266 | + |
| 267 | +def test_generate_call_python_sp_sql(): |
| 268 | + fake_session = MagicMock(Session) |
| 269 | + assert ( |
| 270 | + generate_call_python_sp_sql(fake_session, "system$wait", 1) |
| 271 | + == "CALL system$wait(1)" |
| 272 | + ) |
| 273 | + |
| 274 | + |
159 | 275 | @pytest.mark.parametrize( |
160 | 276 | "timezone, expected", |
161 | 277 | [ |
|
0 commit comments