|
15 | 15 | if TYPE_CHECKING: |
16 | 16 | from pyarrow import DataType |
17 | 17 |
|
| 18 | + from .cursor import ResultMetadataV2 |
| 19 | + |
18 | 20 | # Snowflake's central platform dependent directories, if the folder |
19 | 21 | # ~/.snowflake/ (customizable by the environment variable SNOWFLAKE_HOME) exists |
20 | 22 | # we use that folder for everything. Otherwise, we fall back to platformdirs |
|
38 | 40 | class FieldType(NamedTuple): |
39 | 41 | name: str |
40 | 42 | dbapi_type: list[int] |
41 | | - pa_type: Callable[[], DataType] |
| 43 | + pa_type: Callable[[ResultMetadataV2], DataType] |
| 44 | + |
| 45 | + |
| 46 | +def vector_pa_type(metadata: ResultMetadataV2) -> DataType: |
| 47 | + """Generate the arrow type represented by the given vector column metadata.""" |
| 48 | + |
| 49 | + if metadata.fields is None: |
| 50 | + raise ValueError( |
| 51 | + "Invalid result metadata for vector type: expected sub-field metadata" |
| 52 | + ) |
| 53 | + if len(metadata.fields) != 1: |
| 54 | + raise ValueError( |
| 55 | + "Invalid result metadata for vector type: expected a single sub-field metadata" |
| 56 | + ) |
| 57 | + field_type = FIELD_ID_TO_NAME[metadata.fields[0].type_code] |
| 58 | + |
| 59 | + if metadata.vector_dimension is None: |
| 60 | + raise ValueError( |
| 61 | + "Invalid result metadata for vector type: expected a dimension" |
| 62 | + ) |
| 63 | + |
| 64 | + if field_type == "FIXED": |
| 65 | + return pa.list_(pa.int32(), metadata.vector_dimension) |
| 66 | + elif field_type == "REAL": |
| 67 | + return pa.list_(pa.float32(), metadata.vector_dimension) |
| 68 | + else: |
| 69 | + raise ValueError( |
| 70 | + f"Invalid result metadata for vector type: invalid element type: {field_type}" |
| 71 | + ) |
42 | 72 |
|
43 | 73 |
|
44 | 74 | # This type mapping holds column type definitions. |
45 | 75 | # Be careful to not change the ordering as the index is what Snowflake |
46 | 76 | # gives to as schema |
| 77 | +# |
| 78 | +# `name` is the SQL name of the type, `dbapi_type` is the set of corresponding |
| 79 | +# PEP 249 type objects, and `pa_type` is a lambda that takes in a column's |
| 80 | +# result metadata and returns the corresponding Arrow type. |
47 | 81 | FIELD_TYPES: tuple[FieldType, ...] = ( |
48 | | - FieldType(name="FIXED", dbapi_type=[DBAPI_TYPE_NUMBER], pa_type=lambda: pa.int64()), |
49 | 82 | FieldType( |
50 | | - name="REAL", dbapi_type=[DBAPI_TYPE_NUMBER], pa_type=lambda: pa.float64() |
| 83 | + name="FIXED", dbapi_type=[DBAPI_TYPE_NUMBER], pa_type=lambda _: pa.int64() |
51 | 84 | ), |
52 | | - FieldType(name="TEXT", dbapi_type=[DBAPI_TYPE_STRING], pa_type=lambda: pa.string()), |
53 | 85 | FieldType( |
54 | | - name="DATE", dbapi_type=[DBAPI_TYPE_TIMESTAMP], pa_type=lambda: pa.date64() |
| 86 | + name="REAL", dbapi_type=[DBAPI_TYPE_NUMBER], pa_type=lambda _: pa.float64() |
| 87 | + ), |
| 88 | + FieldType( |
| 89 | + name="TEXT", dbapi_type=[DBAPI_TYPE_STRING], pa_type=lambda _: pa.string() |
| 90 | + ), |
| 91 | + FieldType( |
| 92 | + name="DATE", dbapi_type=[DBAPI_TYPE_TIMESTAMP], pa_type=lambda _: pa.date64() |
55 | 93 | ), |
56 | 94 | FieldType( |
57 | 95 | name="TIMESTAMP", |
58 | 96 | dbapi_type=[DBAPI_TYPE_TIMESTAMP], |
59 | | - pa_type=lambda: pa.time64("ns"), |
| 97 | + pa_type=lambda _: pa.time64("ns"), |
60 | 98 | ), |
61 | 99 | FieldType( |
62 | | - name="VARIANT", dbapi_type=[DBAPI_TYPE_BINARY], pa_type=lambda: pa.string() |
| 100 | + name="VARIANT", dbapi_type=[DBAPI_TYPE_BINARY], pa_type=lambda _: pa.string() |
63 | 101 | ), |
64 | 102 | FieldType( |
65 | 103 | name="TIMESTAMP_LTZ", |
66 | 104 | dbapi_type=[DBAPI_TYPE_TIMESTAMP], |
67 | | - pa_type=lambda: pa.timestamp("ns"), |
| 105 | + pa_type=lambda _: pa.timestamp("ns"), |
68 | 106 | ), |
69 | 107 | FieldType( |
70 | 108 | name="TIMESTAMP_TZ", |
71 | 109 | dbapi_type=[DBAPI_TYPE_TIMESTAMP], |
72 | | - pa_type=lambda: pa.timestamp("ns"), |
| 110 | + pa_type=lambda _: pa.timestamp("ns"), |
73 | 111 | ), |
74 | 112 | FieldType( |
75 | 113 | name="TIMESTAMP_NTZ", |
76 | 114 | dbapi_type=[DBAPI_TYPE_TIMESTAMP], |
77 | | - pa_type=lambda: pa.timestamp("ns"), |
| 115 | + pa_type=lambda _: pa.timestamp("ns"), |
78 | 116 | ), |
79 | 117 | FieldType( |
80 | | - name="OBJECT", dbapi_type=[DBAPI_TYPE_BINARY], pa_type=lambda: pa.string() |
| 118 | + name="OBJECT", dbapi_type=[DBAPI_TYPE_BINARY], pa_type=lambda _: pa.string() |
81 | 119 | ), |
82 | 120 | FieldType( |
83 | | - name="ARRAY", dbapi_type=[DBAPI_TYPE_BINARY], pa_type=lambda: pa.string() |
| 121 | + name="ARRAY", dbapi_type=[DBAPI_TYPE_BINARY], pa_type=lambda _: pa.string() |
84 | 122 | ), |
85 | 123 | FieldType( |
86 | | - name="BINARY", dbapi_type=[DBAPI_TYPE_BINARY], pa_type=lambda: pa.binary() |
| 124 | + name="BINARY", dbapi_type=[DBAPI_TYPE_BINARY], pa_type=lambda _: pa.binary() |
87 | 125 | ), |
88 | 126 | FieldType( |
89 | | - name="TIME", dbapi_type=[DBAPI_TYPE_TIMESTAMP], pa_type=lambda: pa.time64("ns") |
90 | | - ), |
91 | | - FieldType(name="BOOLEAN", dbapi_type=[], pa_type=lambda: pa.bool_()), |
92 | | - FieldType( |
93 | | - name="GEOGRAPHY", dbapi_type=[DBAPI_TYPE_STRING], pa_type=lambda: pa.string() |
| 127 | + name="TIME", |
| 128 | + dbapi_type=[DBAPI_TYPE_TIMESTAMP], |
| 129 | + pa_type=lambda _: pa.time64("ns"), |
94 | 130 | ), |
| 131 | + FieldType(name="BOOLEAN", dbapi_type=[], pa_type=lambda _: pa.bool_()), |
95 | 132 | FieldType( |
96 | | - name="GEOMETRY", dbapi_type=[DBAPI_TYPE_STRING], pa_type=lambda: pa.string() |
| 133 | + name="GEOGRAPHY", dbapi_type=[DBAPI_TYPE_STRING], pa_type=lambda _: pa.string() |
97 | 134 | ), |
98 | 135 | FieldType( |
99 | | - # TODO(SNOW-969160): While pa.binary() results in the correct pandas column |
100 | | - # type being generated, it should be switched to pa.list_(...) once parsing |
101 | | - # for the new result metadata fields is added. |
102 | | - name="VECTOR", |
103 | | - dbapi_type=[DBAPI_TYPE_BINARY], |
104 | | - pa_type=lambda: pa.binary(), |
| 136 | + name="GEOMETRY", dbapi_type=[DBAPI_TYPE_STRING], pa_type=lambda _: pa.string() |
105 | 137 | ), |
| 138 | + FieldType(name="VECTOR", dbapi_type=[DBAPI_TYPE_BINARY], pa_type=vector_pa_type), |
106 | 139 | ) |
107 | 140 |
|
108 | 141 | FIELD_NAME_TO_ID: DefaultDict[Any, int] = defaultdict(int) |
|
0 commit comments