Skip to content

Commit 7c66276

Browse files
committed
NO-SNOW: Add nullifzero function support for local testing (#4036)
1 parent 7d70182 commit 7c66276

File tree

2 files changed

+42
-6
lines changed

2 files changed

+42
-6
lines changed

src/snowflake/snowpark/mock/_functions.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,6 @@ def _check_constant_result(self, input_data, args, result):
114114
return result
115115

116116
def __call__(self, *args, input_data=None, row_number=None, **kwargs):
117-
118117
if self._pass_input_data:
119118
kwargs["raw_input"] = input_data
120119
if self._pass_row_index:
@@ -393,7 +392,7 @@ def mock_stddev(column: ColumnEmulator) -> ColumnEmulator:
393392

394393
@patch("approx_percentile_accumulate")
395394
def mock_approx_percentile_accumulate(
396-
column: Union[TableEmulator, ColumnEmulator]
395+
column: Union[TableEmulator, ColumnEmulator],
397396
) -> ColumnEmulator:
398397
# TODO SNOW-1800512: Fix, returns dummy of 42 for now.
399398
_logger.warning("TODO SNOW-1800512: Returns dummy value of 42 now, need to fix.")
@@ -707,6 +706,18 @@ def mock_abs(expr):
707706
return abs(expr)
708707

709708

709+
@patch("nullifzero")
710+
def mock_nullifzero(expr: ColumnEmulator) -> ColumnEmulator:
711+
def convert_zero_to_null(value):
712+
if value == 0 or value == 0.0:
713+
return None
714+
return value
715+
716+
result = expr.apply(convert_zero_to_null)
717+
result.sf_type = ColumnType(expr.sf_type.datatype, nullable=True)
718+
return result
719+
720+
710721
@patch("to_decimal")
711722
def mock_to_decimal(
712723
e: ColumnEmulator,
@@ -974,7 +985,6 @@ def convert_timestamp(row):
974985
if data is None:
975986
return None
976987
try:
977-
978988
datatype = column.sf_type.datatype
979989
if isinstance(datatype, TimestampType):
980990
# data is datetime.datetime type
@@ -1206,7 +1216,10 @@ def convert_char(data, _fmt):
12061216
return try_convert(convert_numeric_to_str, try_cast, data)
12071217
elif isinstance(source_datatype, (DateType, TimeType)):
12081218
default_format = _DEFAULT_OUTPUT_FORMAT.get(type(source_datatype))
1209-
(format, _,) = convert_snowflake_datetime_format(
1219+
(
1220+
format,
1221+
_,
1222+
) = convert_snowflake_datetime_format(
12101223
_fmt, default_format=default_format, is_input_format=False
12111224
)
12121225
convert_date_time_to_str = (
@@ -1219,7 +1232,10 @@ def convert_char(data, _fmt):
12191232
)
12201233
elif isinstance(source_datatype, TimestampType):
12211234
default_format = _DEFAULT_OUTPUT_FORMAT.get(TimestampType)
1222-
(format, fractional_seconds,) = convert_snowflake_datetime_format(
1235+
(
1236+
format,
1237+
fractional_seconds,
1238+
) = convert_snowflake_datetime_format(
12231239
_fmt, default_format, is_input_format=False
12241240
)
12251241
# handle 3f, can use str index

tests/mock/test_functions.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
lit,
3030
max,
3131
min,
32+
nullifzero,
3233
rank,
3334
row_number,
3435
sum,
@@ -40,7 +41,6 @@
4041
from snowflake.snowpark.mock.exceptions import SnowparkLocalTestingException
4142
from snowflake.snowpark.types import IntegerType
4243
from snowflake.snowpark.window import Window
43-
4444
from tests.utils import Utils
4545

4646

@@ -161,6 +161,26 @@ def test_abs(session):
161161
assert origin_df.select(abs(col("m"))).collect() == [Row(1), Row(1), Row(2)]
162162

163163

164+
def test_nullifzero(session):
165+
origin_df: DataFrame = session.create_dataframe(
166+
[
167+
[0],
168+
[1],
169+
[0.0],
170+
[100],
171+
[-5],
172+
],
173+
schema=["value"],
174+
)
175+
assert origin_df.select(nullifzero(col("value"))).collect() == [
176+
Row(None),
177+
Row(1),
178+
Row(None),
179+
Row(100),
180+
Row(-5),
181+
]
182+
183+
164184
def test_asc_and_desc(session):
165185
origin_df: DataFrame = session.create_dataframe(
166186
[

0 commit comments

Comments
 (0)