diff --git a/src/datastar_py/attributes.py b/src/datastar_py/attributes.py index 6a9f99c..9542b97 100644 --- a/src/datastar_py/attributes.py +++ b/src/datastar_py/attributes.py @@ -285,7 +285,7 @@ def view_transition(self, expression: str) -> BaseAttr: @property def json_signals(self) -> BaseAttr: """Create a signal that contains the JSON representation of the signals.""" - return BaseAttr("json-signals", alias=self._alias) + return JsonSignalsAttr(alias=self._alias) @property def ignore_morph(self) -> BaseAttr: @@ -529,13 +529,16 @@ def trust(self) -> Self: class PersistAttr(BaseAttr): _attr = "persist" - def __call__(self, signal_names: str | Iterable[str] | None = None) -> Self: - if not signal_names: - return self - if isinstance(signal_names, str): - self._value = signal_names - else: - self._value = " ".join(signal_names) + def __call__( + self, + storage_key: str | None = None, + include: str | None = None, + exclude: str | None = None, + ) -> Self: + if storage_key: + self._key = storage_key + if include or exclude: + self._value = json.dumps(_filter_dict(include=include, exclude=exclude)) return self @property @@ -550,12 +553,7 @@ class JsonSignalsAttr(BaseAttr): def __call__(self, include: str | None = None, exclude: str | None = None) -> Self: if include or exclude: - filter_object = {} - if include: - filter_object["include"] = include - if exclude: - filter_object["exclude"] = exclude - self._value = json.dumps(filter_object) + self._value = json.dumps(_filter_dict(include=include, exclude=exclude)) return self @property @@ -694,13 +692,11 @@ class OnSignalPatchAttr(BaseAttr, TimingMod, DelayMod): def filter(self, include: str | None = None, exclude: str | None = None) -> Self: """Filter the signal patch events.""" if include or exclude: - filter_object = {} - if include: - filter_object["include"] = include - if exclude: - filter_object["exclude"] = exclude self._other_attrs = [ - BaseAttr("on-signal-patch-filter", value=json.dumps(filter_object)) + BaseAttr( + "on-signal-patch-filter", + value=json.dumps(_filter_dict(include=include, exclude=exclude)), + ) ] return self @@ -714,12 +710,7 @@ class QueryStringAttr(BaseAttr): def __call__(self, include: str | None = None, exclude: str | None = None) -> Self: if include or exclude: - filter_object = {} - if include: - filter_object["include"] = include - if exclude: - filter_object["exclude"] = exclude - self._value = json.dumps(filter_object) + self._value = json.dumps(_filter_dict(include=include, exclude=exclude)) return self @property @@ -738,6 +729,15 @@ def _escape(s: str) -> str: ) +def _filter_dict(include: str | None = None, exclude: str | None = None) -> dict: + filter_dict = {} + if include: + filter_dict["include"] = include + if exclude: + filter_dict["exclude"] = exclude + return filter_dict + + def _js_object(obj: dict) -> str: """Create a JS object where the values are expressions rather than strings.""" return (