Skip to content

Commit 183f27c

Browse files
SNOW-2887937: fix dataframewriter loses decimal precision info when writing back to snowflake table (#4024)
1 parent b254146 commit 183f27c

File tree

4 files changed

+182
-3
lines changed

4 files changed

+182
-3
lines changed

src/snowflake/snowpark/_internal/type_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,9 @@ def convert_sf_to_sp_type(
316316

317317

318318
def convert_sp_to_sf_type(datatype: DataType, nullable_override=None) -> str:
319+
if context._is_snowpark_connect_compatible_mode:
320+
if isinstance(datatype, _IntegralType) and datatype._precision is not None:
321+
return f"NUMBER({datatype._precision}, 0)"
319322
if isinstance(datatype, DecimalType):
320323
return f"NUMBER({datatype.precision}, {datatype.scale})"
321324
if isinstance(datatype, IntegerType):

src/snowflake/snowpark/context.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@
4343
# Global flag for fix 2360274. When enabled schema queries will use NULL as a place holder for any values inside structured objects
4444
_enable_fix_2360274 = False
4545

46+
# internal only dictionary store the default precision of integral types, if the type does not appear in the
47+
# dictionary, the default precision is None.
48+
# example: _integral_type_default_precision = {IntegerType: 9}, IntegerType default _precision is 9 now
49+
_integral_type_default_precision = {}
50+
4651

4752
def configure_development_features(
4853
*,

src/snowflake/snowpark/types.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,9 @@ def _fill_ast(self, ast: proto.DataType) -> None:
371371
# Numeric types
372372
class _IntegralType(_NumericType):
373373
def __init__(self, **kwargs) -> None:
374-
self._precision = kwargs.pop("_precision", None)
374+
self._precision = kwargs.pop(
375+
"_precision", context._integral_type_default_precision.get(type(self), None)
376+
)
375377

376378
if kwargs != {}:
377379
raise TypeError(

tests/integ/test_datatypes.py

Lines changed: 171 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
import os
66
import tempfile
77
from decimal import Decimal
8+
from unittest import mock
89

910
import pytest
1011

11-
from snowflake.snowpark import DataFrame, Row
12+
from snowflake.snowpark import DataFrame, Row, context
1213
from snowflake.snowpark.functions import lit
1314
from snowflake.snowpark.types import (
1415
BooleanType,
@@ -19,6 +20,8 @@
1920
StringType,
2021
StructField,
2122
StructType,
23+
IntegerType,
24+
ShortType,
2225
)
2326
from tests.utils import Utils
2427

@@ -437,7 +440,6 @@ def test_numeric_type_store_precision_and_scale(session, massive_number, precisi
437440
# does not have precision information, thus set to default 38.
438441
df.write.save_as_table(table_name, mode="overwrite", table_type="temp")
439442
result = session.sql(f"select * from {table_name}")
440-
session.sql(f"describe table {table_name}").show()
441443
datatype = result.schema.fields[0].datatype
442444
assert isinstance(datatype, LongType)
443445
assert datatype._precision == 38
@@ -502,3 +504,170 @@ def write_csv(data):
502504
def test_illegal_argument_intergraltype():
503505
with pytest.raises(TypeError, match="takes 0 argument but 1 were given"):
504506
LongType(b=10)
507+
508+
509+
@pytest.mark.skipif(
510+
"config.getoption('local_testing_mode', default=False)",
511+
reason="session.sql not supported by local testing mode",
512+
)
513+
@pytest.mark.parametrize("precision", [38, 19, 5, 3])
514+
def test_write_to_sf_with_correct_precision(session, precision):
515+
table_name = Utils.random_table_name()
516+
517+
with mock.patch.object(context, "_is_snowpark_connect_compatible_mode", True):
518+
df = session.create_dataframe(
519+
[],
520+
StructType([StructField("large_value", DecimalType(precision, 0), True)]),
521+
)
522+
datatype = df.schema.fields[0].datatype
523+
assert datatype._precision == precision
524+
525+
df.write.save_as_table(table_name, mode="overwrite", table_type="temp")
526+
result = session.sql(f"select * from {table_name}")
527+
datatype = result.schema.fields[0].datatype
528+
assert datatype._precision == precision
529+
530+
531+
@pytest.mark.parametrize(
532+
"mock_default_precision",
533+
[
534+
{IntegerType: 5, LongType: 4},
535+
{LongType: 19, IntegerType: 10},
536+
],
537+
)
538+
def test_integral_type_default_precision(mock_default_precision):
539+
with mock.patch(
540+
"snowflake.snowpark.context._integral_type_default_precision",
541+
mock_default_precision,
542+
):
543+
integer_type = IntegerType()
544+
assert integer_type._precision == mock_default_precision[IntegerType]
545+
546+
long_type = LongType()
547+
assert long_type._precision == mock_default_precision[LongType]
548+
549+
short_type = ShortType()
550+
assert short_type._precision is None
551+
552+
553+
@pytest.mark.skipif(
554+
"config.getoption('local_testing_mode', default=False)",
555+
reason="session.sql not supported by local testing mode",
556+
)
557+
@pytest.mark.parametrize(
558+
"mock_default_precision",
559+
[
560+
{IntegerType: 5, LongType: 4},
561+
{LongType: 19, IntegerType: 10},
562+
],
563+
)
564+
def test_end_to_end_default_precision(session, mock_default_precision):
565+
table_name = Utils.random_table_name()
566+
567+
with mock.patch.object(
568+
context, "_is_snowpark_connect_compatible_mode", True
569+
), mock.patch.object(
570+
context, "_integral_type_default_precision", mock_default_precision
571+
):
572+
573+
schema = StructType(
574+
[
575+
StructField("D38", DecimalType(38, 0), True),
576+
StructField("D19", DecimalType(19, 0), True),
577+
StructField("D5", DecimalType(5, 0), True),
578+
StructField("D3", DecimalType(3, 0), True),
579+
StructField("integer_value", IntegerType(), True),
580+
StructField("long_value", LongType(), True),
581+
]
582+
)
583+
584+
df = session.create_dataframe(
585+
[],
586+
schema,
587+
)
588+
assert df.schema.fields[0].datatype._precision == 38
589+
assert df.schema.fields[1].datatype._precision == 19
590+
assert df.schema.fields[2].datatype._precision == 5
591+
assert df.schema.fields[3].datatype._precision == 3
592+
assert (
593+
df.schema.fields[4].datatype._precision
594+
== mock_default_precision[IntegerType]
595+
)
596+
assert (
597+
df.schema.fields[5].datatype._precision == mock_default_precision[LongType]
598+
)
599+
600+
df.write.save_as_table(table_name, mode="overwrite", table_type="temp")
601+
result = session.sql(f"select * from {table_name}")
602+
assert result.schema.fields[0].datatype._precision == 38
603+
assert result.schema.fields[1].datatype._precision == 19
604+
assert result.schema.fields[2].datatype._precision == 5
605+
assert result.schema.fields[3].datatype._precision == 3
606+
assert (
607+
result.schema.fields[4].datatype._precision
608+
== mock_default_precision[IntegerType]
609+
)
610+
assert (
611+
result.schema.fields[5].datatype._precision
612+
== mock_default_precision[LongType]
613+
)
614+
615+
616+
@pytest.mark.skipif(
617+
"config.getoption('local_testing_mode', default=False)",
618+
reason="relaxed_types not supported by local testing mode",
619+
)
620+
@pytest.mark.parametrize("massive_number", ["9" * 38, "5" * 19, "7" * 5])
621+
def test_default_precision_read_file(session, massive_number):
622+
mock_default_precision = {LongType: 19, IntegerType: 10}
623+
with mock.patch.object(
624+
context, "_is_snowpark_connect_compatible_mode", True
625+
), mock.patch.object(
626+
context, "_integral_type_default_precision", mock_default_precision
627+
):
628+
stage_name = Utils.random_stage_name()
629+
header = ("BIG_NUM",)
630+
test_data = [(massive_number,)]
631+
632+
def write_csv(data):
633+
with tempfile.NamedTemporaryFile(
634+
mode="w+",
635+
delete=False,
636+
suffix=".csv",
637+
newline="",
638+
) as file:
639+
writer = csv.writer(file)
640+
writer.writerow(header)
641+
for row in data:
642+
writer.writerow(row)
643+
return file.name
644+
645+
file_path = write_csv(test_data)
646+
647+
try:
648+
Utils.create_stage(session, stage_name, is_temporary=True)
649+
result = session.file.put(
650+
file_path, f"@{stage_name}", auto_compress=False, overwrite=True
651+
)
652+
653+
# Infer schema from only the short file
654+
constrained_reader = session.read.options(
655+
{
656+
"INFER_SCHEMA": True,
657+
"INFER_SCHEMA_OPTIONS": {"FILES": [result[0].target]},
658+
"PARSE_HEADER": True,
659+
# Only load the short file
660+
"PATTERN": f".*{result[0].target}",
661+
}
662+
)
663+
664+
# df1 uses constrained types
665+
df1 = constrained_reader.csv(f"@{stage_name}/")
666+
datatype = df1.schema.fields[0].datatype
667+
assert isinstance(datatype, LongType)
668+
assert datatype._precision == len(massive_number)
669+
670+
finally:
671+
Utils.drop_stage(session, stage_name)
672+
if os.path.exists(file_path):
673+
os.remove(file_path)

0 commit comments

Comments
 (0)