|
5 | 5 | import os |
6 | 6 | import tempfile |
7 | 7 | from decimal import Decimal |
| 8 | +from unittest import mock |
8 | 9 |
|
9 | 10 | import pytest |
10 | 11 |
|
11 | | -from snowflake.snowpark import DataFrame, Row |
| 12 | +from snowflake.snowpark import DataFrame, Row, context |
12 | 13 | from snowflake.snowpark.functions import lit |
13 | 14 | from snowflake.snowpark.types import ( |
14 | 15 | BooleanType, |
@@ -437,7 +438,6 @@ def test_numeric_type_store_precision_and_scale(session, massive_number, precisi |
437 | 438 | # does not have precision information, thus set to default 38. |
438 | 439 | df.write.save_as_table(table_name, mode="overwrite", table_type="temp") |
439 | 440 | result = session.sql(f"select * from {table_name}") |
440 | | - session.sql(f"describe table {table_name}").show() |
441 | 441 | datatype = result.schema.fields[0].datatype |
442 | 442 | assert isinstance(datatype, LongType) |
443 | 443 | assert datatype._precision == 38 |
@@ -502,3 +502,25 @@ def write_csv(data): |
502 | 502 | def test_illegal_argument_intergraltype(): |
503 | 503 | with pytest.raises(TypeError, match="takes 0 argument but 1 were given"): |
504 | 504 | LongType(b=10) |
| 505 | + |
| 506 | + |
| 507 | +@pytest.mark.skipif( |
| 508 | + "config.getoption('local_testing_mode', default=False)", |
| 509 | + reason="session.sql not supported by local testing mode", |
| 510 | +) |
| 511 | +@pytest.mark.parametrize("precision", [38, 19, 5, 3]) |
| 512 | +def test_write_to_sf_with_correct_precision(session, precision): |
| 513 | + table_name = Utils.random_table_name() |
| 514 | + |
| 515 | + with mock.patch.object(context, "_is_snowpark_connect_compatible_mode", True): |
| 516 | + df = session.create_dataframe( |
| 517 | + [], |
| 518 | + StructType([StructField("large_value", DecimalType(precision, 0), True)]), |
| 519 | + ) |
| 520 | + datatype = df.schema.fields[0].datatype |
| 521 | + assert datatype._precision == precision |
| 522 | + |
| 523 | + df.write.save_as_table(table_name, mode="overwrite", table_type="temp") |
| 524 | + result = session.sql(f"select * from {table_name}") |
| 525 | + datatype = result.schema.fields[0].datatype |
| 526 | + assert datatype._precision == precision |
0 commit comments