Skip to content

Commit dc16ecc

Browse files
authored
SNOW-1000284: Add support for structured types to ResultMetadataV2 (#1890)
1 parent 23eda6e commit dc16ecc

File tree

4 files changed

+172
-43
lines changed

4 files changed

+172
-43
lines changed

DESCRIPTION.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne
2020
- Add `SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT` flag (usage: `SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT=true`) to make a non-blocking socket.recv call and retry on Error
2121
- Consider using this if running in a containerized environment and externalbrowser auth frequently hangs while waiting for callback
2222
- NOTE: this has not been tested extensively, but has been shown to improve the experience when using WSL
23+
- Added support for parsing structured type information in schema queries.
2324
- Bumped platformdirs from >=2.6.0,<4.0.0 to >=2.6.0,<5.0.0
2425
- Updated diagnostics to use system$allowlist instead of system$whitelist.
2526
- Update `write_pandas` to skip TABLE IF NOT EXISTS in truncate mode.

src/snowflake/connector/constants.py

Lines changed: 68 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -44,38 +44,75 @@ class FieldType(NamedTuple):
4444

4545

4646
def vector_pa_type(metadata: ResultMetadataV2) -> DataType:
47-
"""Generate the Arrow type represented by the given vector column metadata.
48-
47+
"""
48+
Generate the Arrow type represented by the given vector column metadata.
4949
Vectors are represented as Arrow fixed-size lists.
5050
"""
51+
assert (
52+
metadata.fields is not None and len(metadata.fields) == 1
53+
), "Invalid result metadata for vector type: expected a single field to be defined"
54+
assert (
55+
metadata.vector_dimension or 0
56+
) > 0, "Invalid result metadata for vector type: expected a positive dimension"
57+
58+
field_type = FIELD_TYPES[metadata.fields[0].type_code]
59+
return pa.list_(field_type.pa_type(metadata.fields[0]), metadata.vector_dimension)
60+
61+
62+
def array_pa_type(metadata: ResultMetadataV2) -> DataType:
63+
"""
64+
Generate the Arrow type represented by the given array column metadata.
65+
"""
66+
# If fields is missing then structured types are not enabled.
67+
# Fallback to json encoded string
68+
if metadata.fields is None:
69+
return pa.string()
70+
71+
assert (
72+
len(metadata.fields) == 1
73+
), "Invalid result metadata for array type: expected a single field to be defined"
74+
75+
field_type = FIELD_TYPES[metadata.fields[0].type_code]
76+
return pa.list_(field_type.pa_type(metadata.fields[0]))
77+
78+
79+
def map_pa_type(metadata: ResultMetadataV2) -> DataType:
80+
"""
81+
Generate the Arrow type represented by the given map column metadata.
82+
"""
83+
# If fields is missing then structured types are not enabled.
84+
# Fallback to json encoded string
85+
if metadata.fields is None:
86+
return pa.string()
5187

88+
assert (
89+
len(metadata.fields or []) == 2
90+
), "Invalid result metadata for map type: expected a field for key and a field for value"
91+
key_type = FIELD_TYPES[metadata.fields[0].type_code]
92+
value_type = FIELD_TYPES[metadata.fields[1].type_code]
93+
return pa.map_(
94+
key_type.pa_type(metadata.fields[0]), value_type.pa_type(metadata.fields[1])
95+
)
96+
97+
98+
def struct_pa_type(metadata: ResultMetadataV2) -> DataType:
99+
"""
100+
Generate the Arrow type represented by the given struct column metadata.
101+
"""
102+
# If fields is missing then structured types are not enabled.
103+
# Fallback to json encoded string
52104
if metadata.fields is None:
53-
raise ValueError(
54-
"Invalid result metadata for vector type: expected sub-field metadata"
55-
)
56-
if len(metadata.fields) != 1:
57-
raise ValueError(
58-
"Invalid result metadata for vector type: expected a single sub-field metadata"
59-
)
60-
field_type = FIELD_ID_TO_NAME[metadata.fields[0].type_code]
61-
62-
if metadata.vector_dimension is None:
63-
raise ValueError(
64-
"Invalid result metadata for vector type: expected a dimension"
65-
)
66-
elif metadata.vector_dimension <= 0:
67-
raise ValueError(
68-
"Invalid result metadata for vector type: expected a positive dimension"
69-
)
70-
71-
if field_type == "FIXED":
72-
return pa.list_(pa.int32(), metadata.vector_dimension)
73-
elif field_type == "REAL":
74-
return pa.list_(pa.float32(), metadata.vector_dimension)
75-
else:
76-
raise ValueError(
77-
f"Invalid result metadata for vector type: invalid element type: {field_type}"
78-
)
105+
return pa.string()
106+
107+
assert all(
108+
field.name is not None for field in metadata.fields
109+
), "All fields of a stuct type must have a name."
110+
return pa.struct(
111+
{
112+
field.name: FIELD_TYPES[field.type_code].pa_type(field)
113+
for field in metadata.fields
114+
}
115+
)
79116

80117

81118
# This type mapping holds column type definitions.
@@ -121,12 +158,8 @@ def vector_pa_type(metadata: ResultMetadataV2) -> DataType:
121158
dbapi_type=[DBAPI_TYPE_TIMESTAMP],
122159
pa_type=lambda _: pa.timestamp("ns"),
123160
),
124-
FieldType(
125-
name="OBJECT", dbapi_type=[DBAPI_TYPE_BINARY], pa_type=lambda _: pa.string()
126-
),
127-
FieldType(
128-
name="ARRAY", dbapi_type=[DBAPI_TYPE_BINARY], pa_type=lambda _: pa.string()
129-
),
161+
FieldType(name="OBJECT", dbapi_type=[DBAPI_TYPE_BINARY], pa_type=struct_pa_type),
162+
FieldType(name="ARRAY", dbapi_type=[DBAPI_TYPE_BINARY], pa_type=array_pa_type),
130163
FieldType(
131164
name="BINARY", dbapi_type=[DBAPI_TYPE_BINARY], pa_type=lambda _: pa.binary()
132165
),
@@ -143,6 +176,7 @@ def vector_pa_type(metadata: ResultMetadataV2) -> DataType:
143176
name="GEOMETRY", dbapi_type=[DBAPI_TYPE_STRING], pa_type=lambda _: pa.string()
144177
),
145178
FieldType(name="VECTOR", dbapi_type=[DBAPI_TYPE_BINARY], pa_type=vector_pa_type),
179+
FieldType(name="MAP", dbapi_type=[DBAPI_TYPE_BINARY], pa_type=map_pa_type),
146180
)
147181

148182
FIELD_NAME_TO_ID: DefaultDict[Any, int] = defaultdict(int)

src/snowflake/connector/cursor.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -201,28 +201,35 @@ def from_column(cls, col: dict[str, Any]) -> ResultMetadataV2:
201201
This differs from ResultMetadata in that it has newly-added fields which cannot be added to
202202
ResultMetadata since it is a named tuple.
203203
"""
204-
type_code = FIELD_NAME_TO_ID[
204+
col_type = (
205205
col["extTypeName"].upper()
206206
if col.get("extTypeName")
207207
else col["type"].upper()
208-
]
208+
)
209209

210-
fields = None
211-
if type_code == FIELD_NAME_TO_ID["VECTOR"] and col.get("fields") is not None:
212-
fields = [
213-
ResultMetadataV2.from_column({"name": None, **f}) for f in col["fields"]
214-
]
210+
fields = col.get("fields")
211+
processed_fields: Optional[List[ResultMetadataV2]] = None
212+
if fields is not None:
213+
if col_type in {"VECTOR", "ARRAY", "OBJECT", "MAP"}:
214+
processed_fields = [
215+
ResultMetadataV2.from_column({"name": None, **f})
216+
for f in col["fields"]
217+
]
218+
else:
219+
raise ValueError(
220+
f"Field parsing is not supported for columns of type {col_type}."
221+
)
215222

216223
return cls(
217224
col["name"],
218-
type_code,
225+
FIELD_NAME_TO_ID[col_type],
219226
col["nullable"],
220227
None,
221228
col["length"],
222229
col["precision"],
223230
col["scale"],
224231
col.get("vectorDimension"),
225-
fields,
232+
processed_fields,
226233
)
227234

228235
def _to_result_metadata_v1(self):

test/integ/pandas/test_arrow_pandas.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1202,6 +1202,93 @@ def test_batch_to_pandas_arrow(conn_cnx, result_format):
12021202
assert arrow_table.to_pydict()["FOO"] == list(range(rowcount))
12031203

12041204

1205+
@pytest.mark.internal
1206+
@pytest.mark.parametrize("enable_structured_types", [True, False])
1207+
def test_to_arrow_datatypes(enable_structured_types, conn_cnx):
1208+
with conn_cnx() as cnx:
1209+
with cnx.cursor() as cur:
1210+
cur.execute(SQL_ENABLE_ARROW)
1211+
1212+
expected_types = (
1213+
pyarrow.int64(),
1214+
pyarrow.float64(),
1215+
pyarrow.string(),
1216+
pyarrow.date64(),
1217+
pyarrow.timestamp("ns"),
1218+
pyarrow.string(),
1219+
pyarrow.timestamp("ns"),
1220+
pyarrow.timestamp("ns"),
1221+
pyarrow.timestamp("ns"),
1222+
pyarrow.binary(),
1223+
pyarrow.time64("ns"),
1224+
pyarrow.bool_(),
1225+
pyarrow.string(),
1226+
pyarrow.string(),
1227+
pyarrow.list_(pyarrow.float64(), 5),
1228+
)
1229+
1230+
if enable_structured_types:
1231+
structured_params = {
1232+
"ENABLE_STRUCTURED_TYPES_IN_CLIENT_RESPONSE",
1233+
"IGNORE_CLIENT_VESRION_IN_STRUCTURED_TYPES_RESPONSE",
1234+
"FORCE_ENABLE_STRUCTURED_TYPES_NATIVE_ARROW_FORMAT",
1235+
}
1236+
try:
1237+
for param in structured_params:
1238+
cur.execute(f"alter session set {param}=true")
1239+
expected_types += (
1240+
pyarrow.map_(pyarrow.string(), pyarrow.int64()),
1241+
pyarrow.struct(
1242+
{"city": pyarrow.string(), "population": pyarrow.float64()}
1243+
),
1244+
pyarrow.list_(pyarrow.float64()),
1245+
)
1246+
finally:
1247+
cur.execute(f"alter session unset {param}")
1248+
else:
1249+
expected_types += (
1250+
pyarrow.string(),
1251+
pyarrow.string(),
1252+
pyarrow.string(),
1253+
)
1254+
1255+
# Ensure an empty batch to use default typing
1256+
# Otherwise arrow will resize types to save space
1257+
cur.execute(
1258+
"""
1259+
select
1260+
1 :: INTEGER as FIXED_type,
1261+
2.0 :: FLOAT as REAL_type,
1262+
'test' :: TEXT as TEXT_type,
1263+
'2024-02-28' :: DATE as DATE_type,
1264+
'2020-03-12 01:02:03.123456789' :: TIMESTAMP as TIMESTAMP_type,
1265+
'{"foo": "bar"}' :: VARIANT as VARIANT_type,
1266+
'2020-03-12 01:02:03.123456789' :: TIMESTAMP_LTZ as TIMESTAMP_LTZ_type,
1267+
'2020-03-12 01:02:03.123456789' :: TIMESTAMP_TZ as TIMESTAMP_TZ_type,
1268+
'2020-03-12 01:02:03.123456789' :: TIMESTAMP_NTZ as TIMESTAMP_NTZ_type,
1269+
'0xAAAA' :: BINARY as BINARY_type,
1270+
'01:02:03.123456789' :: TIME as TIME_type,
1271+
true :: BOOLEAN as BOOLEAN_type,
1272+
TO_GEOGRAPHY('LINESTRING(13.4814 52.5015, -121.8212 36.8252)') as GEOGRAPHY_type,
1273+
TO_GEOMETRY('LINESTRING(13.4814 52.5015, -121.8212 36.8252)') as GEOMETRY_type,
1274+
[1,2,3,4,5] :: vector(float, 5) as VECTOR_type,
1275+
object_construct('k1', 1, 'k2', 2, 'k3', 3, 'k4', 4, 'k5', 5) :: map(varchar, int) as MAP_type,
1276+
object_construct('city', 'san jose', 'population', 0.05) :: object(city varchar, population float) as OBJECT_type,
1277+
[1.0, 3.1, 4.5] :: array(float) as ARRAY_type
1278+
WHERE 1=0
1279+
"""
1280+
)
1281+
1282+
batches = cur.get_result_batches()
1283+
assert len(batches) == 1
1284+
batch = batches[0]
1285+
arrow_table = batch.to_arrow()
1286+
for actual, expected in zip(arrow_table.schema, expected_types):
1287+
assert (
1288+
actual.type == expected
1289+
), f"Expected {actual.name} :: {actual.type} column to be of type {expected}"
1290+
1291+
12051292
def test_simple_arrow_fetch(conn_cnx):
12061293
rowcount = 250_000
12071294
with conn_cnx() as cnx:

0 commit comments

Comments
 (0)