|
5 | 5 | import os |
6 | 6 | import tempfile |
7 | 7 | from decimal import Decimal |
| 8 | +from unittest import mock |
8 | 9 |
|
9 | 10 | import pytest |
10 | 11 |
|
11 | | -from snowflake.snowpark import DataFrame, Row |
| 12 | +from snowflake.snowpark import DataFrame, Row, context |
12 | 13 | from snowflake.snowpark.functions import lit |
13 | 14 | from snowflake.snowpark.types import ( |
14 | 15 | BooleanType, |
|
19 | 20 | StringType, |
20 | 21 | StructField, |
21 | 22 | StructType, |
| 23 | + IntegerType, |
| 24 | + ShortType, |
22 | 25 | ) |
23 | 26 | from tests.utils import Utils |
24 | 27 |
|
@@ -437,7 +440,6 @@ def test_numeric_type_store_precision_and_scale(session, massive_number, precisi |
437 | 440 | # does not have precision information, thus set to default 38. |
438 | 441 | df.write.save_as_table(table_name, mode="overwrite", table_type="temp") |
439 | 442 | result = session.sql(f"select * from {table_name}") |
440 | | - session.sql(f"describe table {table_name}").show() |
441 | 443 | datatype = result.schema.fields[0].datatype |
442 | 444 | assert isinstance(datatype, LongType) |
443 | 445 | assert datatype._precision == 38 |
@@ -502,3 +504,170 @@ def write_csv(data): |
502 | 504 | def test_illegal_argument_intergraltype(): |
503 | 505 | with pytest.raises(TypeError, match="takes 0 argument but 1 were given"): |
504 | 506 | LongType(b=10) |
| 507 | + |
| 508 | + |
| 509 | +@pytest.mark.skipif( |
| 510 | + "config.getoption('local_testing_mode', default=False)", |
| 511 | + reason="session.sql not supported by local testing mode", |
| 512 | +) |
| 513 | +@pytest.mark.parametrize("precision", [38, 19, 5, 3]) |
| 514 | +def test_write_to_sf_with_correct_precision(session, precision): |
| 515 | + table_name = Utils.random_table_name() |
| 516 | + |
| 517 | + with mock.patch.object(context, "_is_snowpark_connect_compatible_mode", True): |
| 518 | + df = session.create_dataframe( |
| 519 | + [], |
| 520 | + StructType([StructField("large_value", DecimalType(precision, 0), True)]), |
| 521 | + ) |
| 522 | + datatype = df.schema.fields[0].datatype |
| 523 | + assert datatype._precision == precision |
| 524 | + |
| 525 | + df.write.save_as_table(table_name, mode="overwrite", table_type="temp") |
| 526 | + result = session.sql(f"select * from {table_name}") |
| 527 | + datatype = result.schema.fields[0].datatype |
| 528 | + assert datatype._precision == precision |
| 529 | + |
| 530 | + |
| 531 | +@pytest.mark.parametrize( |
| 532 | + "mock_default_precision", |
| 533 | + [ |
| 534 | + {IntegerType: 5, LongType: 4}, |
| 535 | + {LongType: 19, IntegerType: 10}, |
| 536 | + ], |
| 537 | +) |
| 538 | +def test_integral_type_default_precision(mock_default_precision): |
| 539 | + with mock.patch( |
| 540 | + "snowflake.snowpark.context._integral_type_default_precision", |
| 541 | + mock_default_precision, |
| 542 | + ): |
| 543 | + integer_type = IntegerType() |
| 544 | + assert integer_type._precision == mock_default_precision[IntegerType] |
| 545 | + |
| 546 | + long_type = LongType() |
| 547 | + assert long_type._precision == mock_default_precision[LongType] |
| 548 | + |
| 549 | + short_type = ShortType() |
| 550 | + assert short_type._precision is None |
| 551 | + |
| 552 | + |
| 553 | +@pytest.mark.skipif( |
| 554 | + "config.getoption('local_testing_mode', default=False)", |
| 555 | + reason="session.sql not supported by local testing mode", |
| 556 | +) |
| 557 | +@pytest.mark.parametrize( |
| 558 | + "mock_default_precision", |
| 559 | + [ |
| 560 | + {IntegerType: 5, LongType: 4}, |
| 561 | + {LongType: 19, IntegerType: 10}, |
| 562 | + ], |
| 563 | +) |
| 564 | +def test_end_to_end_default_precision(session, mock_default_precision): |
| 565 | + table_name = Utils.random_table_name() |
| 566 | + |
| 567 | + with mock.patch.object( |
| 568 | + context, "_is_snowpark_connect_compatible_mode", True |
| 569 | + ), mock.patch.object( |
| 570 | + context, "_integral_type_default_precision", mock_default_precision |
| 571 | + ): |
| 572 | + |
| 573 | + schema = StructType( |
| 574 | + [ |
| 575 | + StructField("D38", DecimalType(38, 0), True), |
| 576 | + StructField("D19", DecimalType(19, 0), True), |
| 577 | + StructField("D5", DecimalType(5, 0), True), |
| 578 | + StructField("D3", DecimalType(3, 0), True), |
| 579 | + StructField("integer_value", IntegerType(), True), |
| 580 | + StructField("long_value", LongType(), True), |
| 581 | + ] |
| 582 | + ) |
| 583 | + |
| 584 | + df = session.create_dataframe( |
| 585 | + [], |
| 586 | + schema, |
| 587 | + ) |
| 588 | + assert df.schema.fields[0].datatype._precision == 38 |
| 589 | + assert df.schema.fields[1].datatype._precision == 19 |
| 590 | + assert df.schema.fields[2].datatype._precision == 5 |
| 591 | + assert df.schema.fields[3].datatype._precision == 3 |
| 592 | + assert ( |
| 593 | + df.schema.fields[4].datatype._precision |
| 594 | + == mock_default_precision[IntegerType] |
| 595 | + ) |
| 596 | + assert ( |
| 597 | + df.schema.fields[5].datatype._precision == mock_default_precision[LongType] |
| 598 | + ) |
| 599 | + |
| 600 | + df.write.save_as_table(table_name, mode="overwrite", table_type="temp") |
| 601 | + result = session.sql(f"select * from {table_name}") |
| 602 | + assert result.schema.fields[0].datatype._precision == 38 |
| 603 | + assert result.schema.fields[1].datatype._precision == 19 |
| 604 | + assert result.schema.fields[2].datatype._precision == 5 |
| 605 | + assert result.schema.fields[3].datatype._precision == 3 |
| 606 | + assert ( |
| 607 | + result.schema.fields[4].datatype._precision |
| 608 | + == mock_default_precision[IntegerType] |
| 609 | + ) |
| 610 | + assert ( |
| 611 | + result.schema.fields[5].datatype._precision |
| 612 | + == mock_default_precision[LongType] |
| 613 | + ) |
| 614 | + |
| 615 | + |
| 616 | +@pytest.mark.skipif( |
| 617 | + "config.getoption('local_testing_mode', default=False)", |
| 618 | + reason="relaxed_types not supported by local testing mode", |
| 619 | +) |
| 620 | +@pytest.mark.parametrize("massive_number", ["9" * 38, "5" * 19, "7" * 5]) |
| 621 | +def test_default_precision_read_file(session, massive_number): |
| 622 | + mock_default_precision = {LongType: 19, IntegerType: 10} |
| 623 | + with mock.patch.object( |
| 624 | + context, "_is_snowpark_connect_compatible_mode", True |
| 625 | + ), mock.patch.object( |
| 626 | + context, "_integral_type_default_precision", mock_default_precision |
| 627 | + ): |
| 628 | + stage_name = Utils.random_stage_name() |
| 629 | + header = ("BIG_NUM",) |
| 630 | + test_data = [(massive_number,)] |
| 631 | + |
| 632 | + def write_csv(data): |
| 633 | + with tempfile.NamedTemporaryFile( |
| 634 | + mode="w+", |
| 635 | + delete=False, |
| 636 | + suffix=".csv", |
| 637 | + newline="", |
| 638 | + ) as file: |
| 639 | + writer = csv.writer(file) |
| 640 | + writer.writerow(header) |
| 641 | + for row in data: |
| 642 | + writer.writerow(row) |
| 643 | + return file.name |
| 644 | + |
| 645 | + file_path = write_csv(test_data) |
| 646 | + |
| 647 | + try: |
| 648 | + Utils.create_stage(session, stage_name, is_temporary=True) |
| 649 | + result = session.file.put( |
| 650 | + file_path, f"@{stage_name}", auto_compress=False, overwrite=True |
| 651 | + ) |
| 652 | + |
| 653 | + # Infer schema from only the short file |
| 654 | + constrained_reader = session.read.options( |
| 655 | + { |
| 656 | + "INFER_SCHEMA": True, |
| 657 | + "INFER_SCHEMA_OPTIONS": {"FILES": [result[0].target]}, |
| 658 | + "PARSE_HEADER": True, |
| 659 | + # Only load the short file |
| 660 | + "PATTERN": f".*{result[0].target}", |
| 661 | + } |
| 662 | + ) |
| 663 | + |
| 664 | + # df1 uses constrained types |
| 665 | + df1 = constrained_reader.csv(f"@{stage_name}/") |
| 666 | + datatype = df1.schema.fields[0].datatype |
| 667 | + assert isinstance(datatype, LongType) |
| 668 | + assert datatype._precision == len(massive_number) |
| 669 | + |
| 670 | + finally: |
| 671 | + Utils.drop_stage(session, stage_name) |
| 672 | + if os.path.exists(file_path): |
| 673 | + os.remove(file_path) |
0 commit comments