Skip to content

Commit f13acb1

Browse files
committed
use type directly instead of converting to string
1 parent 90fef57 commit f13acb1

File tree

3 files changed

+8
-10
lines changed

3 files changed

+8
-10
lines changed

src/snowflake/snowpark/context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545

4646
# internal only dictionary store the default precision of integral types, if the type does not appear in the
4747
# dictionary, the default precision is None.
48-
# example: _integral_type_default_precision = {"IntegerType": 9}, IntegerType default _precision is 9 now
48+
# example: _integral_type_default_precision = {IntegerType: 9}, IntegerType default _precision is 9 now
4949
_integral_type_default_precision = {}
5050

5151

src/snowflake/snowpark/types.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -371,9 +371,8 @@ def _fill_ast(self, ast: proto.DataType) -> None:
371371
# Numeric types
372372
class _IntegralType(_NumericType):
373373
def __init__(self, **kwargs) -> None:
374-
class_name = type(self).__name__
375374
self._precision = kwargs.pop(
376-
"_precision", context._integral_type_default_precision.get(class_name, None)
375+
"_precision", context._integral_type_default_precision.get(type(self), None)
377376
)
378377

379378
if kwargs != {}:

tests/integ/test_datatypes.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -557,8 +557,8 @@ def test_integral_type_default_precision(mock_default_precision):
557557
@pytest.mark.parametrize(
558558
"mock_default_precision",
559559
[
560-
{"IntegerType": 5, "LongType": 4},
561-
{"LongType": 19, "IntegerType": 10},
560+
{IntegerType: 5, LongType: 4},
561+
{LongType: 19, IntegerType: 10},
562562
],
563563
)
564564
def test_end_to_end_default_precision(session, precision, mock_default_precision):
@@ -585,21 +585,20 @@ def test_end_to_end_default_precision(session, precision, mock_default_precision
585585
assert df.schema.fields[0].datatype._precision == precision
586586
assert (
587587
df.schema.fields[1].datatype._precision
588-
== mock_default_precision["IntegerType"]
588+
== mock_default_precision[IntegerType]
589589
)
590590
assert (
591-
df.schema.fields[2].datatype._precision
592-
== mock_default_precision["LongType"]
591+
df.schema.fields[2].datatype._precision == mock_default_precision[LongType]
593592
)
594593

595594
df.write.save_as_table(table_name, mode="overwrite", table_type="temp")
596595
result = session.sql(f"select * from {table_name}")
597596
assert result.schema.fields[0].datatype._precision == precision
598597
assert (
599598
result.schema.fields[1].datatype._precision
600-
== mock_default_precision["IntegerType"]
599+
== mock_default_precision[IntegerType]
601600
)
602601
assert (
603602
result.schema.fields[2].datatype._precision
604-
== mock_default_precision["LongType"]
603+
== mock_default_precision[LongType]
605604
)

0 commit comments

Comments
 (0)