Skip to content

Commit e7fe97c

Browse files
committed
store as decimal with precision in compatiable mode
1 parent af1eef4 commit e7fe97c

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

src/snowflake/snowpark/_internal/type_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,9 @@ def convert_sf_to_sp_type(
316316

317317

318318
def convert_sp_to_sf_type(datatype: DataType, nullable_override=None) -> str:
319+
if context._is_snowpark_connect_compatible_mode:
320+
if isinstance(datatype, _IntegralType) and datatype._precision is not None:
321+
return f"NUMBER({datatype._precision}, 0)"
319322
if isinstance(datatype, DecimalType):
320323
return f"NUMBER({datatype.precision}, {datatype.scale})"
321324
if isinstance(datatype, IntegerType):

tests/integ/test_datatypes.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
import os
66
import tempfile
77
from decimal import Decimal
8+
from unittest import mock
89

910
import pytest
1011

11-
from snowflake.snowpark import DataFrame, Row
12+
from snowflake.snowpark import DataFrame, Row, context
1213
from snowflake.snowpark.functions import lit
1314
from snowflake.snowpark.types import (
1415
BooleanType,
@@ -437,7 +438,6 @@ def test_numeric_type_store_precision_and_scale(session, massive_number, precisi
437438
# does not have precision information, thus set to default 38.
438439
df.write.save_as_table(table_name, mode="overwrite", table_type="temp")
439440
result = session.sql(f"select * from {table_name}")
440-
session.sql(f"describe table {table_name}").show()
441441
datatype = result.schema.fields[0].datatype
442442
assert isinstance(datatype, LongType)
443443
assert datatype._precision == 38
@@ -502,3 +502,25 @@ def write_csv(data):
502502
def test_illegal_argument_intergraltype():
503503
with pytest.raises(TypeError, match="takes 0 argument but 1 were given"):
504504
LongType(b=10)
505+
506+
507+
@pytest.mark.skipif(
508+
"config.getoption('local_testing_mode', default=False)",
509+
reason="session.sql not supported by local testing mode",
510+
)
511+
@pytest.mark.parametrize("precision", [38, 19, 5, 3])
512+
def test_write_to_sf_with_correct_precision(session, precision):
513+
table_name = Utils.random_table_name()
514+
515+
with mock.patch.object(context, "_is_snowpark_connect_compatible_mode", True):
516+
df = session.create_dataframe(
517+
[],
518+
StructType([StructField("large_value", DecimalType(precision, 0), True)]),
519+
)
520+
datatype = df.schema.fields[0].datatype
521+
assert datatype._precision == precision
522+
523+
df.write.save_as_table(table_name, mode="overwrite", table_type="temp")
524+
result = session.sql(f"select * from {table_name}")
525+
datatype = result.schema.fields[0].datatype
526+
assert datatype._precision == precision

0 commit comments

Comments
 (0)