Skip to content

Commit 9e1dd4f

Browse files
sfc-gh-tkissingersfc-gh-pczajka
authored andcommitted
SNOW-1915469 Basic support for DECFLOAT type (#2167)
1 parent de3a4c1 commit 9e1dd4f

File tree

11 files changed

+250
-3
lines changed

11 files changed

+250
-3
lines changed

DESCRIPTION.md

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,6 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne
4646
- Added a feature to verify if the connection is still good enough to send queries over.
4747
- Added support for base64-encoded DER private key strings in the `private_key` authentication type.
4848

49-
- v3.12.5(TBD)
50-
- Added a feature to limit the sizes of IO-bound ThreadPoolExecutors during PUT and GET commands.
51-
5249
- v3.12.4(December 3,2024)
5350
- Fixed a bug where multipart uploads to Azure would be missing their MD5 hashes.
5451
- Fixed a bug where OpenTelemetry header injection would sometimes cause Exceptions to be thrown.

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def build_extension(self, ext):
101101
"CArrowIterator.cpp",
102102
"CArrowTableIterator.cpp",
103103
"DateConverter.cpp",
104+
"DecFloatConverter.cpp",
104105
"DecimalConverter.cpp",
105106
"FixedSizeListConverter.cpp",
106107
"FloatConverter.cpp",

src/snowflake/connector/arrow_context.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,3 +159,11 @@ def DECIMAL128_to_decimal(self, int128_bytes: bytes, scale: int) -> decimal.Deci
159159
digits = [int(digit) for digit in str(int128) if digit != "-"]
160160
sign = int128 < 0
161161
return decimal.Decimal((sign, digits, -scale))
162+
163+
def DECFLOAT_to_decimal(self, exponent: int, significand: bytes) -> decimal.Decimal:
164+
# significand is two's complement big endian.
165+
significand = int.from_bytes(significand, byteorder="big", signed=True)
166+
return decimal.Decimal(significand).scaleb(exponent)
167+
168+
def DECFLOAT_to_numpy_float64(self, exponent: int, significand: bytes) -> float64:
169+
return numpy.float64(self.DECFLOAT_to_decimal(exponent, significand))

src/snowflake/connector/converter.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,12 @@ def conv(value: str) -> int64:
203203

204204
return conv
205205

206+
def _DECFLOAT_numpy_to_python(self, ctx: dict[str, Any]) -> Callable:
207+
return numpy.float64
208+
209+
def _DECFLOAT_to_python(self, ctx: dict[str, Any]) -> Callable:
210+
return decimal.Decimal
211+
206212
def _REAL_to_python(self, _: dict[str, str | None] | dict[str, str]) -> Callable:
207213
return float
208214

src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowChunkIterator.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "BinaryConverter.hpp"
1313
#include "BooleanConverter.hpp"
1414
#include "DateConverter.hpp"
15+
#include "DecFloatConverter.hpp"
1516
#include "DecimalConverter.hpp"
1617
#include "FixedSizeListConverter.hpp"
1718
#include "FloatConverter.hpp"
@@ -471,6 +472,12 @@ std::shared_ptr<sf::IColumnConverter> getConverterFromSchema(
471472
break;
472473
}
473474

475+
case SnowflakeType::Type::DECFLOAT: {
476+
converter = std::make_shared<sf::DecFloatConverter>(*array, schemaView,
477+
*context, useNumpy);
478+
break;
479+
}
480+
474481
default: {
475482
std::string errorInfo = Logger::formatString(
476483
"[Snowflake Exception] unknown snowflake data type : %d", st);
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
2+
//
3+
// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
4+
//
5+
6+
#include "DecFloatConverter.hpp"
7+
8+
#include <cstring>
9+
#include <memory>
10+
11+
#include "Python/Helpers.hpp"
12+
13+
namespace sf {
14+
15+
Logger* DecFloatConverter::logger =
16+
new Logger("snowflake.connector.DecFloatConverter");
17+
18+
const std::string DecFloatConverter::FIELD_NAME_EXPONENT = "exponent";
19+
const std::string DecFloatConverter::FIELD_NAME_SIGNIFICAND = "significand";
20+
21+
DecFloatConverter::DecFloatConverter(ArrowArrayView& array,
22+
ArrowSchemaView& schema, PyObject& context,
23+
bool useNumpy)
24+
: m_context(context),
25+
m_array(array),
26+
m_exponent(nullptr),
27+
m_significand(nullptr),
28+
m_useNumpy(useNumpy) {
29+
if (schema.schema->n_children != 2) {
30+
std::string errorInfo = Logger::formatString(
31+
"[Snowflake Exception] arrow schema field number does not match, "
32+
"expected 2 but got %d instead",
33+
schema.schema->n_children);
34+
logger->error(__FILE__, __func__, __LINE__, errorInfo.c_str());
35+
PyErr_SetString(PyExc_Exception, errorInfo.c_str());
36+
return;
37+
}
38+
for (int i = 0; i < schema.schema->n_children; i += 1) {
39+
ArrowSchema* c_schema = schema.schema->children[i];
40+
if (std::strcmp(c_schema->name,
41+
DecFloatConverter::FIELD_NAME_EXPONENT.c_str()) == 0) {
42+
m_exponent = m_array.children[i];
43+
} else if (std::strcmp(c_schema->name,
44+
DecFloatConverter::FIELD_NAME_SIGNIFICAND.c_str()) ==
45+
0) {
46+
m_significand = m_array.children[i];
47+
}
48+
}
49+
if (!m_exponent || !m_significand) {
50+
std::string errorInfo = Logger::formatString(
51+
"[Snowflake Exception] arrow schema field names do not match, "
52+
"expected %s and %s, but got %s and %s instead",
53+
DecFloatConverter::FIELD_NAME_EXPONENT.c_str(),
54+
DecFloatConverter::FIELD_NAME_SIGNIFICAND.c_str(),
55+
schema.schema->children[0]->name, schema.schema->children[1]->name);
56+
logger->error(__FILE__, __func__, __LINE__, errorInfo.c_str());
57+
PyErr_SetString(PyExc_Exception, errorInfo.c_str());
58+
return;
59+
}
60+
}
61+
62+
PyObject* DecFloatConverter::toPyObject(int64_t rowIndex) const {
63+
if (ArrowArrayViewIsNull(&m_array, rowIndex)) {
64+
Py_RETURN_NONE;
65+
}
66+
int64_t exponent = ArrowArrayViewGetIntUnsafe(m_exponent, rowIndex);
67+
ArrowStringView stringView =
68+
ArrowArrayViewGetStringUnsafe(m_significand, rowIndex);
69+
if (stringView.size_bytes > 16) {
70+
std::string errorInfo = Logger::formatString(
71+
"[Snowflake Exception] only precisions up to 38 supported. "
72+
"Please update to a newer version of the connector.");
73+
logger->error(__FILE__, __func__, __LINE__, errorInfo.c_str());
74+
PyErr_SetString(PyExc_Exception, errorInfo.c_str());
75+
return nullptr;
76+
}
77+
PyObject* significand =
78+
PyBytes_FromStringAndSize(stringView.data, stringView.size_bytes);
79+
80+
PyObject* result = PyObject_CallMethod(
81+
&m_context,
82+
m_useNumpy ? "DECFLOAT_to_numpy_float64" : "DECFLOAT_to_decimal", "iS",
83+
exponent, significand);
84+
Py_XDECREF(significand);
85+
return result;
86+
}
87+
} // namespace sf
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
2+
//
3+
// Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
4+
//
5+
6+
#ifndef PC_DECFLOATCONVERTER_HPP
7+
#define PC_DECFLOATCONVERTER_HPP
8+
9+
#include <memory>
10+
11+
#include "IColumnConverter.hpp"
12+
#include "logging.hpp"
13+
#include "nanoarrow.h"
14+
15+
namespace sf {
16+
17+
class DecFloatConverter : public IColumnConverter {
18+
public:
19+
const static std::string FIELD_NAME_EXPONENT;
20+
const static std::string FIELD_NAME_SIGNIFICAND;
21+
22+
explicit DecFloatConverter(ArrowArrayView& array, ArrowSchemaView& schema,
23+
PyObject& context, bool useNumpy);
24+
25+
PyObject* toPyObject(int64_t rowIndex) const override;
26+
27+
private:
28+
PyObject& m_context;
29+
ArrowArrayView& m_array;
30+
ArrowArrayView* m_exponent;
31+
ArrowArrayView* m_significand;
32+
bool m_useNumpy;
33+
34+
static Logger* logger;
35+
};
36+
37+
} // namespace sf
38+
39+
#endif // PC_DECFLOATCONVERTER_HPP

src/snowflake/connector/nanoarrow_cpp/ArrowIterator/SnowflakeType.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ std::unordered_map<std::string, SnowflakeType::Type>
1717
{"DOUBLE PRECISION", SnowflakeType::Type::REAL},
1818
{"DOUBLE", SnowflakeType::Type::REAL},
1919
{"FIXED", SnowflakeType::Type::FIXED},
20+
{"DECFLOAT", SnowflakeType::Type::DECFLOAT},
2021
{"FLOAT", SnowflakeType::Type::REAL},
2122
{"MAP", SnowflakeType::Type::MAP},
2223
{"OBJECT", SnowflakeType::Type::OBJECT},

src/snowflake/connector/nanoarrow_cpp/ArrowIterator/SnowflakeType.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class SnowflakeType {
3333
VARIANT = 15,
3434
VECTOR = 16,
3535
MAP = 17,
36+
DECFLOAT = 18,
3637
};
3738

3839
static SnowflakeType::Type snowflakeTypeFromString(std::string str) {

test/integ/test_decfloat.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
#!/usr/bin/env python
2+
#
3+
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
4+
#
5+
6+
from __future__ import annotations
7+
8+
import decimal
9+
from decimal import Decimal
10+
11+
import numpy
12+
13+
import snowflake.connector
14+
15+
16+
def test_decfloat_bindings(conn_cnx):
17+
# set required decimal precision
18+
decimal.getcontext().prec = 38
19+
original_style = snowflake.connector.paramstyle
20+
snowflake.connector.paramstyle = "qmark"
21+
try:
22+
with conn_cnx() as cnx:
23+
# test decfloat bindings
24+
ret = (
25+
cnx.cursor()
26+
.execute("select ?", [("DECFLOAT", Decimal("-1234e4000"))])
27+
.fetchone()
28+
)
29+
assert isinstance(ret[0], Decimal)
30+
assert ret[0] == Decimal("-1234e4000")
31+
ret = cnx.cursor().execute("select ?", [("DECFLOAT", -1e3)]).fetchone()
32+
assert isinstance(ret[0], Decimal)
33+
assert ret[0] == Decimal("-1e3")
34+
# test 38 digits
35+
ret = (
36+
cnx.cursor()
37+
.execute(
38+
"select ?",
39+
[("DECFLOAT", Decimal("12345678901234567890123456789012345678"))],
40+
)
41+
.fetchone()
42+
)
43+
assert isinstance(ret[0], Decimal)
44+
assert ret[0] == Decimal("12345678901234567890123456789012345678")
45+
# test w/o explicit type specification
46+
ret = cnx.cursor().execute("select ?", [-1e3]).fetchone()
47+
assert isinstance(ret[0], float)
48+
ret = cnx.cursor().execute("select ?", [Decimal("-1e3")]).fetchone()
49+
assert isinstance(ret[0], int)
50+
finally:
51+
snowflake.connector.paramstyle = original_style
52+
53+
54+
def test_decfloat_from_compiler(conn_cnx):
55+
# set required decimal precision
56+
decimal.getcontext().prec = 38
57+
# test both result formats
58+
for fmt in ["json", "arrow"]:
59+
with conn_cnx(
60+
session_parameters={
61+
"PYTHON_CONNECTOR_QUERY_RESULT_FORMAT": fmt,
62+
"use_cached_result": "false",
63+
}
64+
) as cnx:
65+
# test endianess
66+
ret = cnx.cursor().execute("SELECT 555::decfloat").fetchone()
67+
assert isinstance(ret[0], Decimal)
68+
assert ret[0] == Decimal("555")
69+
# test with decimal separator
70+
ret = cnx.cursor().execute("SELECT 123456789.12345678::decfloat").fetchone()
71+
assert isinstance(ret[0], Decimal)
72+
assert ret[0] == Decimal("123456789.12345678")
73+
# test 38 digits
74+
ret = (
75+
cnx.cursor()
76+
.execute("SELECT '12345678901234567890123456789012345678'::decfloat")
77+
.fetchone()
78+
)
79+
assert isinstance(ret[0], Decimal)
80+
assert ret[0] == Decimal("12345678901234567890123456789012345678")
81+
# test numpy
82+
with conn_cnx(numpy=True) as cnx:
83+
ret = (
84+
cnx.cursor()
85+
.execute(
86+
"SELECT 1.234::decfloat",
87+
None,
88+
)
89+
.fetchone()
90+
)
91+
assert isinstance(ret[0], numpy.float64)
92+
assert ret[0] == numpy.float64("1.234")

0 commit comments

Comments
 (0)