Skip to content

Commit 187eeea

Browse files
committed
add decfloat type
1 parent e0d2c65 commit 187eeea

File tree

11 files changed

+166
-5
lines changed

11 files changed

+166
-5
lines changed

src/snowflake/snowpark/_internal/analyzer/datatype_mapper.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
DataType,
2626
DateType,
2727
DayTimeIntervalType,
28+
DecFloatType,
2829
DecimalType,
2930
DoubleType,
3031
FileType,
@@ -362,6 +363,9 @@ def to_sql(
362363
if isinstance(datatype, (FloatType, DoubleType)):
363364
if value is None:
364365
return "NULL :: FLOAT"
366+
if isinstance(datatype, DecFloatType):
367+
if value is None:
368+
return "NULL :: DECFLOAT"
365369
if isinstance(datatype, StringType):
366370
if value is None:
367371
return f"NULL :: {analyzer_utils.string(datatype.length)}"
@@ -403,6 +407,13 @@ def to_sql(
403407
if isinstance(datatype, BooleanType):
404408
return f"{value} :: BOOLEAN"
405409

410+
# DecFloatType must use DECFLOAT 'value' syntax (not value::DECFLOAT) to preserve precision
411+
if isinstance(datatype, DecFloatType):
412+
if isinstance(value, str):
413+
return f"DECFLOAT {str_to_sql(value)}"
414+
else:
415+
return f"DECFLOAT '{value}'"
416+
406417
if isinstance(value, float) and isinstance(datatype, _FractionalType):
407418
if math.isnan(value) or math.isinf(value):
408419
return float_nan_inf_to_sql(value)
@@ -495,6 +506,8 @@ def schema_expression(data_type: DataType, is_nullable: bool) -> str:
495506
return "PARSE_JSON('NULL') :: VARIANT"
496507
return "NULL :: " + convert_sp_to_sf_type(data_type)
497508

509+
if isinstance(data_type, DecFloatType):
510+
return "DECFLOAT '0'"
498511
if isinstance(data_type, _NumericType):
499512
return "0 :: " + convert_sp_to_sf_type(data_type)
500513
if isinstance(data_type, StringType):
@@ -596,6 +609,13 @@ def numeric_to_sql_without_cast(value: Any, datatype: DataType) -> str:
596609
# regular to_sql generation
597610
return to_sql(value, datatype)
598611

612+
# DecFloatType must always use DECFLOAT 'value' syntax to preserve precision
613+
if isinstance(datatype, DecFloatType):
614+
if isinstance(value, str):
615+
return f"DECFLOAT {str_to_sql(value)}"
616+
else:
617+
return f"DECFLOAT '{value}'"
618+
599619
if isinstance(value, float) and isinstance(datatype, _FractionalType):
600620
# when the float value is NAN or INF, a cast is still required
601621
if math.isnan(value) or math.isinf(value):

src/snowflake/snowpark/_internal/type_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
DateType,
5252
DayTimeInterval,
5353
DayTimeIntervalType,
54+
DecFloatType,
5455
DecimalType,
5556
DoubleType,
5657
FloatType,
@@ -188,7 +189,7 @@ def convert_metadata_to_sp_type(
188189
return convert_sf_to_sp_type(
189190
column_type_name,
190191
metadata.precision or 0,
191-
metadata.scale or 0,
192+
metadata.scale,
192193
metadata.internal_size or 0,
193194
max_string_size,
194195
)
@@ -197,7 +198,7 @@ def convert_metadata_to_sp_type(
197198
def convert_sf_to_sp_type(
198199
column_type_name: str,
199200
precision: int,
200-
scale: int,
201+
scale: int | None,
201202
internal_size: int,
202203
max_string_size: int,
203204
) -> DataType:
@@ -291,6 +292,8 @@ def convert_sf_to_sp_type(
291292
return TimestampType(timezone=TimestampTimeZone.TZ)
292293
if column_type_name == "DATE":
293294
return DateType()
295+
if column_type_name == "FIXED" and scale is None:
296+
return DecFloatType()
294297
if column_type_name == "DECIMAL" or (
295298
(column_type_name == "FIXED" or column_type_name == "NUMBER") and scale != 0
296299
):
@@ -333,6 +336,8 @@ def convert_sp_to_sf_type(datatype: DataType, nullable_override=None) -> str:
333336
return "FLOAT"
334337
if isinstance(datatype, DoubleType):
335338
return "DOUBLE"
339+
if isinstance(datatype, DecFloatType):
340+
return "DECFLOAT"
336341
# We regard NullType as String, which is required when creating
337342
# a dataframe from local data with all None values
338343
if isinstance(datatype, StringType):
@@ -845,6 +850,7 @@ def snow_type_to_dtype_str(snow_type: DataType) -> str:
845850
BooleanType,
846851
FloatType,
847852
DoubleType,
853+
DecFloatType,
848854
DateType,
849855
TimestampType,
850856
TimeType,

src/snowflake/snowpark/functions.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2118,6 +2118,31 @@ def to_double(
21182118
)
21192119

21202120

2121+
@publicapi
2122+
def to_decfloat(
2123+
e: ColumnOrName, fmt: Optional[ColumnOrLiteralStr] = None, _emit_ast: bool = True
2124+
) -> Column:
2125+
"""Converts an input expression to a decimal floating-point number.
2126+
2127+
Example::
2128+
>>> df = session.create_dataframe(['12', '11.3', '-90.12345'], schema=['a'])
2129+
>>> df.select(to_decfloat(col('a')).as_('ans')).collect()
2130+
[Row(ANS=Decimal('12.0')), Row(ANS=Decimal('11.3')), Row(ANS=Decimal('-90.12345'))]
2131+
"""
2132+
ast = (
2133+
build_function_expr("to_decfloat", [e] if fmt is None else [e, fmt])
2134+
if _emit_ast
2135+
else None
2136+
)
2137+
c = _to_col_if_str(e, "to_decfloat")
2138+
fmt_col = _to_col_if_lit(fmt, "to_decfloat") if fmt is not None else None
2139+
return (
2140+
_call_function("to_decfloat", c, _ast=ast, _emit_ast=_emit_ast)
2141+
if fmt_col is None
2142+
else _call_function("to_decfloat", c, fmt_col, _ast=ast, _emit_ast=_emit_ast)
2143+
)
2144+
2145+
21212146
@publicapi
21222147
def div0(
21232148
dividend: Union[ColumnOrName, int, float],

src/snowflake/snowpark/mock/_snowflake_data_type.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
BooleanType,
1212
DataType,
1313
DateType,
14+
DecFloatType,
1415
DecimalType,
1516
DoubleType,
1617
FloatType,
@@ -114,20 +115,27 @@ class SnowDataTypeConversion:
114115
SnowDataTypeConversion(BooleanType, DecimalType, True, False),
115116
SnowDataTypeConversion(BooleanType, StringType, True, True),
116117
SnowDataTypeConversion(BooleanType, VariantType, True, True),
118+
SnowDataTypeConversion(BooleanType, DecFloatType, True, True),
117119
SnowDataTypeConversion(DateType, TimestampType, True, False),
118120
SnowDataTypeConversion(DateType, StringType, True, True),
119121
SnowDataTypeConversion(DateType, VariantType, True, False),
120122
SnowDataTypeConversion(FloatType, BooleanType, True, True),
121123
SnowDataTypeConversion(FloatType, DecimalType, True, True),
122124
SnowDataTypeConversion(FloatType, StringType, True, True),
123125
SnowDataTypeConversion(FloatType, VariantType, True, True),
126+
SnowDataTypeConversion(FloatType, DecFloatType, True, True),
127+
SnowDataTypeConversion(DecFloatType, BooleanType, True, True),
128+
SnowDataTypeConversion(DecFloatType, DecimalType, True, True),
129+
SnowDataTypeConversion(DecFloatType, FloatType, True, True),
130+
SnowDataTypeConversion(DecFloatType, StringType, True, True),
124131
SnowDataTypeConversion(GeographyType, VariantType, True, False),
125132
# SnowDataTypeConversion(GeometryType, VariantType, True, False), # GeometryType isn't available yet.
126133
SnowDataTypeConversion(DecimalType, BooleanType, True, True),
127134
SnowDataTypeConversion(DecimalType, FloatType, True, True),
128135
SnowDataTypeConversion(DecimalType, TimestampType, True, True),
129136
SnowDataTypeConversion(DecimalType, StringType, True, True),
130137
SnowDataTypeConversion(DecimalType, VariantType, True, True),
138+
SnowDataTypeConversion(DecimalType, DecFloatType, True, True),
131139
SnowDataTypeConversion(MapType, ArrayType, True, False),
132140
SnowDataTypeConversion(MapType, StringType, True, False),
133141
SnowDataTypeConversion(MapType, VariantType, True, True),
@@ -140,6 +148,7 @@ class SnowDataTypeConversion:
140148
SnowDataTypeConversion(StringType, BooleanType, True, True),
141149
SnowDataTypeConversion(StringType, DateType, True, True),
142150
SnowDataTypeConversion(StringType, FloatType, True, True),
151+
SnowDataTypeConversion(StringType, DecFloatType, True, True),
143152
SnowDataTypeConversion(StringType, DecimalType, True, True),
144153
SnowDataTypeConversion(StringType, TimeType, True, True),
145154
SnowDataTypeConversion(StringType, TimestampType, True, True),
@@ -271,6 +280,9 @@ def calculate_type(c1: ColumnType, c2: ColumnType, op: Union[str]):
271280
parameters_info={"op": op},
272281
raise_error=NotImplementedError,
273282
)
283+
# DECFLOAT has highest numeric priority
284+
elif isinstance(t1, DecFloatType) or isinstance(t2, DecFloatType):
285+
return ColumnType(DecFloatType(), nullable)
274286
elif isinstance(t1, (FloatType, DoubleType)) or isinstance(
275287
t2, (FloatType, DoubleType)
276288
):

src/snowflake/snowpark/mock/_snowflake_to_pandas_converter.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
ByteType,
2222
DataType,
2323
DateType,
24+
DecFloatType,
2425
DecimalType,
2526
DoubleType,
2627
FloatType,
@@ -97,6 +98,19 @@ def _decimal_converter(
9798
)
9899

99100

101+
def _decfloat_converter(
102+
value: str, datatype: DataType, field_optionally_enclosed_by=None, null_if=None
103+
):
104+
if value is None or value == "" or (null_if is not None and value in null_if):
105+
return None
106+
try:
107+
return Decimal(value)
108+
except Exception as exc:
109+
SnowparkLocalTestingException.raise_from_error(
110+
exc, error_message=f"Numeric value '{value}' is not recognized."
111+
)
112+
113+
100114
def _bool_converter(
101115
value: str,
102116
datatype: DataType,
@@ -189,6 +203,7 @@ def _time_converter(
189203
ShortType: _integer_converter,
190204
DoubleType: _fraction_converter,
191205
FloatType: _fraction_converter,
206+
DecFloatType: _decfloat_converter,
192207
DecimalType: _decimal_converter,
193208
BooleanType: _bool_converter,
194209
DateType: _date_converter,

src/snowflake/snowpark/mock/_util.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
BooleanType,
2020
ByteType,
2121
DateType,
22+
DecFloatType,
2223
DecimalType,
2324
DoubleType,
2425
FloatType,
@@ -244,6 +245,7 @@ def fix_drift_between_column_sf_type_and_dtype(col: ColumnEmulator):
244245
BooleanType: bool,
245246
ByteType: numpy.int8 if not col.sf_type.nullable else "Int8",
246247
DateType: object,
248+
DecFloatType: numpy.float64,
247249
DecimalType: numpy.float64,
248250
DoubleType: numpy.float64,
249251
FloatType: numpy.float64,

src/snowflake/snowpark/types.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,13 @@ def _fill_ast(self, ast: proto.DataType) -> None:
461461
ast.double_type = True
462462

463463

464+
class DecFloatType(_FractionalType):
465+
"""DecFloat data type. This maps to the DECFLOAT data type in Snowflake."""
466+
467+
def _fill_ast(self, ast: proto.DataType) -> None:
468+
ast.decfloat_type = True
469+
470+
464471
class DecimalType(_FractionalType):
465472
"""Decimal data type. This maps to the NUMBER data type in Snowflake."""
466473

@@ -1108,6 +1115,7 @@ def _fill_ast(self, ast: proto.DataType) -> None:
11081115
BinaryType,
11091116
BooleanType,
11101117
DecimalType,
1118+
DecFloatType,
11111119
FloatType,
11121120
DoubleType,
11131121
ByteType,

tests/integ/test_datatypes.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from snowflake.snowpark.functions import lit
1414
from snowflake.snowpark.types import (
1515
BooleanType,
16+
DecFloatType,
1617
DecimalType,
1718
DoubleType,
1819
FloatType,
@@ -51,13 +52,14 @@ def test_basic_filter(session):
5152

5253
def test_plus_basic(session):
5354
df = session.create_dataframe(
54-
[[1, 1.1, 2.2, 3.3]],
55+
[[1, 1.1, 2.2, 3.3, 4.4]],
5556
schema=StructType(
5657
[
5758
StructField("a", LongType(), nullable=False),
5859
StructField("b", DecimalType(3, 1), nullable=False),
5960
StructField("c", DoubleType(), nullable=False),
6061
StructField("d", DecimalType(4, 2), nullable=False),
62+
StructField("e", DecFloatType(), nullable=False),
6163
]
6264
),
6365
)
@@ -66,27 +68,30 @@ def test_plus_basic(session):
6668
(df["a"] + 1).as_("new_a"),
6769
(df["b"] + df["d"]).as_("new_b"),
6870
(df["c"] + 3).as_("new_c"),
71+
(df["e"] + df["c"]).as_("new_e"),
6972
)
7073
assert repr(df.schema) == repr(
7174
StructType(
7275
[
7376
StructField("NEW_A", LongType(), nullable=False),
7477
StructField("NEW_B", DecimalType(5, 2), nullable=False),
7578
StructField("NEW_C", DoubleType(), nullable=False),
79+
StructField("NEW_E", DecFloatType(), nullable=False),
7680
]
7781
)
7882
)
7983

8084

8185
def test_minus_basic(session):
8286
df = session.create_dataframe(
83-
[[1, 1.1, 2.2, 3.3]],
87+
[[1, 1.1, 2.2, 3.3, 4.4]],
8488
schema=StructType(
8589
[
8690
StructField("a", LongType(), nullable=False),
8791
StructField("b", DecimalType(3, 1), nullable=False),
8892
StructField("c", DoubleType(), nullable=False),
8993
StructField("d", DecimalType(4, 2), nullable=False),
94+
StructField("e", DecFloatType(), nullable=False),
9095
]
9196
),
9297
)
@@ -95,27 +100,30 @@ def test_minus_basic(session):
95100
(df["a"] - 1).as_("new_a"),
96101
(df["b"] - df["d"]).as_("new_b"),
97102
(df["c"] - 3).as_("new_c"),
103+
(df["e"] - df["a"]).as_("new_e"),
98104
)
99105
assert repr(df.schema) == repr(
100106
StructType(
101107
[
102108
StructField("NEW_A", LongType(), nullable=False),
103109
StructField("NEW_B", DecimalType(5, 2), nullable=False),
104110
StructField("NEW_C", DoubleType(), nullable=False),
111+
StructField("NEW_E", DecFloatType(), nullable=False),
105112
]
106113
)
107114
)
108115

109116

110117
def test_multiple_basic(session):
111118
df = session.create_dataframe(
112-
[[1, 1.1, 2.2, 3.3]],
119+
[[1, 1.1, 2.2, 3.3, 4.4]],
113120
schema=StructType(
114121
[
115122
StructField("a", LongType(), nullable=False),
116123
StructField("b", DecimalType(3, 1), nullable=False),
117124
StructField("c", FloatType(), nullable=False),
118125
StructField("d", DecimalType(4, 2), nullable=False),
126+
StructField("e", DecFloatType(), nullable=False),
119127
]
120128
),
121129
)
@@ -124,13 +132,15 @@ def test_multiple_basic(session):
124132
(df["a"] * 1).as_("new_a"),
125133
(df["b"] * df["d"]).as_("new_b"),
126134
(df["c"] * 3).as_("new_c"),
135+
(df["e"] * df["b"]).as_("new_e"),
127136
)
128137
assert repr(df.schema) == repr(
129138
StructType(
130139
[
131140
StructField("NEW_A", LongType(), nullable=False),
132141
StructField("NEW_B", DecimalType(7, 3), nullable=False),
133142
StructField("NEW_C", DoubleType(), nullable=False),
143+
StructField("NEW_E", DecFloatType(), nullable=False),
134144
]
135145
)
136146
)

0 commit comments

Comments
 (0)