Skip to content

Commit 2247ae5

Browse files
authored
SNOW-1853347: Add mechanism to allow changing type strs when printing schema (#2819)
1 parent 888cec5 commit 2247ae5

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

src/snowflake/snowpark/dataframe.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5655,7 +5655,10 @@ def convert(col: ColumnOrName) -> Expression:
56555655
return exprs
56565656

56575657
def _format_schema(
5658-
self, level: Optional[int] = None, translate_columns: Optional[dict] = None
5658+
self,
5659+
level: Optional[int] = None,
5660+
translate_columns: Optional[dict] = None,
5661+
translate_types: Optional[dict] = None,
56595662
) -> str:
56605663
def _format_datatype(name, dtype, nullable=None, depth=0):
56615664
if level is not None and depth >= level:
@@ -5669,6 +5672,10 @@ def _format_datatype(name, dtype, nullable=None, depth=0):
56695672
extra_lines = []
56705673
type_str = dtype.__class__.__name__
56715674

5675+
translated = None
5676+
if translate_types:
5677+
translated = translate_types.get(type_str, type_str)
5678+
56725679
# Structured Type format their parameters on multiple lines.
56735680
if isinstance(dtype, ArrayType):
56745681
extra_lines = [
@@ -5695,7 +5702,7 @@ def _format_datatype(name, dtype, nullable=None, depth=0):
56955702

56965703
return "\n".join(
56975704
[
5698-
f"{prefix} |-- {name}: {type_str}{nullable_str}",
5705+
f"{prefix} |-- {name}: {translated or type_str}{nullable_str}",
56995706
]
57005707
+ [f"{line}" for line in extra_lines if line]
57015708
)

tests/integ/scala/test_datatype_suite.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,6 +1077,19 @@ def test_structured_type_print_schema(
10771077
== 'root\n |-- "map": MapType (nullable = True)'
10781078
)
10791079

1080+
# Check that column types can be translated
1081+
assert (
1082+
df._format_schema(
1083+
2,
1084+
translate_types={
1085+
"MapType": "dict",
1086+
"StringType": "str",
1087+
"ArrayType": "list",
1088+
},
1089+
)
1090+
== 'root\n |-- "MAP": dict (nullable = True)\n | |-- key: str\n | |-- value: list'
1091+
)
1092+
10801093

10811094
@pytest.mark.skipif(
10821095
"config.getoption('local_testing_mode', default=False)",

0 commit comments

Comments
 (0)