Skip to content

Commit 60e22d5

Browse files
committed
SNOW-773928: nanoarrow decimal support (#1538)
1 parent 59db9b8 commit 60e22d5

File tree

4 files changed

+171
-8
lines changed

4 files changed

+171
-8
lines changed

src/snowflake/connector/cpp/ArrowIterator/CArrowTableIterator.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ void CArrowTableIterator::convertScaledFixedNumberColumnToDecimalColumn_nanoarro
312312
// create new schema
313313
ArrowSchemaInit(newSchema);
314314
newSchema->flags &= (field->schema->flags & ARROW_FLAG_NULLABLE); // map to nullable()
315-
ArrowSchemaSetType(newSchema, NANOARROW_TYPE_DECIMAL128); // map to arrow:float64()
315+
ArrowSchemaSetTypeDecimal(newSchema, NANOARROW_TYPE_DECIMAL128, 38, scale);
316316
ArrowSchemaSetName(newSchema, field->schema->name);
317317

318318
ArrowError error;
@@ -325,14 +325,17 @@ void CArrowTableIterator::convertScaledFixedNumberColumnToDecimalColumn_nanoarro
325325
logger->error(__FILE__, __func__, __LINE__, errorInfo.c_str());
326326
PyErr_SetString(PyExc_Exception, errorInfo.c_str());
327327
}
328-
328+
ArrowArrayStartAppending(newArray);
329329
for(int64_t rowIdx = 0; rowIdx < columnArray->array->length; rowIdx++)
330330
{
331331
if(ArrowArrayViewIsNull(columnArray, rowIdx)) {
332332
ArrowArrayAppendNull(newArray, 1);
333333
} else {
334334
auto originalVal = ArrowArrayViewGetIntUnsafe(columnArray, rowIdx);
335-
// TODO: nanoarrow is missing appending a decimal value to array
335+
std::shared_ptr<ArrowDecimal> arrowDecimal = std::make_shared<ArrowDecimal>();
336+
ArrowDecimalInit(arrowDecimal.get(), 128, 38, scale);
337+
ArrowDecimalSetInt(arrowDecimal.get(), originalVal);
338+
ArrowArrayAppendDecimal(newArray, arrowDecimal.get());
336339
}
337340
}
338341
ArrowArrayFinishBuildingDefault(newArray, &error);

src/snowflake/connector/cpp/ArrowIterator/nanoarrow.h

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,13 @@ typedef int ArrowErrorCode;
204204
#define NANOARROW_RETURN_NOT_OK(EXPR) \
205205
_NANOARROW_RETURN_NOT_OK_IMPL(_NANOARROW_MAKE_NAME(errno_status_, __COUNTER__), EXPR)
206206

207+
static char _ArrowIsLittleEndian(void) {
208+
uint32_t check = 1;
209+
char first_byte;
210+
memcpy(&first_byte, &check, sizeof(char));
211+
return first_byte;
212+
}
213+
207214
/// \brief Arrow type enumerator
208215
/// \ingroup nanoarrow-utils
209216
///
@@ -586,6 +593,90 @@ struct ArrowArrayPrivateData {
586593
int8_t union_type_id_is_child_index;
587594
};
588595

596+
/// \brief A representation of a fixed-precision decimal number
597+
/// \ingroup nanoarrow-utils
598+
///
599+
/// This structure should be initialized with ArrowDecimalInit() once and
600+
/// values set using ArrowDecimalSetInt(), ArrowDecimalSetBytes128(),
601+
/// or ArrowDecimalSetBytes256().
602+
struct ArrowDecimal {
603+
/// \brief An array of 64-bit integers of n_words length defined in native-endian order
604+
uint64_t words[4];
605+
606+
/// \brief The number of significant digits this decimal number can represent
607+
int32_t precision;
608+
609+
/// \brief The number of digits after the decimal point. This can be negative.
610+
int32_t scale;
611+
612+
/// \brief The number of words in the words array
613+
int n_words;
614+
615+
/// \brief Cached value used by the implementation
616+
int high_word_index;
617+
618+
/// \brief Cached value used by the implementation
619+
int low_word_index;
620+
};
621+
622+
/// \brief Initialize a decimal with a given set of type parameters
623+
/// \ingroup nanoarrow-utils
624+
static inline void ArrowDecimalInit(struct ArrowDecimal* decimal, int32_t bitwidth,
625+
int32_t precision, int32_t scale) {
626+
memset(decimal->words, 0, sizeof(decimal->words));
627+
decimal->precision = precision;
628+
decimal->scale = scale;
629+
decimal->n_words = bitwidth / 8 / sizeof(uint64_t);
630+
631+
if (_ArrowIsLittleEndian()) {
632+
decimal->low_word_index = 0;
633+
decimal->high_word_index = decimal->n_words - 1;
634+
} else {
635+
decimal->low_word_index = decimal->n_words - 1;
636+
decimal->high_word_index = 0;
637+
}
638+
}
639+
640+
/// \brief Get a signed integer value of a sufficiently small ArrowDecimal
641+
///
642+
/// This does not check if the decimal's precision sufficiently small to fit
643+
/// within the signed 64-bit integer range (A precision less than or equal
644+
/// to 18 is sufficiently small).
645+
static inline int64_t ArrowDecimalGetIntUnsafe(struct ArrowDecimal* decimal) {
646+
return (int64_t)decimal->words[decimal->low_word_index];
647+
}
648+
649+
/// \brief Copy the bytes of this decimal into a sufficiently large buffer
650+
/// \ingroup nanoarrow-utils
651+
static inline void ArrowDecimalGetBytes(struct ArrowDecimal* decimal, uint8_t* out) {
652+
memcpy(out, decimal->words, decimal->n_words * sizeof(uint64_t));
653+
}
654+
655+
/// \brief Returns 1 if the value represented by decimal is >= 0 or -1 otherwise
656+
/// \ingroup nanoarrow-utils
657+
static inline int64_t ArrowDecimalSign(struct ArrowDecimal* decimal) {
658+
return 1 | ((int64_t)(decimal->words[decimal->high_word_index]) >> 63);
659+
}
660+
661+
/// \brief Sets the integer value of this decimal
662+
/// \ingroup nanoarrow-utils
663+
static inline void ArrowDecimalSetInt(struct ArrowDecimal* decimal, int64_t value) {
664+
if (value < 0) {
665+
memset(decimal->words, 0xff, decimal->n_words * sizeof(uint64_t));
666+
} else {
667+
memset(decimal->words, 0, decimal->n_words * sizeof(uint64_t));
668+
}
669+
670+
decimal->words[decimal->low_word_index] = value;
671+
}
672+
673+
/// \brief Copy bytes from a buffer into this decimal
674+
/// \ingroup nanoarrow-utils
675+
static inline void ArrowDecimalSetBytes(struct ArrowDecimal* decimal,
676+
const uint8_t* value) {
677+
memcpy(decimal->words, value, decimal->n_words * sizeof(uint64_t));
678+
}
679+
589680
#ifdef __cplusplus
590681
}
591682
#endif
@@ -1417,12 +1508,20 @@ static inline ArrowErrorCode ArrowArrayAppendBytes(struct ArrowArray* array,
14171508
struct ArrowBufferView value);
14181509

14191510
/// \brief Append a string value to an array
1511+
///
14201512
/// Returns NANOARROW_OK if value can be exactly represented by
14211513
/// the underlying storage type or EINVAL otherwise (e.g.,
14221514
/// the underlying array is not a string or large string array).
14231515
static inline ArrowErrorCode ArrowArrayAppendString(struct ArrowArray* array,
14241516
struct ArrowStringView value);
14251517

1518+
/// \brief Append a decimal value to an array
1519+
///
1520+
/// Returns NANOARROW_OK if array is a decimal array with the appropriate
1521+
/// bitwidth or EINVAL otherwise.
1522+
static inline ArrowErrorCode ArrowArrayAppendDecimal(struct ArrowArray* array,
1523+
struct ArrowDecimal* value);
1524+
14261525
/// \brief Finish a nested array element
14271526
///
14281527
/// Appends a non-null element to the array based on the first child's current
@@ -1559,6 +1658,14 @@ static inline struct ArrowStringView ArrowArrayViewGetStringUnsafe(
15591658
static inline struct ArrowBufferView ArrowArrayViewGetBytesUnsafe(
15601659
struct ArrowArrayView* array_view, int64_t i);
15611660

1661+
/// \brief Get an element in an ArrowArrayView as an ArrowDecimal
1662+
///
1663+
/// This function does not check for null values. The out parameter must
1664+
/// be initialized with ArrowDecimalInit() with the proper parameters for this
1665+
/// type before calling this for the first time.
1666+
static inline void ArrowArrayViewGetDecimalUnsafe(struct ArrowArrayView* array_view,
1667+
int64_t i, struct ArrowDecimal* out);
1668+
15621669
/// @}
15631670

15641671
/// \defgroup nanoarrow-basic-array-stream Basic ArrowArrayStream implementation
@@ -2622,6 +2729,41 @@ static inline ArrowErrorCode ArrowArrayAppendString(struct ArrowArray* array,
26222729
}
26232730
}
26242731

2732+
static inline ArrowErrorCode ArrowArrayAppendDecimal(struct ArrowArray* array,
2733+
struct ArrowDecimal* value) {
2734+
struct ArrowArrayPrivateData* private_data =
2735+
(struct ArrowArrayPrivateData*)array->private_data;
2736+
struct ArrowBuffer* data_buffer = ArrowArrayBuffer(array, 1);
2737+
2738+
switch (private_data->storage_type) {
2739+
case NANOARROW_TYPE_DECIMAL128:
2740+
if (value->n_words != 2) {
2741+
return EINVAL;
2742+
} else {
2743+
NANOARROW_RETURN_NOT_OK(
2744+
ArrowBufferAppend(data_buffer, value->words, 2 * sizeof(uint64_t)));
2745+
break;
2746+
}
2747+
case NANOARROW_TYPE_DECIMAL256:
2748+
if (value->n_words != 4) {
2749+
return EINVAL;
2750+
} else {
2751+
NANOARROW_RETURN_NOT_OK(
2752+
ArrowBufferAppend(data_buffer, value->words, 4 * sizeof(uint64_t)));
2753+
break;
2754+
}
2755+
default:
2756+
return EINVAL;
2757+
}
2758+
2759+
if (private_data->bitmap.buffer.data != NULL) {
2760+
NANOARROW_RETURN_NOT_OK(ArrowBitmapAppend(ArrowArrayValidityBitmap(array), 1, 1));
2761+
}
2762+
2763+
array->length++;
2764+
return NANOARROW_OK;
2765+
}
2766+
26252767
static inline ArrowErrorCode ArrowArrayFinishElement(struct ArrowArray* array) {
26262768
struct ArrowArrayPrivateData* private_data =
26272769
(struct ArrowArrayPrivateData*)array->private_data;
@@ -2931,6 +3073,23 @@ static inline struct ArrowBufferView ArrowArrayViewGetBytesUnsafe(
29313073
return view;
29323074
}
29333075

3076+
static inline void ArrowArrayViewGetDecimalUnsafe(struct ArrowArrayView* array_view,
3077+
int64_t i, struct ArrowDecimal* out) {
3078+
i += array_view->array->offset;
3079+
const uint8_t* data_view = array_view->buffer_views[1].data.as_uint8;
3080+
switch (array_view->storage_type) {
3081+
case NANOARROW_TYPE_DECIMAL128:
3082+
ArrowDecimalSetBytes(out, data_view + (i * 16));
3083+
break;
3084+
case NANOARROW_TYPE_DECIMAL256:
3085+
ArrowDecimalSetBytes(out, data_view + (i * 32));
3086+
break;
3087+
default:
3088+
memset(out->words, 0, sizeof(out->words));
3089+
break;
3090+
}
3091+
}
3092+
29343093
#ifdef __cplusplus
29353094
}
29363095
#endif

src/snowflake/connector/cpp/ArrowIterator/nanoarrow.hpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
//
2+
// Copyright (c) 2012-2021 Snowflake Computing Inc. All rights reserved.
3+
//
4+
15
// Licensed to the Apache Software Foundation (ASF) under one
26
// or more contributor license agreements. See the NOTICE file
37
// distributed with this work for additional information
@@ -278,10 +282,10 @@ class VectorArrayStream : public EmptyArrayStream {
278282

279283
protected:
280284
VectorArrayStream(struct ArrowSchema* schema, std::vector<UniqueArray> arrays)
281-
: EmptyArrayStream(schema), offset_(0), arrays_(std::move(arrays)) {}
285+
: EmptyArrayStream(schema), arrays_(std::move(arrays)), offset_(0) {}
282286

283287
int get_next(struct ArrowArray* array) {
284-
if (offset_ < arrays_.size()) {
288+
if (offset_ < static_cast<int64_t>(arrays_.size())) {
285289
arrays_[offset_++].move(array);
286290
} else {
287291
array->release = nullptr;

test/integ/pandas/test_arrow_pandas.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -936,7 +936,6 @@ def test_query_resultscan_combos(conn_cnx, query_format, resultscan_format):
936936
],
937937
)
938938
def test_number_fetchall_retrieve_type(conn_cnx, use_decimal, expected):
939-
pytest.skip("missing decimal")
940939
with conn_cnx(arrow_number_to_decimal=use_decimal) as con:
941940
with con.cursor() as cur:
942941
cur.execute("SELECT 12345600.87654301::NUMBER(18, 8) a")
@@ -956,7 +955,6 @@ def test_number_fetchall_retrieve_type(conn_cnx, use_decimal, expected):
956955
],
957956
)
958957
def test_number_fetchbatches_retrieve_type(conn_cnx, use_decimal: bool, expected: type):
959-
pytest.skip("missing decimal")
960958
with conn_cnx(arrow_number_to_decimal=use_decimal) as con:
961959
with con.cursor() as cur:
962960
cur.execute("SELECT 12345600.87654301::NUMBER(18, 8) a")
@@ -1249,7 +1247,6 @@ def test_timestamp_tz(conn_cnx):
12491247

12501248

12511249
def test_arrow_number_to_decimal(conn_cnx):
1252-
pytest.skip("missing decimal")
12531250
with conn_cnx(
12541251
session_parameters={
12551252
PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT: "arrow_force"

0 commit comments

Comments
 (0)