Skip to content

Commit 35c2f6f

Browse files
SNOW-950840 Add support for the vector data type (#1804)
1 parent 7d83846 commit 35c2f6f

File tree

13 files changed

+291
-3
lines changed

13 files changed

+291
-3
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,3 +124,4 @@ core.*
124124

125125
# Compiled Cython
126126
src/snowflake/connector/arrow_iterator.cpp
127+
src/snowflake/connector/nanoarrow_cpp/ArrowIterator/nanoarrow_arrow_iterator.cpp

DESCRIPTION.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne
99
# Release Notes
1010

1111

12+
- v3.6.0(TBD)
13+
14+
- Added support for Vector types
15+
1216
- v3.5.0(November 13,2023)
1317

1418
- Version 3.5.0 is the snowflake-connector-python purely built upon apache arrow-nanoarrow project.

setup.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@ def build_extension(self, ext):
107107
NANOARROW_ARROW_ITERATOR_SRC_DIR, "DecimalConverter.cpp"
108108
),
109109
os.path.join(NANOARROW_ARROW_ITERATOR_SRC_DIR, "DateConverter.cpp"),
110+
os.path.join(
111+
NANOARROW_ARROW_ITERATOR_SRC_DIR, "FixedSizeListConverter.cpp"
112+
),
110113
os.path.join(
111114
NANOARROW_ARROW_ITERATOR_SRC_DIR, "FloatConverter.cpp"
112115
),

src/snowflake/connector/constants.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,14 @@ class FieldType(NamedTuple):
9595
FieldType(
9696
name="GEOMETRY", dbapi_type=[DBAPI_TYPE_STRING], pa_type=lambda: pa.string()
9797
),
98+
FieldType(
99+
# TODO(SNOW-969160): While pa.binary() results in the correct pandas column
100+
# type being generated, it should be switched to pa.list_(...) once parsing
101+
# for the new result metadata fields is added.
102+
name="VECTOR",
103+
dbapi_type=[DBAPI_TYPE_BINARY],
104+
pa_type=lambda: pa.binary(),
105+
),
98106
)
99107

100108
FIELD_NAME_TO_ID: DefaultDict[Any, int] = defaultdict(int)

src/snowflake/connector/converter.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import binascii
99
import decimal
10+
import json
1011
import time
1112
from datetime import date, datetime
1213
from datetime import time as dt_t
@@ -320,6 +321,9 @@ def _VARIANT_to_python(self, _: dict[str, Any]) -> Any | None:
320321

321322
_ARRAY_to_python = _VARIANT_to_python
322323

324+
def _VECTOR_to_python(self, ctx: dict[str, Any]) -> Callable:
325+
return lambda v: json.loads(v)
326+
323327
def _BOOLEAN_to_python(
324328
self, ctx: dict[str, str | None] | dict[str, str]
325329
) -> Callable:

src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowChunkIterator.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "BinaryConverter.hpp"
1212
#include "BooleanConverter.hpp"
1313
#include "DateConverter.hpp"
14+
#include "FixedSizeListConverter.hpp"
1415
#include "TimeStampConverter.hpp"
1516
#include "TimeConverter.hpp"
1617
#include <memory>
@@ -438,6 +439,12 @@ void CArrowChunkIterator::initColumnConverters()
438439
break;
439440
}
440441

442+
case SnowflakeType::Type::VECTOR:
443+
{
444+
m_currentBatchConverters.push_back(std::make_shared<sf::FixedSizeListConverter>(array));
445+
break;
446+
}
447+
441448
default:
442449
{
443450
std::string errorInfo = Logger::formatString(
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
//
2+
// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
3+
//
4+
5+
#include "FixedSizeListConverter.hpp"
6+
7+
namespace sf
8+
{
9+
Logger* FixedSizeListConverter::logger =
10+
new Logger("snowflake.connector.FixedSizeListConverter");
11+
12+
FixedSizeListConverter::FixedSizeListConverter(ArrowArrayView* array)
13+
: m_array(array)
14+
{
15+
}
16+
17+
void FixedSizeListConverter::generateError(const std::string& msg) const
18+
{
19+
logger->error(__FILE__, __func__, __LINE__, msg.c_str());
20+
PyErr_SetString(PyExc_Exception, msg.c_str());
21+
}
22+
23+
PyObject* FixedSizeListConverter::toPyObject(int64_t rowIndex) const
24+
{
25+
if (ArrowArrayViewIsNull(m_array, rowIndex))
26+
{
27+
Py_RETURN_NONE;
28+
}
29+
30+
if (m_array->n_children != 1)
31+
{
32+
std::string errorInfo = Logger::formatString(
33+
"[Snowflake Exception] invalid arrow element schema for fixed size "
34+
"list: got (%d) "
35+
"children",
36+
m_array->n_children);
37+
this->generateError(errorInfo);
38+
return nullptr;
39+
}
40+
41+
// m_array->length represents the number of fixed size lists in the array
42+
// m_array->children[0] has a buffer view that contains the actual data of
43+
// each list, back-to-back m_array->children[0]->length represents the sum of
44+
// the lengths of the fixed size lists in the array.
45+
46+
ArrowArrayView* elements = m_array->children[0];
47+
const auto fixedSizeArrayLength = elements->length / m_array->length;
48+
PyObject* list = PyList_New(fixedSizeArrayLength);
49+
50+
const int64_t startIndexWithoutOffset = rowIndex * fixedSizeArrayLength;
51+
for (int64_t i = 0; i < fixedSizeArrayLength; ++i)
52+
{
53+
const auto bufferIndexWithoutOffset = startIndexWithoutOffset + i;
54+
// Currently, the backend only sends back INT32 and FLOAT32, but the
55+
// remaining types are enumerated for future use.
56+
switch (elements->storage_type)
57+
{
58+
case NANOARROW_TYPE_INT8:
59+
case NANOARROW_TYPE_INT16:
60+
case NANOARROW_TYPE_INT32:
61+
case NANOARROW_TYPE_INT64:
62+
{
63+
const auto value =
64+
ArrowArrayViewGetIntUnsafe(elements, bufferIndexWithoutOffset);
65+
PyList_SetItem(list, i, PyLong_FromLongLong(value));
66+
} break;
67+
case NANOARROW_TYPE_HALF_FLOAT:
68+
case NANOARROW_TYPE_FLOAT:
69+
case NANOARROW_TYPE_DOUBLE:
70+
{
71+
const auto value =
72+
ArrowArrayViewGetDoubleUnsafe(elements, bufferIndexWithoutOffset);
73+
PyList_SetItem(list, i, PyFloat_FromDouble(value));
74+
} break;
75+
default:
76+
std::string errorInfo = Logger::formatString(
77+
"[Snowflake Exception] invalid arrow element type for fixed size "
78+
"list: got (%s)",
79+
ArrowTypeString(elements->storage_type));
80+
this->generateError(errorInfo);
81+
return nullptr;
82+
}
83+
}
84+
85+
return list;
86+
}
87+
88+
} // namespace sf
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
//
2+
// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
3+
//
4+
5+
#ifndef PC_FIXEDSIZELISTCONVERTER_HPP
6+
#define PC_FIXEDSIZELISTCONVERTER_HPP
7+
8+
#include <memory>
9+
10+
#include "IColumnConverter.hpp"
11+
#include "logging.hpp"
12+
#include "nanoarrow.h"
13+
#include "nanoarrow.hpp"
14+
15+
namespace sf
16+
{
17+
18+
class FixedSizeListConverter : public IColumnConverter
19+
{
20+
public:
21+
explicit FixedSizeListConverter(ArrowArrayView* array);
22+
PyObject* toPyObject(int64_t rowIndex) const override;
23+
24+
private:
25+
void generateError(const std::string& msg) const;
26+
27+
ArrowArrayView* m_array;
28+
29+
static Logger* logger;
30+
};
31+
32+
} // namespace sf
33+
34+
#endif // PC_FIXEDSIZELISTCONVERTER_HPP

src/snowflake/connector/nanoarrow_cpp/ArrowIterator/SnowflakeType.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ std::unordered_map<std::string, SnowflakeType::Type>
2929
{"TIMESTAMP_LTZ", SnowflakeType::Type::TIMESTAMP_LTZ},
3030
{"TIMESTAMP_NTZ", SnowflakeType::Type::TIMESTAMP_NTZ},
3131
{"TIMESTAMP_TZ", SnowflakeType::Type::TIMESTAMP_TZ},
32-
{"VARIANT", SnowflakeType::Type::VARIANT}};
32+
{"VARIANT", SnowflakeType::Type::VARIANT},
33+
{"VECTOR", SnowflakeType::Type::VECTOR}};
3334

3435
} // namespace sf

src/snowflake/connector/nanoarrow_cpp/ArrowIterator/SnowflakeType.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ class SnowflakeType
3232
TIMESTAMP_LTZ = 12,
3333
TIMESTAMP_NTZ = 13,
3434
TIMESTAMP_TZ = 14,
35-
VARIANT = 15
35+
VARIANT = 15,
36+
VECTOR = 16
3637
};
3738

3839
static SnowflakeType::Type snowflakeTypeFromString(std::string str)

0 commit comments

Comments
 (0)