Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions src/attr/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@

# This is used at least twice, so cache it here.
_OBJ_SETATTR = object.__setattr__
_BASE_EXCEPTION_REDUCE = BaseException.__dict__.get("__reduce__")
_BASE_EXCEPTION_SETSTATE = BaseException.__dict__.get("__setstate__")
_INIT_FACTORY_PAT = "__attr_factory_%s"
_CLASSVAR_PREFIXES = (
"typing.ClassVar",
Expand Down Expand Up @@ -103,6 +105,37 @@ def __reduce__(self, _none_constructor=type(None), _args=()): # noqa: B008
return _none_constructor, _args


def _reconstruct_exception(cls, args, state):
"""
Reconstruct an attrs exception for pickle without calling __init__.
"""
self = BaseException.__new__(cls)
BaseException.__init__(self, *args)

if state is None:
return self

setstate = getattr(self, "__setstate__", None)
if (
setstate is not None
and getattr(cls, "__setstate__", None) is not _BASE_EXCEPTION_SETSTATE
):
setstate(state)
return self

if isinstance(state, tuple):
inst_dict, slot_state = state
state = {}
if inst_dict:
state.update(inst_dict)
state.update(slot_state)

for name, value in state.items():
_OBJ_SETATTR(self, name, value)

return self


def attrib(
default=NOTHING,
validator=None,
Expand Down Expand Up @@ -1018,6 +1051,34 @@ def __str__(self):
self._cls_dict["__str__"] = self._add_method_dunders(__str__)
return self

def add_exception_reduce(self):
def __reduce__(self):
getstate = getattr(self, "__getstate__", None)
if getstate is not None:
state = getstate()
else:
dict_state = getattr(self, "__dict__", None)
dict_state = dict_state.copy() if dict_state else None

if self.__attrs_props__.is_slotted:
slot_state = {
a.name: getattr(self, a.name)
for a in self.__attrs_attrs__
if a.name != "__weakref__"
}
state = (dict_state, slot_state)
else:
state = dict_state

return (
_reconstruct_exception,
(self.__class__, self.args, state),
)

__reduce__.__attrs_exception_reduce__ = True
self._cls_dict["__reduce__"] = self._add_method_dunders(__reduce__)
return self

def _make_getstate_setstate(self):
"""
Create custom __setstate__ and __getstate__ methods.
Expand Down Expand Up @@ -1563,6 +1624,22 @@ def wrap(cls):
msg = "Invalid value for cache_hash. To use hash caching, init must be True."
raise TypeError(msg)

if (
props.is_exception
and props.added_init
and any(a.init and a.kw_only for a in builder._attrs)
and not _has_own_attribute(cls, "__reduce__")
and (
getattr(cls, "__reduce__", None) is _BASE_EXCEPTION_REDUCE
or getattr(
getattr(cls, "__reduce__", None),
"__attrs_exception_reduce__",
False,
)
)
):
builder.add_exception_reduce()

if PY_3_13_PLUS and not _has_own_attribute(cls, "__replace__"):
builder.add_replace()

Expand Down
108 changes: 108 additions & 0 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,59 @@ class C2Slots:
y = attr.ib(default=attr.Factory(list))


@attr.s(auto_exc=True)
class KwOnlyException(Exception):
value = attr.ib(kw_only=True)


@attr.s(auto_exc=True, frozen=True)
class FrozenKwOnlyException(Exception):
value = attr.ib(kw_only=True)


@attr.s(auto_exc=True, slots=True)
class SlottedKwOnlyException(Exception):
value = attr.ib(kw_only=True)


@attr.s(auto_exc=True, frozen=True, slots=True)
class FrozenSlottedKwOnlyException(Exception):
value = attr.ib(kw_only=True)


@attr.s(auto_exc=True, frozen=True, slots=True, kw_only=True)
class KwOnlyBaseException(Exception):
has_default = attr.ib(default=42)


@attr.s(auto_exc=True, frozen=True, slots=True)
class KwOnlySubException(KwOnlyBaseException):
no_default = attr.ib(kw_only=False)


@attr.s(auto_exc=True)
class CustomReduceKwOnlyException(Exception):
value = attr.ib(kw_only=True)

def __reduce__(self):
return "custom"


class InheritedCustomReduceException(Exception):
def __reduce__(self):
return "inherited custom"


@attr.s(auto_exc=True)
class InheritedCustomReduceKwOnlyException(InheritedCustomReduceException):
value = attr.ib(kw_only=True)


@attr.s(auto_exc=True)
class PositionalOnlyException(Exception):
value = attr.ib()


@attr.s
class Base:
x = attr.ib()
Expand Down Expand Up @@ -624,6 +677,61 @@ class FooError(Exception):

FooError(1)

@pytest.mark.parametrize(
"cls",
[
KwOnlyException,
FrozenKwOnlyException,
SlottedKwOnlyException,
FrozenSlottedKwOnlyException,
],
)
def test_auto_exc_kw_only_pickles(self, cls):
"""
Keyword-only exception fields don't break pickle round-tripping.
"""
exc = cls(value=1)

rt = pickle.loads(pickle.dumps(exc))

assert isinstance(rt, cls)
assert 1 == rt.value
assert exc.args == rt.args

def test_auto_exc_kw_only_pickles_subclass_with_positional_field(self):
"""
Inherited keyword-only fields work with subclass positional fields.
"""
exc = KwOnlySubException("new", has_default=23)

rt = pickle.loads(pickle.dumps(exc))

assert isinstance(rt, KwOnlySubException)
assert 23 == rt.has_default
assert "new" == rt.no_default
assert exc.args == rt.args

def test_auto_exc_does_not_overwrite_custom_reduce(self):
"""
A custom __reduce__ on a keyword-only exception is left alone.
"""
assert "custom" == CustomReduceKwOnlyException(value=1).__reduce__()

def test_auto_exc_does_not_overwrite_inherited_custom_reduce(self):
"""
An inherited custom __reduce__ is left alone.
"""
assert (
"inherited custom"
== InheritedCustomReduceKwOnlyException(value=1).__reduce__()
)

def test_auto_exc_without_kw_only_does_not_add_reduce(self):
"""
Exceptions without keyword-only init fields keep BaseException reduce.
"""
assert "__reduce__" not in PositionalOnlyException.__dict__

def test_eq_only(self, slots, frozen):
"""
Classes with order=False cannot be ordered.
Expand Down
Loading