Skip to content

Commit 202ba7f

Browse files
sfc-gh-stakedaankit-bhatnagar167
authored andcommitted
SNOW-87349: Outupt cython code-generated .cpp code to build directory
1 parent 28dc2ed commit 202ba7f

File tree

8 files changed

+146
-23
lines changed

8 files changed

+146
-23
lines changed

arrow_iterator.pyx

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
#
2+
# Copyright (c) 2012-2019 Snowflake Computing Inc. All right reserved.
3+
#
4+
15
# distutils: language = c++
26

37
from cpython.ref cimport PyObject

cpp/ArrowIterator/CArrowChunkIterator.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
#include "CArrowChunkIterator.hpp"
55
#include "IntConverter.hpp"
66
#include "StringConverter.hpp"
7+
#include "FloatConverter.hpp"
78

8-
sf::CArrowChunkIterator::CArrowChunkIterator()
9+
sf::CArrowChunkIterator::CArrowChunkIterator() : m_latestReturnedRow(nullptr)
910
{
1011
this->reset();
1112
}
@@ -25,11 +26,15 @@ void sf::CArrowChunkIterator::reset()
2526
m_currentBatchIndex = -1;
2627
m_rowIndexInBatch = -1;
2728
m_rowCountInBatch = 0;
29+
Py_XDECREF(m_latestReturnedRow);
30+
m_latestReturnedRow = nullptr;
2831
}
2932

3033
PyObject * sf::CArrowChunkIterator::nextRow()
3134
{
3235
m_rowIndexInBatch ++;
36+
Py_XDECREF(m_latestReturnedRow);
37+
m_latestReturnedRow = nullptr;
3338

3439
if (m_rowIndexInBatch < m_rowCountInBatch)
3540
{
@@ -57,7 +62,7 @@ PyObject * sf::CArrowChunkIterator::currentRowAsTuple()
5762
{
5863
PyTuple_SET_ITEM(tuple, i, m_currentBatchConverters[i]->toPyObject(m_rowIndexInBatch));
5964
}
60-
return tuple;
65+
return m_latestReturnedRow = tuple;
6166
}
6267

6368
void sf::CArrowChunkIterator::initColumnConverters()
@@ -97,6 +102,11 @@ void sf::CArrowChunkIterator::initColumnConverters()
97102
std::make_shared<sf::StringConverter>(columnArray.get()));
98103
break;
99104

105+
case arrow::Type::type::DOUBLE:
106+
m_currentBatchConverters.push_back(
107+
std::make_shared<sf::FloatConverter>(columnArray.get()));
108+
break;
109+
100110
default:
101111
break;
102112
}

cpp/ArrowIterator/CArrowChunkIterator.hpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,10 @@ class CArrowChunkIterator
2828
/**
2929
* Desctructor
3030
*/
31-
~CArrowChunkIterator() = default;
31+
~CArrowChunkIterator()
32+
{
33+
Py_XDECREF(m_latestReturnedRow);
34+
}
3235

3336
/**
3437
* Add Arrow RecordBach to current chunk
@@ -59,6 +62,9 @@ class CArrowChunkIterator
5962

6063
/** total number of rows inside current record batch */
6164
int64_t m_rowCountInBatch;
65+
66+
/** pointer to the latest returned python tuple(row) result */
67+
PyObject* m_latestReturnedRow;
6268

6369
/** list of column converters*/
6470
std::vector<std::shared_ptr<sf::IColumnConverter>> m_currentBatchConverters;
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
/*
2+
* Copyright (c) 2013-2019 Snowflake Computing
3+
*/
4+
#include "FloatConverter.hpp"
5+
6+
sf::FloatConverter::FloatConverter(arrow::Array* array)
7+
{
8+
/** snowflake float is 64-precision, which refers to double here */
9+
m_array = dynamic_cast<arrow::DoubleArray*>(array);
10+
}
11+
12+
PyObject* sf::FloatConverter::toPyObject(int64_t rowIndex)
13+
{
14+
return (m_array->IsValid(rowIndex)) ? PyFloat_FromDouble(m_array->Value(rowIndex)) : Py_None;
15+
}
16+
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
/*
2+
* Copyright (c) 2013-2019 Snowflake Computing
3+
*/
4+
#ifndef PC_FLOATCONVERTER_HPP
5+
#define PC_FLOATCONVERTER_HPP
6+
7+
#include "IColumnConverter.hpp"
8+
9+
namespace sf
10+
{
11+
12+
class FloatConverter : public IColumnConverter
13+
{
14+
public:
15+
explicit FloatConverter(arrow::Array* array);
16+
17+
PyObject* toPyObject(long rowIndex) override;
18+
19+
private:
20+
arrow::DoubleArray* m_array;
21+
};
22+
}
23+
24+
#endif

cpp/ArrowIterator/IntConverter.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
#include "IColumnConverter.hpp"
88

9-
109
namespace sf
1110
{
1211

setup.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from os import path
88
import os
99
import sys
10+
from sys import platform
1011
from shutil import copy
1112
import glob
1213

@@ -38,20 +39,16 @@
3839

3940
if isBuildExtEnabled == 'true':
4041
from Cython.Distutils import build_ext
42+
from Cython.Build import cythonize
4143
import os
4244
import pyarrow
4345

44-
extensions = [
45-
Extension(name='snowflake.connector.arrow_result', sources=['arrow_result.pyx']),
46-
Extension(name='snowflake.connector.arrow_iterator', sources=['arrow_iterator.pyx',
47-
'cpp/ArrowIterator/CArrowChunkIterator.cpp',
48-
'cpp/ArrowIterator/IntConverter.cpp',
49-
'cpp/ArrowIterator/StringConverter.cpp'],
50-
extra_compile_args=['-std=c++11'],
51-
extra_link_args=['-Wl,-rpath,$ORIGIN'],
52-
language='c++'
53-
),
54-
]
46+
extensions = cythonize(
47+
[
48+
Extension(name='snowflake.connector.arrow_iterator', sources=['arrow_iterator.pyx']),
49+
Extension(name='snowflake.connector.arrow_result', sources=['arrow_result.pyx'])
50+
],
51+
build_dir=os.path.join('build', 'cython'))
5552

5653
class MyBuildExt(build_ext):
5754

@@ -61,36 +58,50 @@ def build_extension(self, ext):
6158
if ext.name == 'snowflake.connector.arrow_iterator':
6259
self._copy_arrow_lib()
6360

64-
ext.include_dirs.append(self._get_arrow_include_dir())
61+
ext.sources += ['cpp/ArrowIterator/CArrowChunkIterator.cpp',
62+
'cpp/ArrowIterator/FloatConverter.cpp',
63+
'cpp/ArrowIterator/IntConverter.cpp',
64+
'cpp/ArrowIterator/StringConverter.cpp']
65+
ext.include_dirs.append('cpp/ArrowIterator/')
66+
ext.include_dirs.append(pyarrow.get_include())
67+
68+
ext.extra_compile_args.append('-std=c++11')
69+
6570
ext.library_dirs.append(os.path.join(current_dir, self.build_lib, 'snowflake', 'connector'))
6671
ext.extra_link_args += self._get_arrow_lib_as_linker_input()
72+
ext.extra_link_args += ['-Wl,-rpath,$ORIGIN']
6773

6874
build_ext.build_extension(self, ext)
6975

70-
def _get_arrow_include_dir(self):
71-
return pyarrow.get_include()
72-
7376
def _get_arrow_lib_dir(self):
7477
return pyarrow.get_library_dirs()[0]
7578

7679
def _copy_arrow_lib(self):
7780
arrow_lib = pyarrow.get_libraries() + \
7881
['arrow_flight', 'arrow_boost_regex', 'arrow_boost_system', 'arrow_boost_filesystem']
7982
for lib in arrow_lib:
80-
lib_pattern = '{}/lib{}.so*'.format(self._get_arrow_lib_dir(), lib)
83+
lib_pattern = self._get_pyarrow_lib_pattern(lib)
8184
source = glob.glob(lib_pattern)[0]
8285
copy(source, os.path.join(self.build_lib, 'snowflake', 'connector'))
8386

8487
def _get_arrow_lib_as_linker_input(self):
8588
arrow_lib = pyarrow.get_libraries()
8689
link_lib = []
8790
for lib in arrow_lib:
88-
lib_pattern = '{}/lib{}.so*'.format(self._get_arrow_lib_dir(), lib)
91+
lib_pattern = self._get_pyarrow_lib_pattern(lib)
8992
source = glob.glob(lib_pattern)[0]
9093
link_lib.append(source)
9194

9295
return link_lib
9396

97+
def _get_pyarrow_lib_pattern(self, lib_name):
98+
if platform.startswith('linux'):
99+
return '{}/lib{}.so*'.format(self._get_arrow_lib_dir(), lib_name)
100+
elif platform == 'darwin':
101+
return '{}/lib{}*dylib'.format(self._get_arrow_lib_dir(), lib_name)
102+
else:
103+
raise RuntimeError('Building on platform {} is not supported yet.'.format(platform))
104+
94105
cmd_class = {
95106
"build_ext": MyBuildExt
96107
}
@@ -165,8 +176,8 @@ def _get_arrow_lib_as_linker_input(self):
165176
'keyring!=16.1.0'
166177
],
167178
"arrow-result": [
168-
'pyarrow>=0.13.0;python_version>"3.4"',
169-
'pyarrow>=0.13.0;python_version<"3.0"'
179+
'pyarrow>=0.14.0;python_version>"3.4"',
180+
'pyarrow>=0.14.0;python_version<"3.0"'
170181
]
171182
},
172183

test/test_unit_arrow_chunk_iterator.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,3 +124,56 @@ def test_iterate_over_int64_chunk():
124124
assert count == 100
125125
break
126126

127+
128+
@pytest.mark.skip(
129+
reason="Cython is not enabled in build env")
130+
def test_iterate_over_float_chunk():
131+
stream = BytesIO()
132+
field_foo = pyarrow.field("column_foo", pyarrow.float64(), True)
133+
field_bar = pyarrow.field("column_bar", pyarrow.float64(), True)
134+
schema = pyarrow.schema([field_foo, field_bar])
135+
column_meta = [
136+
("column_foo", "FLOAT", None, 0, 0, 0, 0),
137+
("column_bar", "FLOAT", None, 0, 0, 0, 0)
138+
]
139+
140+
column_size = 2
141+
batch_row_count = 10
142+
batch_count = 10
143+
expected_data = []
144+
writer = RecordBatchStreamWriter(stream, schema)
145+
146+
for i in range(batch_count):
147+
column_arrays = []
148+
py_arrays = []
149+
for j in range(column_size):
150+
column_data = []
151+
for k in range(batch_row_count):
152+
data = None if bool(random.getrandbits(1)) else random.uniform(-100.0, 100.0)
153+
column_data.append(data)
154+
column_arrays.append(column_data)
155+
py_arrays.append(pyarrow.array(column_data))
156+
157+
expected_data.append(column_arrays)
158+
rb = RecordBatch.from_arrays(py_arrays, ["column_foo", "column_bar"])
159+
writer.write_batch(rb)
160+
161+
writer.close()
162+
163+
# seek stream to begnning so that we can read from stream
164+
stream.seek(0)
165+
reader = RecordBatchStreamReader(stream)
166+
it = PyArrowChunkIterator()
167+
for rb in reader:
168+
it.add_record_batch(rb)
169+
170+
count = 0
171+
while True:
172+
try:
173+
val = next(it)
174+
assert val[0] == expected_data[int(count / 10)][0][count % 10]
175+
assert val[1] == expected_data[int(count / 10)][1][count % 10]
176+
count += 1
177+
except StopIteration:
178+
assert count == 100
179+
break

0 commit comments

Comments
 (0)