diff --git a/src/spatch/_spatch_example/backend.py b/src/spatch/_spatch_example/backend.py index 55d14fc..9dbf7b9 100644 --- a/src/spatch/_spatch_example/backend.py +++ b/src/spatch/_spatch_example/backend.py @@ -43,3 +43,27 @@ def divide2(x, y): @backend2.set_should_run(divide2) def _(info, x, y): return True + + +class StatefulClassImpl: + @backend1.implements("spatch._spatch_example.library:StatefulClass.apply") + @classmethod + def _from_apply(cls, original_self, x, y): + impl = cls() + impl.method = original_self.method + return impl + + def apply(self, x, y): + if self.method == "add": + res = x + y + elif self.method == "sub": + res = x - y + else: + raise ValueError(f"Unknown method: {self.method}") + + self._last_result = res + return res + + @property + def last_result(self): + return self._last_result diff --git a/src/spatch/_spatch_example/entry_point.toml b/src/spatch/_spatch_example/entry_point.toml index e3a972b..363fd9c 100644 --- a/src/spatch/_spatch_example/entry_point.toml +++ b/src/spatch/_spatch_example/entry_point.toml @@ -20,3 +20,9 @@ uses_context = true function = "spatch._spatch_example.backend:divide" should_run = "spatch._spatch_example.backend:divide._should_run" additional_docs = """This implementation works well on floats.""" + + + +[functions."spatch._spatch_example.library:StatefulClass.apply"] +function = "spatch._spatch_example.backend:StatefulClassImpl._from_apply" +uses_context = false diff --git a/src/spatch/_spatch_example/library.py b/src/spatch/_spatch_example/library.py index 66e7027..e18e5f3 100644 --- a/src/spatch/_spatch_example/library.py +++ b/src/spatch/_spatch_example/library.py @@ -1,4 +1,4 @@ -from spatch.backend_system import BackendSystem +from spatch.backend_system import BackendSystem, dispatchable_stateful_class _backend_system = BackendSystem( "_spatch_example_backends", # entry point group @@ -17,3 +17,28 @@ def divide(x, y): # We could allow context being passed in to do this check. raise TypeError("x and y must be an integer") return x // y + + +@dispatchable_stateful_class() +class StatefulClass: + def __init__(self, method): + self.method = method + + @_backend_system.stateful_dispatching(["x", "y"]) + def apply(self, x, y): + if not isinstance(x, int) or not isinstance(y, int): + raise TypeError("x and y must be an integer") + + if self.method == "add": + res = x + y + elif self.method == "sub": + res = x - y + else: + raise ValueError(f"Unknown method: {self.method}") + + self._last_result = res + return res + + @property + def last_result(self): + return self._last_result diff --git a/src/spatch/backend_system.py b/src/spatch/backend_system.py index 32e9931..2fd388c 100644 --- a/src/spatch/backend_system.py +++ b/src/spatch/backend_system.py @@ -731,6 +731,27 @@ def wrap_callable(func): return wrap_callable + def stateful_dispatching(self, dispatch_args=None, *, module=None, qualname=None): + """ + Mark a method as a stateful dispatching one. + + Mark a method on a class as one that initiates stateful dispatching, i.e. + when this method is called, it will dispatch if dispatching has not yet happened. + + See `~spatch.backend_system.BackendSystem.dispatchable` for information about + arguments. + """ + def wrap_callable(func): + @functools.wraps(func) + def no_impl_dispatching(*args, **kwargs): + return None + + disp = Dispatchable(self, no_impl_dispatching, dispatch_args) + func._spatch_implementation_dispatcher = disp + return func + + return wrap_callable + @functools.cached_property def backend_opts(self): """Property returning a :py:class:`BackendOpts` class specific to this library @@ -1031,3 +1052,66 @@ def __call__(self, *args, **kwargs): call_trace.append(("default fallback", "called")) return self._default_func(*args, **kwargs) + + +class DispatchableMethod: + def __init__(self, state_name, method_name, original_method, dispatcher=None): + # TODO: need backendsystem, although only for tracing? + self._dispatcher = getattr(original_method, "_spatch_implementation_dispatcher", None) + self._state_name = state_name + self._method_name = method_name + self.__wrapped__ = original_method + # add docs, etc. + + def __get__(self, obj, objtype=None): + if obj is None: + return self + return MethodType(self, obj) + + def __call__(self, original_self, *args, **kwargs): + state = getattr(original_self, self._state_name) + if state is NotImplemented: + if self._dispatcher is None: + state = None # No implementation bound yet. + else: + state = self._dispatcher(original_self, *args, **kwargs) + setattr(original_self, self._state_name, state) + + if state is None: + return self.__wrapped__(original_self, *args, **kwargs) + + return getattr(state, self._method_name)(*args, **kwargs) + + +class DispatchableProperty(property): + def __init__(self, state_name, property_name, original_property): + self._state_name = state_name + self._property_name = property_name + self.__wrapped__ = original_property + # documentation? + + def __get__(self, instance, owner=None): + if instance is None: + return self + state = getattr(instance, self._state_name) + return getattr(state, self._property_name) + + def __set__(self, instance, value): + state = getattr(instance, self._state_name) + setattr(state, self._property_name, value) + + +def dispatchable_stateful_class(state_name="_implementation"): + def decorator(cls): + setattr(cls, state_name, NotImplemented) + + for name, attr in cls.__dict__.items(): + if name.startswith("_") and not hasattr(attr, "_spatch_implementation_dispatcher"): + continue + elif callable(attr): + setattr(cls, name, DispatchableMethod(state_name, name, attr)) + elif isinstance(attr, property): + setattr(cls, name, DispatchableProperty(state_name, name, attr)) + return cls + + return decorator diff --git a/src/spatch/backend_utils.py b/src/spatch/backend_utils.py index 480cab1..06cba06 100644 --- a/src/spatch/backend_utils.py +++ b/src/spatch/backend_utils.py @@ -205,11 +205,15 @@ def update_entrypoint(filepath: str): func_info = functions[info.api_identity] + if info.func.__doc__ is not None: + docs = tomlkit.string(inspect.cleandoc(info.func.__doc__), multiline=True) + else: + docs = None new_values = { "function": info.impl_identity, "should_run": info.should_run_identity, "uses_context": info.uses_context, - "additional_docs": tomlkit.string(inspect.cleandoc(info.func.__doc__), multiline=True), + "additional_docs": docs, } for attr, value in new_values.items():