Skip to content

Commit 3cef889

Browse files
committed
Return NotImplemented where relevant.
1 parent 687be9f commit 3cef889

File tree

2 files changed

+35
-16
lines changed

2 files changed

+35
-16
lines changed

src/mind_the_gaps/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@
55

66
__all__ = ["Endpoint", "Gaps", "Var", "x"]
77

8-
__version__ = "0.3.3"
8+
__version__ = "0.3.4"

src/mind_the_gaps/gaps.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
from bisect import bisect
24
from collections.abc import Callable
35
from dataclasses import dataclass, field
@@ -42,16 +44,21 @@ class Endpoint[T: SupportsLessThan]:
4244
`"]"` is closed and right.
4345
"""
4446

45-
def __lt__(self, other: Self) -> bool:
47+
def __lt__(self, other: Endpoint) -> bool:
48+
if not isinstance(other, Endpoint):
49+
return NotImplemented
50+
4651
if self.value != other.value:
4752
return self.value < other.value
4853

4954
return _BOUNDARY_ORDER[self.boundary] < _BOUNDARY_ORDER[other.boundary]
5055

51-
def __eq__(self, other: Self) -> bool:
56+
def __eq__(self, other: Endpoint) -> bool:
57+
if not isinstance(other, Endpoint):
58+
return NotImplemented
59+
5260
return (
53-
isinstance(other, Endpoint)
54-
and self.value == other.value
61+
self.value == other.value
5562
and _BOUNDARY_ORDER[self.boundary] == _BOUNDARY_ORDER[other.boundary]
5663
)
5764

@@ -155,7 +162,7 @@ class Gaps[T: SupportsLessThan]:
155162

156163
endpoints: list[T | Endpoint[T]] = field(default_factory=list)
157164

158-
def __post_init__(self):
165+
def __post_init__(self) -> None:
159166
if len(self.endpoints) % 2 == 1:
160167
raise ValueError("Need an even number of endpoints.")
161168

@@ -214,19 +221,31 @@ def from_string(cls, gaps: str) -> Self:
214221

215222
return cls(endpoints)
216223

217-
def __or__(self, other: Self) -> Self:
218-
return Gaps(_merge(self.endpoints, other.endpoints, or_))
224+
def __or__(self, other: Gaps) -> Self:
225+
if not isinstance(other, Gaps):
226+
return NotImplemented
227+
228+
return type(self)(_merge(self.endpoints, other.endpoints, or_))
219229

220-
def __and__(self, other: Self) -> Self:
221-
return Gaps(_merge(self.endpoints, other.endpoints, and_))
230+
def __and__(self, other: Gaps) -> Self:
231+
if not isinstance(other, Gaps):
232+
return NotImplemented
222233

223-
def __xor__(self, other: Self) -> Self:
224-
return Gaps(_merge(self.endpoints, other.endpoints, xor))
234+
return type(self)(_merge(self.endpoints, other.endpoints, and_))
225235

226-
def __sub__(self, other: Self) -> Self:
227-
return Gaps(_merge(self.endpoints, other.endpoints, sub))
236+
def __xor__(self, other: Gaps) -> Self:
237+
if not isinstance(other, Gaps):
238+
return NotImplemented
228239

229-
def __bool__(self):
240+
return type(self)(_merge(self.endpoints, other.endpoints, xor))
241+
242+
def __sub__(self, other: Gaps) -> Self:
243+
if not isinstance(other, Gaps):
244+
return NotImplemented
245+
246+
return type(self)(_merge(self.endpoints, other.endpoints, sub))
247+
248+
def __bool__(self) -> bool:
230249
return len(self.endpoints) > 0
231250

232251
def __contains__(self, value: T) -> bool:
@@ -242,5 +261,5 @@ def __contains__(self, value: T) -> bool:
242261

243262
return False
244263

245-
def __str__(self):
264+
def __str__(self) -> str:
246265
return f"{{{", ".join(str(endpoint) for endpoint in self.endpoints)}}}"

0 commit comments

Comments
 (0)