Skip to content

Commit 3ddf772

Browse files
sfc-gh-stakedaankit-bhatnagar167
authored andcommitted
SNOW-100191: Change pyarrow as optional dependency. Internally we bundled .so file of arrow cpp lib
1 parent e5495fd commit 3ddf772

12 files changed

+151
-54
lines changed

arrow_iterator.pyx

Lines changed: 99 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@
77

88
from logging import getLogger
99
from cpython.ref cimport PyObject
10+
from libc.stdint cimport *
11+
from libcpp cimport bool as c_bool
12+
from libcpp.memory cimport shared_ptr
13+
from libcpp.string cimport string as c_string
14+
from libcpp.vector cimport vector
1015

1116
logger = getLogger(__name__)
1217

@@ -26,12 +31,87 @@ cdef extern from "cpp/ArrowIterator/CArrowIterator.hpp" namespace "sf":
2631

2732
cdef extern from "cpp/ArrowIterator/CArrowChunkIterator.hpp" namespace "sf":
2833
cdef cppclass CArrowChunkIterator(CArrowIterator):
29-
CArrowChunkIterator(PyObject* context, PyObject* batches) except +
34+
CArrowChunkIterator(PyObject* context, vector[shared_ptr[CRecordBatch]]* batches) except +
3035

3136

3237
cdef extern from "cpp/ArrowIterator/CArrowTableIterator.hpp" namespace "sf":
3338
cdef cppclass CArrowTableIterator(CArrowIterator):
34-
CArrowTableIterator(PyObject* context, PyObject* batches) except +
39+
CArrowTableIterator(PyObject* context, vector[shared_ptr[CRecordBatch]]* batches) except +
40+
41+
42+
cdef extern from "arrow/api.h" namespace "arrow" nogil:
43+
cdef cppclass CStatus "arrow::Status":
44+
CStatus()
45+
46+
c_string ToString()
47+
c_string message()
48+
49+
c_bool ok()
50+
c_bool IsIOError()
51+
c_bool IsOutOfMemory()
52+
c_bool IsInvalid()
53+
c_bool IsKeyError()
54+
c_bool IsNotImplemented()
55+
c_bool IsTypeError()
56+
c_bool IsCapacityError()
57+
c_bool IsIndexError()
58+
c_bool IsSerializationError()
59+
60+
61+
cdef cppclass CBuffer" arrow::Buffer":
62+
CBuffer(const uint8_t* data, int64_t size)
63+
64+
cdef cppclass CRecordBatch" arrow::RecordBatch"
65+
66+
cdef cppclass CRecordBatchReader" arrow::RecordBatchReader":
67+
CStatus ReadNext(shared_ptr[CRecordBatch]* batch)
68+
69+
70+
cdef extern from "arrow/ipc/api.h" namespace "arrow::ipc" nogil:
71+
cdef cppclass CRecordBatchStreamReader \
72+
" arrow::ipc::RecordBatchStreamReader"(CRecordBatchReader):
73+
@staticmethod
74+
CStatus Open(const InputStream* stream,
75+
shared_ptr[CRecordBatchReader]* out)
76+
77+
78+
cdef extern from "arrow/io/api.h" namespace "arrow::io" nogil:
79+
enum FileMode" arrow::io::FileMode::type":
80+
FileMode_READ" arrow::io::FileMode::READ"
81+
FileMode_WRITE" arrow::io::FileMode::WRITE"
82+
FileMode_READWRITE" arrow::io::FileMode::READWRITE"
83+
84+
cdef cppclass FileInterface:
85+
CStatus Close()
86+
CStatus Tell(int64_t* position)
87+
FileMode mode()
88+
c_bool closed()
89+
90+
cdef cppclass Readable:
91+
# put overload under a different name to avoid cython bug with multiple
92+
# layers of inheritance
93+
CStatus ReadBuffer" Read"(int64_t nbytes, shared_ptr[CBuffer]* out)
94+
CStatus Read(int64_t nbytes, int64_t* bytes_read, uint8_t* out)
95+
96+
cdef cppclass InputStream(FileInterface, Readable):
97+
pass
98+
99+
cdef cppclass Seekable:
100+
CStatus Seek(int64_t position)
101+
102+
cdef cppclass RandomAccessFile(InputStream, Seekable):
103+
CStatus GetSize(int64_t* size)
104+
105+
CStatus ReadAt(int64_t position, int64_t nbytes,
106+
int64_t* bytes_read, uint8_t* buffer)
107+
CStatus ReadAt(int64_t position, int64_t nbytes,
108+
shared_ptr[CBuffer]* out)
109+
c_bool supports_zero_copy()
110+
111+
112+
cdef extern from "arrow/python/api.h" namespace "arrow::py" nogil:
113+
cdef cppclass PyReadableFile(RandomAccessFile):
114+
PyReadableFile(object fo)
35115

36116

37117
cdef class EmptyPyArrowIterator:
@@ -53,12 +133,22 @@ cdef class PyArrowIterator(EmptyPyArrowIterator):
53133
cdef CArrowIterator* cIterator
54134
cdef str unit
55135
cdef PyObject* cret
56-
cdef list batches
136+
cdef vector[shared_ptr[CRecordBatch]] batches
137+
138+
def __cinit__(self, object py_inputstream, object arrow_context):
139+
cdef shared_ptr[InputStream] input_stream
140+
cdef shared_ptr[CRecordBatchReader] reader
141+
cdef shared_ptr[CRecordBatch] record_batch
142+
input_stream.reset(new PyReadableFile(py_inputstream))
143+
CRecordBatchStreamReader.Open(input_stream.get(), &reader)
144+
while True:
145+
reader.get().ReadNext(&record_batch)
146+
147+
if record_batch.get() is NULL:
148+
break
149+
150+
self.batches.push_back(record_batch)
57151

58-
def __cinit__(self, object arrow_stream_reader, object arrow_context):
59-
self.batches = []
60-
for rb in arrow_stream_reader:
61-
self.batches.append(rb)
62152
self.context = arrow_context
63153
self.cIterator = NULL
64154
self.unit = ''
@@ -85,8 +175,8 @@ cdef class PyArrowIterator(EmptyPyArrowIterator):
85175
if iter_unit != ROW_UNIT and iter_unit != TABLE_UNIT:
86176
raise NotImplementedError
87177
elif iter_unit == ROW_UNIT:
88-
self.cIterator = new CArrowChunkIterator(<PyObject*>self.context, <PyObject*>self.batches)
178+
self.cIterator = new CArrowChunkIterator(<PyObject*>self.context, &self.batches)
89179
elif iter_unit == TABLE_UNIT:
90-
self.cIterator = new CArrowTableIterator(<PyObject*>self.context, <PyObject*>self.batches)
180+
self.cIterator = new CArrowTableIterator(<PyObject*>self.context, &self.batches)
91181
self.unit = iter_unit
92182

arrow_result.pyx

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66
# cython: language_level=3
77

88
from base64 import b64decode
9+
import io
910
from logging import getLogger
1011
from .telemetry import TelemetryField
1112
from .time_util import get_time_millis
1213
try:
13-
from pyarrow.ipc import open_stream
14-
from pyarrow import concat_tables
1514
from .arrow_iterator import PyArrowIterator, EmptyPyArrowIterator, ROW_UNIT, TABLE_UNIT, EMPTY_UNIT
1615
from .arrow_context import ArrowConverterContext
16+
from pyarrow import concat_tables
1717
except ImportError:
1818
pass
1919

@@ -52,9 +52,8 @@ cdef class ArrowResult:
5252

5353
if rowset_b64:
5454
arrow_bytes = b64decode(rowset_b64)
55-
arrow_reader = open_stream(arrow_bytes)
5655
self._arrow_context = ArrowConverterContext(self._connection._session_parameters)
57-
self._current_chunk_row = PyArrowIterator(arrow_reader, self._arrow_context)
56+
self._current_chunk_row = PyArrowIterator(io.BytesIO(arrow_bytes), self._arrow_context)
5857
else:
5958
self._current_chunk_row = EmptyPyArrowIterator(None, None)
6059
self._iter_unit = EMPTY_UNIT

chunk_downloader.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from gzip import GzipFile
2020

2121
try:
22-
from pyarrow.ipc import open_stream
2322
from .arrow_iterator import PyArrowIterator
2423
from .arrow_context import ArrowConverterContext
2524
except ImportError:
@@ -331,6 +330,5 @@ def __init__(self, meta, connection):
331330
"""
332331
def to_iterator(self, raw_data_fd, download_time):
333332
gzip_decoder = GzipFile(fileobj=raw_data_fd, mode='r')
334-
reader = open_stream(gzip_decoder)
335-
it = PyArrowIterator(reader, self._arrow_context)
333+
it = PyArrowIterator(gzip_decoder, self._arrow_context)
336334
return it

cpp/ArrowIterator/CArrowChunkIterator.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717
namespace sf
1818
{
1919

20-
CArrowChunkIterator::CArrowChunkIterator(PyObject* context, PyObject* batches)
20+
CArrowChunkIterator::CArrowChunkIterator(PyObject* context, std::vector<std::shared_ptr<arrow::RecordBatch>> *batches)
2121
: CArrowIterator(batches), m_latestReturnedRow(nullptr), m_context(context)
2222
{
23-
m_batchCount = m_cRecordBatches.size();
24-
m_columnCount = m_batchCount > 0 ? m_cRecordBatches[0]->num_columns() : 0;
23+
m_batchCount = m_cRecordBatches->size();
24+
m_columnCount = m_batchCount > 0 ? (*m_cRecordBatches)[0]->num_columns() : 0;
2525
m_currentBatchIndex = -1;
2626
m_rowIndexInBatch = -1;
2727
m_rowCountInBatch = 0;
@@ -50,7 +50,7 @@ PyObject* CArrowChunkIterator::next()
5050
if (m_currentBatchIndex < m_batchCount)
5151
{
5252
m_rowIndexInBatch = 0;
53-
m_rowCountInBatch = m_cRecordBatches[m_currentBatchIndex]->num_rows();
53+
m_rowCountInBatch = (*m_cRecordBatches)[m_currentBatchIndex]->num_rows();
5454
this->initColumnConverters();
5555
if (py::checkPyError())
5656
{
@@ -90,7 +90,7 @@ void CArrowChunkIterator::initColumnConverters()
9090
{
9191
m_currentBatchConverters.clear();
9292
std::shared_ptr<arrow::RecordBatch> currentBatch =
93-
m_cRecordBatches[m_currentBatchIndex];
93+
(*m_cRecordBatches)[m_currentBatchIndex];
9494
std::shared_ptr<arrow::Schema> schema = currentBatch->schema();
9595
for (int i = 0; i < currentBatch->num_columns(); i++)
9696
{

cpp/ArrowIterator/CArrowChunkIterator.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class CArrowChunkIterator : public CArrowIterator
2222
/**
2323
* Constructor
2424
*/
25-
CArrowChunkIterator(PyObject* context, PyObject* batches);
25+
CArrowChunkIterator(PyObject* context, std::vector<std::shared_ptr<arrow::RecordBatch>> *);
2626

2727
/**
2828
* Desctructor

cpp/ArrowIterator/CArrowIterator.cpp

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,10 @@ namespace sf
99

1010
Logger CArrowIterator::logger("snowflake.connector.CArrowIterator");
1111

12-
CArrowIterator::CArrowIterator(PyObject* batches)
12+
CArrowIterator::CArrowIterator(std::vector<std::shared_ptr<arrow::RecordBatch>>* batches) :
13+
m_cRecordBatches(batches)
1314
{
14-
int pyListSize = PyList_Size(batches);
15-
logger.debug("Arrow BatchSize: %d", pyListSize);
16-
17-
for (int i=0; i<pyListSize; i++)
18-
{
19-
std::shared_ptr<arrow::RecordBatch> cRecordBatch;
20-
arrow::Status status = arrow::py::unwrap_record_batch(PyList_GetItem(batches, i), &cRecordBatch);
21-
m_cRecordBatches.push_back(cRecordBatch);
22-
}
15+
logger.debug("Arrow BatchSize: %d", batches->size());
2316
}
2417

2518
}

cpp/ArrowIterator/CArrowIterator.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ namespace sf
2424
class CArrowIterator
2525
{
2626
public:
27-
CArrowIterator(PyObject *batches);
27+
CArrowIterator(std::vector<std::shared_ptr<arrow::RecordBatch>> * batches);
2828

2929
virtual ~CArrowIterator() = default;
3030

@@ -35,7 +35,7 @@ class CArrowIterator
3535

3636
protected:
3737
/** list of all record batch in current chunk */
38-
std::vector<std::shared_ptr<arrow::RecordBatch>> m_cRecordBatches;
38+
std::vector<std::shared_ptr<arrow::RecordBatch>> *m_cRecordBatches;
3939

4040
static Logger logger;
4141
};

cpp/ArrowIterator/CArrowTableIterator.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ namespace sf
2626
void CArrowTableIterator::reconstructRecordBatches()
2727
{
2828
// Type conversion, the code needs to be optimized
29-
for (unsigned int batchIdx = 0; batchIdx < m_cRecordBatches.size(); batchIdx++)
29+
for (unsigned int batchIdx = 0; batchIdx < m_cRecordBatches->size(); batchIdx++)
3030
{
31-
std::shared_ptr<arrow::RecordBatch> currentBatch = m_cRecordBatches[batchIdx];
31+
std::shared_ptr<arrow::RecordBatch> currentBatch = (*m_cRecordBatches)[batchIdx];
3232
std::shared_ptr<arrow::Schema> schema = currentBatch->schema();
3333
for (int colIdx = 0; colIdx < currentBatch->num_columns(); colIdx++)
3434
{
@@ -127,7 +127,7 @@ void CArrowTableIterator::reconstructRecordBatches()
127127
}
128128
}
129129

130-
CArrowTableIterator::CArrowTableIterator(PyObject* context, PyObject* batches)
130+
CArrowTableIterator::CArrowTableIterator(PyObject* context, std::vector<std::shared_ptr<arrow::RecordBatch>>* batches)
131131
: CArrowIterator(batches), m_context(context), m_pyTableObjRef(nullptr)
132132
{
133133
PyObject* tz = PyObject_GetAttrString(m_context, "_timezone");
@@ -156,7 +156,7 @@ arrow::Status CArrowTableIterator::replaceColumn(
156156
const std::shared_ptr<arrow::Array>& newColumn)
157157
{
158158
// replace the targeted column
159-
std::shared_ptr<arrow::RecordBatch> currentBatch = m_cRecordBatches[batchIdx];
159+
std::shared_ptr<arrow::RecordBatch> currentBatch = (*m_cRecordBatches)[batchIdx];
160160
arrow::Status ret = currentBatch->AddColumn(colIdx+1, newField, newColumn, &currentBatch);
161161
if(!ret.ok())
162162
{
@@ -167,7 +167,7 @@ arrow::Status CArrowTableIterator::replaceColumn(
167167
{
168168
return ret;
169169
}
170-
m_cRecordBatches[batchIdx] = currentBatch;
170+
(*m_cRecordBatches)[batchIdx] = currentBatch;
171171
return ret;
172172
}
173173

@@ -842,10 +842,10 @@ void CArrowTableIterator::convertTimestampTZColumn(
842842
bool CArrowTableIterator::convertRecordBatchesToTable()
843843
{
844844
// only do conversion once and there exist some record batches
845-
if (!m_cTable && !m_cRecordBatches.empty())
845+
if (!m_cTable && !m_cRecordBatches->empty())
846846
{
847847
reconstructRecordBatches();
848-
arrow::Table::FromRecordBatches(m_cRecordBatches, &m_cTable);
848+
arrow::Table::FromRecordBatches(*m_cRecordBatches, &m_cTable);
849849
return true;
850850
}
851851
return false;

cpp/ArrowIterator/CArrowTableIterator.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class CArrowTableIterator : public CArrowIterator
2222
/**
2323
* Constructor
2424
*/
25-
CArrowTableIterator(PyObject* context, PyObject* batches);
25+
CArrowTableIterator(PyObject* context, std::vector<std::shared_ptr<arrow::RecordBatch>>* batches);
2626

2727
/**
2828
* Destructor

cursor.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,7 @@ def _init_result_and_meta(self, data, use_ijson=False):
609609
column[u'nullable']))
610610

611611
if self._query_result_format == 'arrow':
612-
self.check_pyarrow_resultset()
612+
self.check_can_use_arrow_resultset()
613613
self._result = ArrowResult(data, self)
614614
else:
615615
self._result = self._json_result_class(data, self, use_ijson)
@@ -628,21 +628,15 @@ def _init_result_and_meta(self, data, use_ijson=False):
628628
else:
629629
self._total_rowcount += updated_rows
630630

631-
def check_pyarrow_resultset(self):
631+
def check_can_use_arrow_resultset(self):
632632
global CAN_USE_ARROW_RESULT
633-
global pyarrow
634633

635634
if not CAN_USE_ARROW_RESULT:
636635
if self._connection.application == 'SnowSQL':
637636
msg = (
638637
"Currently SnowSQL doesn't support the result set in Apache Arrow format."
639638
)
640639
errno = ER_NO_PYARROW_SNOWSQL
641-
elif pyarrow is None:
642-
msg = (
643-
"pyarrow package is missing. Install using pip if the platform is supported."
644-
)
645-
errno = ER_NO_PYARROW
646640
else:
647641
msg = (
648642
"The result set in Apache Arrow format is not supported for the platform."
@@ -658,6 +652,24 @@ def check_pyarrow_resultset(self):
658652
}
659653
)
660654

655+
def check_can_use_panadas(self):
656+
global pyarrow
657+
658+
if pyarrow is None:
659+
msg = (
660+
"pyarrow package is missing. Install using pip if the platform is supported."
661+
)
662+
errno = ER_NO_PYARROW
663+
664+
Error.errorhandler_wrapper(
665+
self.connection, self,
666+
ProgrammingError,
667+
{
668+
u'msg': msg,
669+
u'errno': errno,
670+
}
671+
)
672+
661673
def query_result(self, qid, _use_ijson=False):
662674
url = '/queries/{qid}/result'.format(qid=qid)
663675
ret = self._connection.rest.request(url=url, method='get')
@@ -695,6 +707,7 @@ def fetch_pandas_batches(self, **kwargs):
695707
Fetch a single Arrow Table
696708
@param kwargs: will be passed to pyarrow.Table.to_pandas() method
697709
"""
710+
self.check_can_use_panadas()
698711
if self._query_result_format != 'arrow': # TODO: or pandas isn't imported
699712
raise NotSupportedError
700713
for df in self._result._fetch_pandas_batches(**kwargs):
@@ -705,6 +718,7 @@ def fetch_pandas_all(self, **kwargs):
705718
Fetch Pandas dataframes in batch, where 'batch' refers to Snowflake Chunk
706719
@param kwargs: will be passed to pyarrow.Table.to_pandas() method
707720
"""
721+
self.check_can_use_panadas()
708722
if self._query_result_format != 'arrow':
709723
raise NotSupportedError
710724
return self._result._fetch_pandas_all(**kwargs)

0 commit comments

Comments
 (0)