Skip to content

Commit 705341e

Browse files
sfc-gh-stakedaankit-bhatnagar167
authored andcommitted
SNOW-119348: support dictionary cursor for ARROW format result set
1 parent 7ffbe1b commit 705341e

File tree

9 files changed

+120
-45
lines changed

9 files changed

+120
-45
lines changed

arrow_iterator.pyx

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ cdef extern from "cpp/ArrowIterator/CArrowChunkIterator.hpp" namespace "sf":
3535
cdef cppclass CArrowChunkIterator(CArrowIterator):
3636
CArrowChunkIterator(PyObject* context, vector[shared_ptr[CRecordBatch]]* batches) except +
3737

38+
cdef cppclass DictCArrowChunkIterator(CArrowChunkIterator):
39+
DictCArrowChunkIterator(PyObject* context, vector[shared_ptr[CRecordBatch]]* batches) except +
40+
3841

3942
cdef extern from "cpp/ArrowIterator/CArrowTableIterator.hpp" namespace "sf":
4043
cdef cppclass CArrowTableIterator(CArrowIterator):
@@ -117,11 +120,6 @@ cdef extern from "arrow/python/api.h" namespace "arrow::py" nogil:
117120

118121

119122
cdef class EmptyPyArrowIterator:
120-
def __cinit__(self, object arrow_stream_reader, object arrow_context):
121-
pass
122-
123-
def __dealloc__(self):
124-
pass
125123

126124
def __next__(self):
127125
raise StopIteration
@@ -136,8 +134,9 @@ cdef class PyArrowIterator(EmptyPyArrowIterator):
136134
cdef str unit
137135
cdef PyObject* cret
138136
cdef vector[shared_ptr[CRecordBatch]] batches
137+
cdef object use_dict_result
139138

140-
def __cinit__(self, object py_inputstream, object arrow_context):
139+
def __cinit__(self, object py_inputstream, object arrow_context, object use_dict_result):
141140
cdef shared_ptr[InputStream] input_stream
142141
cdef shared_ptr[CRecordBatchReader] reader
143142
cdef shared_ptr[CRecordBatch] record_batch
@@ -175,6 +174,7 @@ cdef class PyArrowIterator(EmptyPyArrowIterator):
175174
self.context = arrow_context
176175
self.cIterator = NULL
177176
self.unit = ''
177+
self.use_dict_result = use_dict_result
178178

179179
def __dealloc__(self):
180180
del self.cIterator
@@ -198,7 +198,9 @@ cdef class PyArrowIterator(EmptyPyArrowIterator):
198198
if iter_unit != ROW_UNIT and iter_unit != TABLE_UNIT:
199199
raise NotImplementedError
200200
elif iter_unit == ROW_UNIT:
201-
self.cIterator = new CArrowChunkIterator(<PyObject*>self.context, &self.batches)
201+
self.cIterator = new CArrowChunkIterator(<PyObject*>self.context, &self.batches) if not self.use_dict_result \
202+
else new DictCArrowChunkIterator(<PyObject*>self.context, &self.batches)
203+
202204
elif iter_unit == TABLE_UNIT:
203205
self.cIterator = new CArrowTableIterator(<PyObject*>self.context, &self.batches)
204206
self.unit = iter_unit

arrow_result.pyx

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# cython: language_level=3
77

88
from base64 import b64decode
9+
from libcpp cimport bool
910
import io
1011
from logging import getLogger
1112
from .telemetry import TelemetryField
@@ -35,11 +36,14 @@ cdef class ArrowResult:
3536
object _chunk_downloader
3637
object _arrow_context
3738
str _iter_unit
39+
object _use_dict_result
3840

39-
def __init__(self, raw_response, cursor, _chunk_downloader=None):
41+
42+
def __init__(self, raw_response, cursor, use_dict_result=False, _chunk_downloader=None):
4043
self._reset()
4144
self._cursor = cursor
4245
self._connection = cursor.connection
46+
self._use_dict_result = use_dict_result
4347
self._chunk_info(raw_response, _chunk_downloader)
4448

4549
def _chunk_info(self, data, _chunk_downloader=None):
@@ -53,10 +57,10 @@ cdef class ArrowResult:
5357
if rowset_b64:
5458
arrow_bytes = b64decode(rowset_b64)
5559
self._arrow_context = ArrowConverterContext(self._connection._session_parameters)
56-
self._current_chunk_row = PyArrowIterator(io.BytesIO(arrow_bytes), self._arrow_context)
60+
self._current_chunk_row = PyArrowIterator(io.BytesIO(arrow_bytes), self._arrow_context, self._use_dict_result)
5761
else:
5862
logger.debug("Data from first gs response is empty")
59-
self._current_chunk_row = EmptyPyArrowIterator(None, None)
63+
self._current_chunk_row = EmptyPyArrowIterator()
6064
self._iter_unit = EMPTY_UNIT
6165

6266
if u'chunks' in data:
@@ -127,7 +131,7 @@ cdef class ArrowResult:
127131
self._chunk_downloader._total_millis_parsing_chunks)
128132
self._chunk_downloader = None
129133
self._chunk_count = 0
130-
self._current_chunk_row = EmptyPyArrowIterator(None, None)
134+
self._current_chunk_row = EmptyPyArrowIterator()
131135
is_done = True
132136

133137
if is_done:
@@ -149,7 +153,7 @@ cdef class ArrowResult:
149153
def _reset(self):
150154
self.total_row_index = -1 # last fetched number of rows
151155
self._current_chunk_row_count = 0
152-
self._current_chunk_row = EmptyPyArrowIterator(None, None)
156+
self._current_chunk_row = EmptyPyArrowIterator()
153157
self._chunk_index = 0
154158

155159
if hasattr(self, u'_chunk_count') and self._chunk_count > 0 and \
@@ -208,7 +212,7 @@ cdef class ArrowResult:
208212
self._chunk_downloader._total_millis_parsing_chunks)
209213
self._chunk_downloader = None
210214
self._chunk_count = 0
211-
self._current_chunk_row = EmptyPyArrowIterator(None, None)
215+
self._current_chunk_row = EmptyPyArrowIterator()
212216
finally:
213217
if self._cursor._first_chunk_time:
214218
logger.info("fetching data into pandas dataframe done")

chunk_downloader.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def _fetch_chunk(self, url, headers):
249249
handler = JsonBinaryHandler(is_raw_binary_iterator=True,
250250
use_ijson=self._use_ijson) \
251251
if self._query_result_format == 'json' else \
252-
ArrowBinaryHandler(self._cursor.description, self._connection)
252+
ArrowBinaryHandler(self._cursor, self._connection)
253253

254254
return self._connection.rest.fetch(
255255
u'get', url, headers,
@@ -316,8 +316,8 @@ def to_iterator(self, raw_data_fd, download_time):
316316

317317
class ArrowBinaryHandler(RawBinaryDataHandler):
318318

319-
def __init__(self, meta, connection):
320-
self._meta = meta
319+
def __init__(self, cursor, connection):
320+
self._cursor = cursor
321321
self._arrow_context = ArrowConverterContext(connection._session_parameters)
322322

323323
"""
@@ -326,5 +326,5 @@ def __init__(self, meta, connection):
326326
def to_iterator(self, raw_data_fd, download_time):
327327
from .arrow_iterator import PyArrowIterator
328328
gzip_decoder = GzipFile(fileobj=raw_data_fd, mode='r')
329-
it = PyArrowIterator(gzip_decoder, self._arrow_context)
329+
it = PyArrowIterator(gzip_decoder, self._arrow_context, self._cursor._use_dict_result)
330330
return it

cpp/ArrowIterator/CArrowChunkIterator.cpp

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ PyObject* CArrowChunkIterator::next()
3737

3838
if (m_rowIndexInBatch < m_rowCountInBatch)
3939
{
40-
this->currentRowAsTuple();
40+
this->createRowPyObject();
4141
if (py::checkPyError())
4242
{
4343
return nullptr;
@@ -60,7 +60,7 @@ PyObject* CArrowChunkIterator::next()
6060
logger.debug("Current batch index: %d, rows in current batch: %d",
6161
m_currentBatchIndex, m_rowCountInBatch);
6262

63-
this->currentRowAsTuple();
63+
this->createRowPyObject();
6464
if (py::checkPyError())
6565
{
6666
return nullptr;
@@ -74,7 +74,7 @@ PyObject* CArrowChunkIterator::next()
7474
return Py_None;
7575
}
7676

77-
void CArrowChunkIterator::currentRowAsTuple()
77+
void CArrowChunkIterator::createRowPyObject()
7878
{
7979
m_latestReturnedRow.reset(PyTuple_New(m_columnCount));
8080
for (int i = 0; i < m_columnCount; i++)
@@ -91,13 +91,13 @@ void CArrowChunkIterator::initColumnConverters()
9191
m_currentBatchConverters.clear();
9292
std::shared_ptr<arrow::RecordBatch> currentBatch =
9393
(*m_cRecordBatches)[m_currentBatchIndex];
94-
std::shared_ptr<arrow::Schema> schema = currentBatch->schema();
94+
m_currentSchema = currentBatch->schema();
9595
for (int i = 0; i < currentBatch->num_columns(); i++)
9696
{
9797
std::shared_ptr<arrow::Array> columnArray = currentBatch->column(i);
98-
std::shared_ptr<arrow::DataType> dt = schema->field(i)->type();
98+
std::shared_ptr<arrow::DataType> dt = m_currentSchema->field(i)->type();
9999
std::shared_ptr<const arrow::KeyValueMetadata> metaData =
100-
schema->field(i)->metadata();
100+
m_currentSchema->field(i)->metadata();
101101
SnowflakeType::Type st = SnowflakeType::snowflakeTypeFromString(
102102
metaData->value(metaData->FindKey("logicalType")));
103103

@@ -407,4 +407,22 @@ void CArrowChunkIterator::initColumnConverters()
407407
}
408408
}
409409

410+
DictCArrowChunkIterator::DictCArrowChunkIterator(PyObject* context,
411+
std::vector<std::shared_ptr<arrow::RecordBatch>> * batches)
412+
: CArrowChunkIterator(context, batches)
413+
{
414+
}
415+
416+
void DictCArrowChunkIterator::createRowPyObject()
417+
{
418+
m_latestReturnedRow.reset(PyDict_New());
419+
for (int i = 0; i < m_currentSchema->num_fields(); i++)
420+
{
421+
PyDict_SetItemString(
422+
m_latestReturnedRow.get(), m_currentSchema->field(i)->name().c_str(),
423+
m_currentBatchConverters[i]->toPyObject(m_rowIndexInBatch));
424+
}
425+
return;
426+
}
427+
410428
} // namespace sf

cpp/ArrowIterator/CArrowChunkIterator.hpp

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,24 @@ class CArrowChunkIterator : public CArrowIterator
3434
*/
3535
PyObject* next() override;
3636

37+
protected:
38+
/**
39+
* @return python object of tuple which is tuple of all row values
40+
*/
41+
virtual void createRowPyObject();
42+
43+
/** pointer to the latest returned python tuple(row) result */
44+
py::UniqueRef m_latestReturnedRow;
45+
46+
/** list of column converters*/
47+
std::vector<std::shared_ptr<sf::IColumnConverter>> m_currentBatchConverters;
48+
49+
/** row index inside current record batch (start from 0) */
50+
int m_rowIndexInBatch;
51+
52+
/** schema of current record batch */
53+
std::shared_ptr<arrow::Schema> m_currentSchema;
54+
3755
private:
3856
/** number of columns */
3957
int m_columnCount;
@@ -44,28 +62,29 @@ class CArrowChunkIterator : public CArrowIterator
4462
/** current index that iterator points to */
4563
int m_currentBatchIndex;
4664

47-
/** row index inside current record batch (start from 0) */
48-
int m_rowIndexInBatch;
49-
5065
/** total number of rows inside current record batch */
5166
int64_t m_rowCountInBatch;
5267

53-
/** pointer to the latest returned python tuple(row) result */
54-
py::UniqueRef m_latestReturnedRow;
55-
56-
/** list of column converters*/
57-
std::vector<std::shared_ptr<sf::IColumnConverter>> m_currentBatchConverters;
58-
5968
/** arrow format convert context for the current session */
6069
PyObject* m_context;
6170

62-
/**
63-
* @return python object of tuple which is tuple of all row values
64-
*/
65-
void currentRowAsTuple();
66-
6771
void initColumnConverters();
6872
};
73+
74+
class DictCArrowChunkIterator : public CArrowChunkIterator
75+
{
76+
public:
77+
DictCArrowChunkIterator(PyObject* context, std::vector<std::shared_ptr<arrow::RecordBatch>> *);
78+
79+
~DictCArrowChunkIterator() = default;
80+
81+
private:
82+
83+
void createRowPyObject() override;
84+
85+
};
86+
87+
6988
}
7089

7190
#endif // PC_ARROWCHUNKITERATOR_HPP

cursor.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,14 @@ class SnowflakeCursor(object):
8282
u(r'alter\s+session\s+set\s+(.*)=\'?([^\']+)\'?\s*;'),
8383
flags=re.IGNORECASE | re.MULTILINE | re.DOTALL)
8484

85-
def __init__(self, connection, json_result_class=JsonResult):
85+
def __init__(self, connection, use_dict_result=False, json_result_class=JsonResult):
86+
"""
87+
:param connection: connection created this cursor
88+
:param use_dict_result: whether use dict result or not. This variable only applied to
89+
arrow result. When result in json, json_result_class will be
90+
honored
91+
:param json_result_class: class that used in json result
92+
"""
8693
self._connection = connection
8794

8895
self._errorhandler = Error.default_errorhandler
@@ -106,6 +113,7 @@ def __init__(self, connection, json_result_class=JsonResult):
106113
self._timezone = None
107114
self._binary_output_format = None
108115
self._result = None
116+
self._use_dict_result = use_dict_result
109117
self._json_result_class = json_result_class
110118

111119
self._arraysize = 1 # PEP-0249: defaults to 1
@@ -623,7 +631,7 @@ def _init_result_and_meta(self, data, use_ijson=False):
623631

624632
if self._query_result_format == 'arrow':
625633
self.check_can_use_arrow_resultset()
626-
self._result = ArrowResult(data, self)
634+
self._result = ArrowResult(data, self, use_dict_result=self._use_dict_result)
627635
else:
628636
self._result = self._json_result_class(data, self, use_ijson)
629637

@@ -944,4 +952,4 @@ class DictCursor(SnowflakeCursor):
944952
"""
945953

946954
def __init__(self, connection):
947-
SnowflakeCursor.__init__(self, connection, DictJsonResult)
955+
SnowflakeCursor.__init__(self, connection, use_dict_result=True, json_result_class=DictJsonResult)

test/test_arrow_result.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import random
99
import pytest
1010
from datetime import datetime
11+
import snowflake.connector
1112
try:
1213
from snowflake.connector.arrow_iterator import PyArrowIterator
1314
no_arrow_iterator_ext = False
@@ -374,6 +375,26 @@ def test_select_with_large_resultset(conn_cnx):
374375
iterate_over_test_chunk("large_resultset", conn_cnx, sql_text, row_count, col_count)
375376

376377

378+
def test_dict_cursor(conn_cnx):
379+
with conn_cnx() as cnx:
380+
with cnx.cursor(snowflake.connector.DictCursor) as c:
381+
c.execute("alter session set python_connector_query_result_format='ARROW'")
382+
383+
# first test small result generated by GS
384+
ret = c.execute("select 1 as foo, 2 as bar").fetchone()
385+
assert ret['FOO'] == 1
386+
assert ret['BAR'] == 2
387+
388+
# test larger result set
389+
row_index = 1
390+
for row in c.execute("select row_number() over (order by val asc) as foo, "
391+
"row_number() over (order by val asc) as bar "
392+
"from (select seq4() as val from table(generator(rowcount=>10000)));"):
393+
assert row['FOO'] == row_index
394+
assert row['BAR'] == row_index
395+
row_index += 1
396+
397+
377398
def get_random_seed():
378399
random.seed(datetime.now())
379400
return random.randint(0, 10000)

test/test_unit_arrow_chunk_iterator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,7 @@ def iterate_over_test_chunk(pyarrow_type, column_meta, source_data_generator, ex
538538
# seek stream to begnning so that we can read from stream
539539
stream.seek(0)
540540
context = ArrowConverterContext()
541-
it = PyArrowIterator(stream, context)
541+
it = PyArrowIterator(stream, context, False)
542542
it.init(ROW_UNIT)
543543

544544
count = 0

0 commit comments

Comments
 (0)