Skip to content

Commit eceed98

Browse files
SNOW-950840 Parse new result metadata attributes for internal use (#1806)
1 parent bf7a3c5 commit eceed98

File tree

6 files changed

+439
-112
lines changed

6 files changed

+439
-112
lines changed

src/snowflake/connector/constants.py

Lines changed: 58 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
if TYPE_CHECKING:
1616
from pyarrow import DataType
1717

18+
from .cursor import ResultMetadataV2
19+
1820
# Snowflake's central platform dependent directories, if the folder
1921
# ~/.snowflake/ (customizable by the environment variable SNOWFLAKE_HOME) exists
2022
# we use that folder for everything. Otherwise, we fall back to platformdirs
@@ -38,71 +40,102 @@
3840
class FieldType(NamedTuple):
3941
name: str
4042
dbapi_type: list[int]
41-
pa_type: Callable[[], DataType]
43+
pa_type: Callable[[ResultMetadataV2], DataType]
44+
45+
46+
def vector_pa_type(metadata: ResultMetadataV2) -> DataType:
47+
"""Generate the arrow type represented by the given vector column metadata."""
48+
49+
if metadata.fields is None:
50+
raise ValueError(
51+
"Invalid result metadata for vector type: expected sub-field metadata"
52+
)
53+
if len(metadata.fields) != 1:
54+
raise ValueError(
55+
"Invalid result metadata for vector type: expected a single sub-field metadata"
56+
)
57+
field_type = FIELD_ID_TO_NAME[metadata.fields[0].type_code]
58+
59+
if metadata.vector_dimension is None:
60+
raise ValueError(
61+
"Invalid result metadata for vector type: expected a dimension"
62+
)
63+
64+
if field_type == "FIXED":
65+
return pa.list_(pa.int32(), metadata.vector_dimension)
66+
elif field_type == "REAL":
67+
return pa.list_(pa.float32(), metadata.vector_dimension)
68+
else:
69+
raise ValueError(
70+
f"Invalid result metadata for vector type: invalid element type: {field_type}"
71+
)
4272

4373

4474
# This type mapping holds column type definitions.
4575
# Be careful to not change the ordering as the index is what Snowflake
4676
# gives to as schema
77+
#
78+
# `name` is the SQL name of the type, `dbapi_type` is the set of corresponding
79+
# PEP 249 type objects, and `pa_type` is a lambda that takes in a column's
80+
# result metadata and returns the corresponding Arrow type.
4781
FIELD_TYPES: tuple[FieldType, ...] = (
48-
FieldType(name="FIXED", dbapi_type=[DBAPI_TYPE_NUMBER], pa_type=lambda: pa.int64()),
4982
FieldType(
50-
name="REAL", dbapi_type=[DBAPI_TYPE_NUMBER], pa_type=lambda: pa.float64()
83+
name="FIXED", dbapi_type=[DBAPI_TYPE_NUMBER], pa_type=lambda _: pa.int64()
5184
),
52-
FieldType(name="TEXT", dbapi_type=[DBAPI_TYPE_STRING], pa_type=lambda: pa.string()),
5385
FieldType(
54-
name="DATE", dbapi_type=[DBAPI_TYPE_TIMESTAMP], pa_type=lambda: pa.date64()
86+
name="REAL", dbapi_type=[DBAPI_TYPE_NUMBER], pa_type=lambda _: pa.float64()
87+
),
88+
FieldType(
89+
name="TEXT", dbapi_type=[DBAPI_TYPE_STRING], pa_type=lambda _: pa.string()
90+
),
91+
FieldType(
92+
name="DATE", dbapi_type=[DBAPI_TYPE_TIMESTAMP], pa_type=lambda _: pa.date64()
5593
),
5694
FieldType(
5795
name="TIMESTAMP",
5896
dbapi_type=[DBAPI_TYPE_TIMESTAMP],
59-
pa_type=lambda: pa.time64("ns"),
97+
pa_type=lambda _: pa.time64("ns"),
6098
),
6199
FieldType(
62-
name="VARIANT", dbapi_type=[DBAPI_TYPE_BINARY], pa_type=lambda: pa.string()
100+
name="VARIANT", dbapi_type=[DBAPI_TYPE_BINARY], pa_type=lambda _: pa.string()
63101
),
64102
FieldType(
65103
name="TIMESTAMP_LTZ",
66104
dbapi_type=[DBAPI_TYPE_TIMESTAMP],
67-
pa_type=lambda: pa.timestamp("ns"),
105+
pa_type=lambda _: pa.timestamp("ns"),
68106
),
69107
FieldType(
70108
name="TIMESTAMP_TZ",
71109
dbapi_type=[DBAPI_TYPE_TIMESTAMP],
72-
pa_type=lambda: pa.timestamp("ns"),
110+
pa_type=lambda _: pa.timestamp("ns"),
73111
),
74112
FieldType(
75113
name="TIMESTAMP_NTZ",
76114
dbapi_type=[DBAPI_TYPE_TIMESTAMP],
77-
pa_type=lambda: pa.timestamp("ns"),
115+
pa_type=lambda _: pa.timestamp("ns"),
78116
),
79117
FieldType(
80-
name="OBJECT", dbapi_type=[DBAPI_TYPE_BINARY], pa_type=lambda: pa.string()
118+
name="OBJECT", dbapi_type=[DBAPI_TYPE_BINARY], pa_type=lambda _: pa.string()
81119
),
82120
FieldType(
83-
name="ARRAY", dbapi_type=[DBAPI_TYPE_BINARY], pa_type=lambda: pa.string()
121+
name="ARRAY", dbapi_type=[DBAPI_TYPE_BINARY], pa_type=lambda _: pa.string()
84122
),
85123
FieldType(
86-
name="BINARY", dbapi_type=[DBAPI_TYPE_BINARY], pa_type=lambda: pa.binary()
124+
name="BINARY", dbapi_type=[DBAPI_TYPE_BINARY], pa_type=lambda _: pa.binary()
87125
),
88126
FieldType(
89-
name="TIME", dbapi_type=[DBAPI_TYPE_TIMESTAMP], pa_type=lambda: pa.time64("ns")
90-
),
91-
FieldType(name="BOOLEAN", dbapi_type=[], pa_type=lambda: pa.bool_()),
92-
FieldType(
93-
name="GEOGRAPHY", dbapi_type=[DBAPI_TYPE_STRING], pa_type=lambda: pa.string()
127+
name="TIME",
128+
dbapi_type=[DBAPI_TYPE_TIMESTAMP],
129+
pa_type=lambda _: pa.time64("ns"),
94130
),
131+
FieldType(name="BOOLEAN", dbapi_type=[], pa_type=lambda _: pa.bool_()),
95132
FieldType(
96-
name="GEOMETRY", dbapi_type=[DBAPI_TYPE_STRING], pa_type=lambda: pa.string()
133+
name="GEOGRAPHY", dbapi_type=[DBAPI_TYPE_STRING], pa_type=lambda _: pa.string()
97134
),
98135
FieldType(
99-
# TODO(SNOW-969160): While pa.binary() results in the correct pandas column
100-
# type being generated, it should be switched to pa.list_(...) once parsing
101-
# for the new result metadata fields is added.
102-
name="VECTOR",
103-
dbapi_type=[DBAPI_TYPE_BINARY],
104-
pa_type=lambda: pa.binary(),
136+
name="GEOMETRY", dbapi_type=[DBAPI_TYPE_STRING], pa_type=lambda _: pa.string()
105137
),
138+
FieldType(name="VECTOR", dbapi_type=[DBAPI_TYPE_BINARY], pa_type=vector_pa_type),
106139
)
107140

108141
FIELD_NAME_TO_ID: DefaultDict[Any, int] = defaultdict(int)

0 commit comments

Comments
 (0)