Skip to content

Commit 0afe5cf

Browse files
committed
add default precision
1 parent e7fe97c commit 0afe5cf

File tree

3 files changed

+30
-1
lines changed

3 files changed

+30
-1
lines changed

src/snowflake/snowpark/context.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@
4343
# Global flag for fix 2360274. When enabled schema queries will use NULL as a place holder for any values inside structured objects
4444
_enable_fix_2360274 = False
4545

46+
# internal only dictionary store the default precision of integral types, if the type does not appear in the
47+
# dictionary, the default precision is None.
48+
# example: _integral_type_default_precision = {"IntegerType": 9}, IntegerType default _precision is 9 now
49+
_integral_type_default_precision = {}
50+
4651

4752
def configure_development_features(
4853
*,

src/snowflake/snowpark/types.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,10 @@ def _fill_ast(self, ast: proto.DataType) -> None:
371371
# Numeric types
372372
class _IntegralType(_NumericType):
373373
def __init__(self, **kwargs) -> None:
374-
self._precision = kwargs.pop("_precision", None)
374+
class_name = type(self).__name__
375+
self._precision = kwargs.pop(
376+
"_precision", context._integral_type_default_precision.get(class_name, None)
377+
)
375378

376379
if kwargs != {}:
377380
raise TypeError(

tests/integ/test_datatypes.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
StringType,
2121
StructField,
2222
StructType,
23+
IntegerType,
24+
ShortType,
2325
)
2426
from tests.utils import Utils
2527

@@ -524,3 +526,22 @@ def test_write_to_sf_with_correct_precision(session, precision):
524526
result = session.sql(f"select * from {table_name}")
525527
datatype = result.schema.fields[0].datatype
526528
assert datatype._precision == precision
529+
530+
531+
@pytest.mark.parametrize(
532+
"mock_dict",
533+
[
534+
{"IntegerType": 5, "LongType": 4},
535+
{"LongType": 19, "IntegerType": 10},
536+
],
537+
)
538+
def test_integral_type_default_precision(mock_dict):
539+
with mock.patch.object(context, "_integral_type_default_precision", mock_dict):
540+
integer_type = IntegerType()
541+
assert integer_type._precision == mock_dict["IntegerType"]
542+
543+
long_type = LongType()
544+
assert long_type._precision == mock_dict["LongType"]
545+
546+
short_type = ShortType()
547+
assert short_type._precision is None

0 commit comments

Comments
 (0)