Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions src/spatch/_spatch_example/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,27 @@ def divide2(x, y):
@backend2.set_should_run(divide2)
def _(info, x, y):
return True


class StatefulClassImpl:
@backend1.implements("spatch._spatch_example.library:StatefulClass.apply")
@classmethod
def _from_apply(cls, original_self, x, y):
impl = cls()
impl.method = original_self.method
return impl

def apply(self, x, y):
if self.method == "add":
res = x + y
elif self.method == "sub":
res = x - y
else:
raise ValueError(f"Unknown method: {self.method}")

self._last_result = res
return res

@property
def last_result(self):
return self._last_result
6 changes: 6 additions & 0 deletions src/spatch/_spatch_example/entry_point.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,9 @@ uses_context = true
function = "spatch._spatch_example.backend:divide"
should_run = "spatch._spatch_example.backend:divide._should_run"
additional_docs = """This implementation works well on floats."""



[functions."spatch._spatch_example.library:StatefulClass.apply"]
function = "spatch._spatch_example.backend:StatefulClassImpl._from_apply"
uses_context = false
27 changes: 26 additions & 1 deletion src/spatch/_spatch_example/library.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from spatch.backend_system import BackendSystem
from spatch.backend_system import BackendSystem, dispatchable_stateful_class

_backend_system = BackendSystem(
"_spatch_example_backends", # entry point group
Expand All @@ -17,3 +17,28 @@ def divide(x, y):
# We could allow context being passed in to do this check.
raise TypeError("x and y must be an integer")
return x // y


@dispatchable_stateful_class()
class StatefulClass:
def __init__(self, method):
self.method = method

@_backend_system.stateful_dispatching(["x", "y"])
def apply(self, x, y):
if not isinstance(x, int) or not isinstance(y, int):
raise TypeError("x and y must be an integer")

if self.method == "add":
res = x + y
elif self.method == "sub":
res = x - y
else:
raise ValueError(f"Unknown method: {self.method}")

self._last_result = res
return res

@property
def last_result(self):
return self._last_result
84 changes: 84 additions & 0 deletions src/spatch/backend_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,27 @@ def wrap_callable(func):

return wrap_callable

def stateful_dispatching(self, dispatch_args=None, *, module=None, qualname=None):
"""
Mark a method as a stateful dispatching one.

Mark a method on a class as one that initiates stateful dispatching, i.e.
when this method is called, it will dispatch if dispatching has not yet happened.

See `~spatch.backend_system.BackendSystem.dispatchable` for information about
arguments.
"""
def wrap_callable(func):
@functools.wraps(func)
def no_impl_dispatching(*args, **kwargs):
return None

disp = Dispatchable(self, no_impl_dispatching, dispatch_args)
func._spatch_implementation_dispatcher = disp
return func

return wrap_callable

@functools.cached_property
def backend_opts(self):
"""Property returning a :py:class:`BackendOpts` class specific to this library
Expand Down Expand Up @@ -1031,3 +1052,66 @@ def __call__(self, *args, **kwargs):
call_trace.append(("default fallback", "called"))

return self._default_func(*args, **kwargs)


class DispatchableMethod:
def __init__(self, state_name, method_name, original_method, dispatcher=None):
# TODO: need backendsystem, although only for tracing?
self._dispatcher = getattr(original_method, "_spatch_implementation_dispatcher", None)
self._state_name = state_name
self._method_name = method_name
self.__wrapped__ = original_method
# add docs, etc.

def __get__(self, obj, objtype=None):
if obj is None:
return self
return MethodType(self, obj)

def __call__(self, original_self, *args, **kwargs):
state = getattr(original_self, self._state_name)
if state is NotImplemented:
if self._dispatcher is None:
state = None # No implementation bound yet.
else:
state = self._dispatcher(original_self, *args, **kwargs)
setattr(original_self, self._state_name, state)

if state is None:
return self.__wrapped__(original_self, *args, **kwargs)

return getattr(state, self._method_name)(*args, **kwargs)


class DispatchableProperty(property):
def __init__(self, state_name, property_name, original_property):
self._state_name = state_name
self._property_name = property_name
self.__wrapped__ = original_property
# documentation?

def __get__(self, instance, owner=None):
if instance is None:
return self
state = getattr(instance, self._state_name)
return getattr(state, self._property_name)

def __set__(self, instance, value):
state = getattr(instance, self._state_name)
setattr(state, self._property_name, value)


def dispatchable_stateful_class(state_name="_implementation"):
def decorator(cls):
setattr(cls, state_name, NotImplemented)

for name, attr in cls.__dict__.items():
if name.startswith("_") and not hasattr(attr, "_spatch_implementation_dispatcher"):
continue
elif callable(attr):
setattr(cls, name, DispatchableMethod(state_name, name, attr))
elif isinstance(attr, property):
setattr(cls, name, DispatchableProperty(state_name, name, attr))
return cls

return decorator
6 changes: 5 additions & 1 deletion src/spatch/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,11 +205,15 @@ def update_entrypoint(filepath: str):

func_info = functions[info.api_identity]

if info.func.__doc__ is not None:
docs = tomlkit.string(inspect.cleandoc(info.func.__doc__), multiline=True)
else:
docs = None
new_values = {
"function": info.impl_identity,
"should_run": info.should_run_identity,
"uses_context": info.uses_context,
"additional_docs": tomlkit.string(inspect.cleandoc(info.func.__doc__), multiline=True),
"additional_docs": docs,
}

for attr, value in new_values.items():
Expand Down
Loading