Skip to content

Commit 6eec8e3

Browse files
devin-ai-integration[bot]Alek99masenf
authored
Add decimal.Decimal support to serializers and NumberVar (#5226)
* Add serializer for decimal.Decimal type that converts to float Co-Authored-By: Alek Petuskey <[email protected]> * Add tests for decimal.Decimal serializer and NumberVar support Co-Authored-By: Alek Petuskey <[email protected]> * Update NumberVar and related components to support decimal.Decimal Co-Authored-By: Alek Petuskey <[email protected]> * Simplify test_all_number_operations to fix type compatibility with decimal.Decimal Co-Authored-By: Alek Petuskey <[email protected]> * Fix decimal serialization to properly quote string values Co-Authored-By: Alek Petuskey <[email protected]> * Fix decimal serialization functions Co-Authored-By: Alek Petuskey <[email protected]> * Revert "Simplify test_all_number_operations to fix type compatibility with decimal.Decimal" This reverts commit 758d55f. * revert bad test change * add overload for Decimal in Var.create move test_decimal_var to test_var and tweak the expectations override return type for NumberVar.__neg__ * revert changes in `float_input_event` needed to add another `.to` overload for proper type checking * update test_serializers expectation --------- Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Co-authored-by: Alek Petuskey <[email protected]> Co-authored-by: Masen Furer <[email protected]>
1 parent b5c8b6e commit 6eec8e3

File tree

6 files changed

+92
-10
lines changed

6 files changed

+92
-10
lines changed

reflex/utils/serializers.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import contextlib
66
import dataclasses
7+
import decimal
78
import functools
89
import inspect
910
import json
@@ -386,6 +387,19 @@ def serialize_uuid(uuid: UUID) -> str:
386387
return str(uuid)
387388

388389

390+
@serializer(to=float)
391+
def serialize_decimal(value: decimal.Decimal) -> float:
392+
"""Serialize a Decimal to a float.
393+
394+
Args:
395+
value: The Decimal to serialize.
396+
397+
Returns:
398+
The serialized Decimal as a float.
399+
"""
400+
return float(value)
401+
402+
389403
@serializer(to=str)
390404
def serialize_color(color: Color) -> str:
391405
"""Serialize a color.

reflex/vars/base.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import uuid
1616
import warnings
1717
from collections.abc import Callable, Coroutine, Iterable, Mapping, Sequence
18+
from decimal import Decimal
1819
from types import CodeType, FunctionType
1920
from typing import ( # noqa: UP035
2021
TYPE_CHECKING,
@@ -630,6 +631,14 @@ def create(
630631
_var_data: VarData | None = None,
631632
) -> LiteralNumberVar[float]: ...
632633

634+
@overload
635+
@classmethod
636+
def create(
637+
cls,
638+
value: Decimal,
639+
_var_data: VarData | None = None,
640+
) -> LiteralNumberVar[Decimal]: ...
641+
633642
@overload
634643
@classmethod
635644
def create( # pyright: ignore [reportOverlappingOverload]
@@ -743,7 +752,10 @@ def to(self, output: type[bool]) -> BooleanVar: ...
743752
def to(self, output: type[int]) -> NumberVar[int]: ...
744753

745754
@overload
746-
def to(self, output: type[int] | type[float]) -> NumberVar: ...
755+
def to(self, output: type[float]) -> NumberVar[float]: ...
756+
757+
@overload
758+
def to(self, output: type[Decimal]) -> NumberVar[Decimal]: ...
747759

748760
@overload
749761
def to(

reflex/vars/number.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
import dataclasses
6+
import decimal
67
import json
78
import math
89
from collections.abc import Callable
@@ -30,7 +31,10 @@
3031
)
3132

3233
NUMBER_T = TypeVarExt(
33-
"NUMBER_T", bound=(int | float), default=(int | float), covariant=True
34+
"NUMBER_T",
35+
bound=(int | float | decimal.Decimal),
36+
default=(int | float | decimal.Decimal),
37+
covariant=True,
3438
)
3539

3640
if TYPE_CHECKING:
@@ -54,7 +58,7 @@ def raise_unsupported_operand_types(
5458
)
5559

5660

57-
class NumberVar(Var[NUMBER_T], python_types=(int, float)):
61+
class NumberVar(Var[NUMBER_T], python_types=(int, float, decimal.Decimal)):
5862
"""Base class for immutable number vars."""
5963

6064
def __add__(self, other: number_types) -> NumberVar:
@@ -285,13 +289,13 @@ def __rpow__(self, other: number_types) -> NumberVar:
285289

286290
return number_exponent_operation(+other, self)
287291

288-
def __neg__(self):
292+
def __neg__(self) -> NumberVar:
289293
"""Negate the number.
290294
291295
Returns:
292296
The number negation operation.
293297
"""
294-
return number_negate_operation(self)
298+
return number_negate_operation(self) # pyright: ignore [reportReturnType]
295299

296300
def __invert__(self):
297301
"""Boolean NOT the number.
@@ -943,7 +947,7 @@ def boolean_not_operation(value: BooleanVar):
943947
class LiteralNumberVar(LiteralVar, NumberVar[NUMBER_T]):
944948
"""Base class for immutable literal number vars."""
945949

946-
_var_value: float | int = dataclasses.field(default=0)
950+
_var_value: float | int | decimal.Decimal = dataclasses.field(default=0)
947951

948952
def json(self) -> str:
949953
"""Get the JSON representation of the var.
@@ -954,6 +958,8 @@ def json(self) -> str:
954958
Raises:
955959
PrimitiveUnserializableToJSONError: If the var is unserializable to JSON.
956960
"""
961+
if isinstance(self._var_value, decimal.Decimal):
962+
return json.dumps(float(self._var_value))
957963
if math.isinf(self._var_value) or math.isnan(self._var_value):
958964
raise PrimitiveUnserializableToJSONError(
959965
f"No valid JSON representation for {self}"
@@ -969,7 +975,9 @@ def __hash__(self) -> int:
969975
return hash((type(self).__name__, self._var_value))
970976

971977
@classmethod
972-
def create(cls, value: float | int, _var_data: VarData | None = None):
978+
def create(
979+
cls, value: float | int | decimal.Decimal, _var_data: VarData | None = None
980+
):
973981
"""Create the number var.
974982
975983
Args:
@@ -1039,7 +1047,7 @@ def create(cls, value: bool, _var_data: VarData | None = None):
10391047
)
10401048

10411049

1042-
number_types = NumberVar | int | float
1050+
number_types = NumberVar | int | float | decimal.Decimal
10431051
boolean_types = BooleanVar | bool
10441052

10451053

@@ -1112,4 +1120,4 @@ def ternary_operation(
11121120
return value
11131121

11141122

1115-
NUMBER_TYPES = (int, float, NumberVar)
1123+
NUMBER_TYPES = (int, float, decimal.Decimal, NumberVar)

reflex/vars/sequence.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import collections.abc
66
import dataclasses
7+
import decimal
78
import inspect
89
import json
910
import re
@@ -1558,7 +1559,7 @@ def is_tuple_type(t: GenericType) -> bool:
15581559

15591560

15601561
def _determine_value_of_array_index(
1561-
var_type: GenericType, index: int | float | None = None
1562+
var_type: GenericType, index: int | float | decimal.Decimal | None = None
15621563
):
15631564
"""Determine the value of an array index.
15641565

tests/units/test_var.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import decimal
12
import json
23
import math
34
import typing
@@ -1920,3 +1921,43 @@ class StateWithVar(rx.State):
19201921
rx.vstack(
19211922
str(StateWithVar.field),
19221923
)
1924+
1925+
1926+
def test_decimal_number_operations():
1927+
"""Test that decimal.Decimal values work with NumberVar operations."""
1928+
dec_num = Var.create(decimal.Decimal("123.456"))
1929+
assert isinstance(dec_num._var_value, decimal.Decimal)
1930+
assert str(dec_num) == "123.456"
1931+
1932+
result = dec_num + 10
1933+
assert str(result) == "(123.456 + 10)"
1934+
1935+
result = dec_num * 2
1936+
assert str(result) == "(123.456 * 2)"
1937+
1938+
result = dec_num / 2
1939+
assert str(result) == "(123.456 / 2)"
1940+
1941+
result = dec_num > 100
1942+
assert str(result) == "(123.456 > 100)"
1943+
1944+
result = dec_num < 200
1945+
assert str(result) == "(123.456 < 200)"
1946+
1947+
assert dec_num.json() == "123.456"
1948+
1949+
1950+
def test_decimal_var_type_compatibility():
1951+
"""Test that decimal.Decimal values are compatible with NumberVar type system."""
1952+
dec_num = Var.create(decimal.Decimal("123.456"))
1953+
int_num = Var.create(42)
1954+
float_num = Var.create(3.14)
1955+
1956+
result = dec_num + int_num
1957+
assert str(result) == "(123.456 + 42)"
1958+
1959+
result = dec_num * float_num
1960+
assert str(result) == "(123.456 * 3.14)"
1961+
1962+
result = (dec_num + int_num) / float_num
1963+
assert str(result) == "((123.456 + 42) / 3.14)"

tests/units/utils/test_serializers.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import datetime
2+
import decimal
23
import json
34
from enum import Enum
45
from pathlib import Path
@@ -188,6 +189,9 @@ class BaseSubclass(Base):
188189
(Color(color="slate", shade=1), "var(--slate-1)"),
189190
(Color(color="orange", shade=1, alpha=True), "var(--orange-a1)"),
190191
(Color(color="accent", shade=1, alpha=True), "var(--accent-a1)"),
192+
(decimal.Decimal("123.456"), 123.456),
193+
(decimal.Decimal("-0.5"), -0.5),
194+
(decimal.Decimal("0"), 0.0),
191195
],
192196
)
193197
def test_serialize(value: Any, expected: str):
@@ -226,6 +230,8 @@ def test_serialize(value: Any, expected: str):
226230
(Color(color="slate", shade=1), '"var(--slate-1)"', True),
227231
(BaseSubclass, '"BaseSubclass"', True),
228232
(Path(), '"."', True),
233+
(decimal.Decimal("123.456"), "123.456", True),
234+
(decimal.Decimal("-0.5"), "-0.5", True),
229235
],
230236
)
231237
def test_serialize_var_to_str(value: Any, expected: str, exp_var_is_string: bool):

0 commit comments

Comments
 (0)