Skip to content

Commit 5e66f1d

Browse files
authored
SNOW-1332387: Fix test_to_arrow_datatypes (#1926)
1 parent d2d3377 commit 5e66f1d

File tree

1 file changed

+70
-70
lines changed

1 file changed

+70
-70
lines changed

test/integ/pandas/test_arrow_pandas.py

Lines changed: 70 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1205,35 +1205,58 @@ def test_batch_to_pandas_arrow(conn_cnx, result_format):
12051205
@pytest.mark.internal
12061206
@pytest.mark.parametrize("enable_structured_types", [True, False])
12071207
def test_to_arrow_datatypes(enable_structured_types, conn_cnx):
1208+
expected_types = (
1209+
pyarrow.int64(),
1210+
pyarrow.float64(),
1211+
pyarrow.string(),
1212+
pyarrow.date64(),
1213+
pyarrow.timestamp("ns"),
1214+
pyarrow.string(),
1215+
pyarrow.timestamp("ns"),
1216+
pyarrow.timestamp("ns"),
1217+
pyarrow.timestamp("ns"),
1218+
pyarrow.binary(),
1219+
pyarrow.time64("ns"),
1220+
pyarrow.bool_(),
1221+
pyarrow.string(),
1222+
pyarrow.string(),
1223+
pyarrow.list_(pyarrow.float64(), 5),
1224+
)
1225+
1226+
query = """
1227+
select
1228+
1 :: INTEGER as FIXED_type,
1229+
2.0 :: FLOAT as REAL_type,
1230+
'test' :: TEXT as TEXT_type,
1231+
'2024-02-28' :: DATE as DATE_type,
1232+
'2020-03-12 01:02:03.123456789' :: TIMESTAMP as TIMESTAMP_type,
1233+
'{"foo": "bar"}' :: VARIANT as VARIANT_type,
1234+
'2020-03-12 01:02:03.123456789' :: TIMESTAMP_LTZ as TIMESTAMP_LTZ_type,
1235+
'2020-03-12 01:02:03.123456789' :: TIMESTAMP_TZ as TIMESTAMP_TZ_type,
1236+
'2020-03-12 01:02:03.123456789' :: TIMESTAMP_NTZ as TIMESTAMP_NTZ_type,
1237+
'0xAAAA' :: BINARY as BINARY_type,
1238+
'01:02:03.123456789' :: TIME as TIME_type,
1239+
true :: BOOLEAN as BOOLEAN_type,
1240+
TO_GEOGRAPHY('LINESTRING(13.4814 52.5015, -121.8212 36.8252)') as GEOGRAPHY_type,
1241+
TO_GEOMETRY('LINESTRING(13.4814 52.5015, -121.8212 36.8252)') as GEOMETRY_type,
1242+
[1,2,3,4,5] :: vector(float, 5) as VECTOR_type,
1243+
object_construct('k1', 1, 'k2', 2, 'k3', 3, 'k4', 4, 'k5', 5) :: map(varchar, int) as MAP_type,
1244+
object_construct('city', 'san jose', 'population', 0.05) :: object(city varchar, population float) as OBJECT_type,
1245+
[1.0, 3.1, 4.5] :: array(float) as ARRAY_type
1246+
WHERE 1=0
1247+
"""
1248+
1249+
structured_params = {
1250+
"ENABLE_STRUCTURED_TYPES_IN_CLIENT_RESPONSE",
1251+
"IGNORE_CLIENT_VESRION_IN_STRUCTURED_TYPES_RESPONSE",
1252+
"FORCE_ENABLE_STRUCTURED_TYPES_NATIVE_ARROW_FORMAT",
1253+
}
1254+
12081255
with conn_cnx() as cnx:
12091256
with cnx.cursor() as cur:
12101257
cur.execute(SQL_ENABLE_ARROW)
1211-
1212-
expected_types = (
1213-
pyarrow.int64(),
1214-
pyarrow.float64(),
1215-
pyarrow.string(),
1216-
pyarrow.date64(),
1217-
pyarrow.timestamp("ns"),
1218-
pyarrow.string(),
1219-
pyarrow.timestamp("ns"),
1220-
pyarrow.timestamp("ns"),
1221-
pyarrow.timestamp("ns"),
1222-
pyarrow.binary(),
1223-
pyarrow.time64("ns"),
1224-
pyarrow.bool_(),
1225-
pyarrow.string(),
1226-
pyarrow.string(),
1227-
pyarrow.list_(pyarrow.float64(), 5),
1228-
)
1229-
1230-
if enable_structured_types:
1231-
structured_params = {
1232-
"ENABLE_STRUCTURED_TYPES_IN_CLIENT_RESPONSE",
1233-
"IGNORE_CLIENT_VESRION_IN_STRUCTURED_TYPES_RESPONSE",
1234-
"FORCE_ENABLE_STRUCTURED_TYPES_NATIVE_ARROW_FORMAT",
1235-
}
1236-
try:
1258+
try:
1259+
if enable_structured_types:
12371260
for param in structured_params:
12381261
cur.execute(f"alter session set {param}=true")
12391262
expected_types += (
@@ -1243,50 +1266,27 @@ def test_to_arrow_datatypes(enable_structured_types, conn_cnx):
12431266
),
12441267
pyarrow.list_(pyarrow.float64()),
12451268
)
1246-
finally:
1247-
cur.execute(f"alter session unset {param}")
1248-
else:
1249-
expected_types += (
1250-
pyarrow.string(),
1251-
pyarrow.string(),
1252-
pyarrow.string(),
1253-
)
1254-
1255-
# Ensure an empty batch to use default typing
1256-
# Otherwise arrow will resize types to save space
1257-
cur.execute(
1258-
"""
1259-
select
1260-
1 :: INTEGER as FIXED_type,
1261-
2.0 :: FLOAT as REAL_type,
1262-
'test' :: TEXT as TEXT_type,
1263-
'2024-02-28' :: DATE as DATE_type,
1264-
'2020-03-12 01:02:03.123456789' :: TIMESTAMP as TIMESTAMP_type,
1265-
'{"foo": "bar"}' :: VARIANT as VARIANT_type,
1266-
'2020-03-12 01:02:03.123456789' :: TIMESTAMP_LTZ as TIMESTAMP_LTZ_type,
1267-
'2020-03-12 01:02:03.123456789' :: TIMESTAMP_TZ as TIMESTAMP_TZ_type,
1268-
'2020-03-12 01:02:03.123456789' :: TIMESTAMP_NTZ as TIMESTAMP_NTZ_type,
1269-
'0xAAAA' :: BINARY as BINARY_type,
1270-
'01:02:03.123456789' :: TIME as TIME_type,
1271-
true :: BOOLEAN as BOOLEAN_type,
1272-
TO_GEOGRAPHY('LINESTRING(13.4814 52.5015, -121.8212 36.8252)') as GEOGRAPHY_type,
1273-
TO_GEOMETRY('LINESTRING(13.4814 52.5015, -121.8212 36.8252)') as GEOMETRY_type,
1274-
[1,2,3,4,5] :: vector(float, 5) as VECTOR_type,
1275-
object_construct('k1', 1, 'k2', 2, 'k3', 3, 'k4', 4, 'k5', 5) :: map(varchar, int) as MAP_type,
1276-
object_construct('city', 'san jose', 'population', 0.05) :: object(city varchar, population float) as OBJECT_type,
1277-
[1.0, 3.1, 4.5] :: array(float) as ARRAY_type
1278-
WHERE 1=0
1279-
"""
1280-
)
1281-
1282-
batches = cur.get_result_batches()
1283-
assert len(batches) == 1
1284-
batch = batches[0]
1285-
arrow_table = batch.to_arrow()
1286-
for actual, expected in zip(arrow_table.schema, expected_types):
1287-
assert (
1288-
actual.type == expected
1289-
), f"Expected {actual.name} :: {actual.type} column to be of type {expected}"
1269+
else:
1270+
expected_types += (
1271+
pyarrow.string(),
1272+
pyarrow.string(),
1273+
pyarrow.string(),
1274+
)
1275+
# Ensure an empty batch to use default typing
1276+
# Otherwise arrow will resize types to save space
1277+
cur.execute(query)
1278+
batches = cur.get_result_batches()
1279+
assert len(batches) == 1
1280+
batch = batches[0]
1281+
arrow_table = batch.to_arrow()
1282+
for actual, expected in zip(arrow_table.schema, expected_types):
1283+
assert (
1284+
actual.type == expected
1285+
), f"Expected {actual.name} :: {actual.type} column to be of type {expected}"
1286+
finally:
1287+
if enable_structured_types:
1288+
for param in structured_params:
1289+
cur.execute(f"alter session unset {param}")
12901290

12911291

12921292
def test_simple_arrow_fetch(conn_cnx):

0 commit comments

Comments
 (0)