Skip to content

Commit 0defc75

Browse files
committed
modify __eq__ instead of feature flag
1 parent eacb904 commit 0defc75

File tree

4 files changed

+45
-38
lines changed

4 files changed

+45
-38
lines changed

src/snowflake/snowpark/context.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
# If _should_continue_registration is not None, i.e. a caller environment has assigned it an alternate callable, then the callback is responsible for determining the rest of the Snowpark workflow.
2525
_should_continue_registration: Optional[Callable[..., bool]] = None
2626

27-
_store_precision_and_scale_in_numeric_type: bool = False
2827

2928
# Internal-only global flag that determines if structured type semantics should be used
3029
_use_structured_type_semantics = False

src/snowflake/snowpark/modin/plugin/_internal/snowpark_pandas_types.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,12 @@ def __init__(self) -> None:
133133
super().__init__()
134134

135135
def __eq__(self, other: Any) -> bool:
136-
return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
136+
def filtered(d: dict) -> dict:
137+
return {k: v for k, v in d.items() if k not in ("_precision", "_scale")}
138+
139+
return isinstance(other, self.__class__) and filtered(
140+
self.__dict__
141+
) == filtered(other.__dict__)
137142

138143
def __ne__(self, other: Any) -> bool:
139144
return not self.__eq__(other)

src/snowflake/snowpark/types.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,14 @@ def __init__(self, **kwargs) -> None:
173173
self._precision = kwargs.get("precision", None)
174174
self._scale = kwargs.get("scale", None)
175175

176+
def __eq__(self, other):
177+
def filtered(d: dict) -> dict:
178+
return {k: v for k, v in d.items() if k not in ("_precision", "_scale")}
179+
180+
return isinstance(other, self.__class__) and filtered(
181+
self.__dict__
182+
) == filtered(other.__dict__)
183+
176184

177185
class TimestampTimeZone(Enum):
178186
"""

tests/integ/test_datatypes.py

Lines changed: 31 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,10 @@
55
import os
66
import tempfile
77
from decimal import Decimal
8-
from unittest.mock import patch
98

109
import pytest
1110

12-
from snowflake.snowpark import DataFrame, Row, context
11+
from snowflake.snowpark import DataFrame, Row
1312
from snowflake.snowpark.functions import lit
1413
from snowflake.snowpark.types import (
1514
BooleanType,
@@ -426,18 +425,15 @@ def test_join_basic(session):
426425
def test_numeric_type_store_precision_and_scale(session, massive_number, precision):
427426
table_name = Utils.random_table_name()
428427
try:
429-
with patch.object(context, "_store_precision_and_scale_in_numeric_type", True):
430-
df = session.create_dataframe(
431-
[Decimal(massive_number)],
432-
StructType(
433-
[StructField("large_value", DecimalType(precision, 0), True)]
434-
),
435-
)
436-
df.write.save_as_table(table_name, mode="overwrite", table_type="temp")
437-
result = session.sql(f"select * from {table_name}")
438-
datatype = result.schema.fields[0].datatype
439-
assert isinstance(datatype, LongType)
440-
assert datatype._precision == 38 and datatype._scale == 0
428+
df = session.create_dataframe(
429+
[Decimal(massive_number)],
430+
StructType([StructField("large_value", DecimalType(precision, 0), True)]),
431+
)
432+
df.write.save_as_table(table_name, mode="overwrite", table_type="temp")
433+
result = session.sql(f"select * from {table_name}")
434+
datatype = result.schema.fields[0].datatype
435+
assert isinstance(datatype, LongType)
436+
assert datatype._precision == 38 and datatype._scale == 0
441437
finally:
442438
session.sql(f"drop table {table_name}").collect()
443439

@@ -468,28 +464,27 @@ def write_csv(data):
468464
file_path = write_csv(test_data)
469465

470466
try:
471-
with patch.object(context, "_store_precision_and_scale_in_numeric_type", True):
472-
Utils.create_stage(session, stage_name, is_temporary=True)
473-
result = session.file.put(
474-
file_path, f"@{stage_name}", auto_compress=False, overwrite=True
475-
)
476-
477-
# Infer schema from only the short file
478-
constrained_reader = session.read.options(
479-
{
480-
"INFER_SCHEMA": True,
481-
"INFER_SCHEMA_OPTIONS": {"FILES": [result[0].target]},
482-
"PARSE_HEADER": True,
483-
# Only load the short file
484-
"PATTERN": f".*{result[0].target}",
485-
}
486-
)
487-
488-
# df1 uses constrained types
489-
df1 = constrained_reader.csv(f"@{stage_name}/")
490-
datatype = df1.schema.fields[0].datatype
491-
assert isinstance(datatype, LongType)
492-
assert datatype._precision == 38 and datatype._scale == 0
467+
Utils.create_stage(session, stage_name, is_temporary=True)
468+
result = session.file.put(
469+
file_path, f"@{stage_name}", auto_compress=False, overwrite=True
470+
)
471+
472+
# Infer schema from only the short file
473+
constrained_reader = session.read.options(
474+
{
475+
"INFER_SCHEMA": True,
476+
"INFER_SCHEMA_OPTIONS": {"FILES": [result[0].target]},
477+
"PARSE_HEADER": True,
478+
# Only load the short file
479+
"PATTERN": f".*{result[0].target}",
480+
}
481+
)
482+
483+
# df1 uses constrained types
484+
df1 = constrained_reader.csv(f"@{stage_name}/")
485+
datatype = df1.schema.fields[0].datatype
486+
assert isinstance(datatype, LongType)
487+
assert datatype._precision == 38 and datatype._scale == 0
493488

494489
finally:
495490
Utils.drop_stage(session, stage_name)

0 commit comments

Comments
 (0)