|
4 | 4 | from collections.abc import Callable |
5 | 5 | from dataclasses import dataclass, field |
6 | 6 | from functools import total_ordering |
| 7 | +from itertools import pairwise |
7 | 8 | from operator import and_, attrgetter, or_, xor |
8 | 9 | from typing import Any, Final, Literal, Protocol, Self |
9 | 10 |
|
@@ -166,6 +167,13 @@ def _merge[T]( |
166 | 167 | return endpoints |
167 | 168 |
|
168 | 169 |
|
| 170 | +def _to_number(text: str) -> int | float: |
| 171 | + try: |
| 172 | + return int(text) |
| 173 | + except ValueError: |
| 174 | + return float(text) |
| 175 | + |
| 176 | + |
169 | 177 | @dataclass |
170 | 178 | class Gaps[T: SupportsLessThan]: |
171 | 179 | """A set of mutually exclusive continuous intervals. |
@@ -208,36 +216,36 @@ def __post_init__(self) -> None: |
208 | 216 | i += 1 |
209 | 217 |
|
210 | 218 | @classmethod |
211 | | - def from_string(cls, gaps: str) -> Self: |
| 219 | + def from_string(cls, gaps: str) -> Self[int | float]: |
212 | 220 | """Create gaps from a string. |
213 | 221 |
|
214 | | - Values can only be int or float. Uses standard interval notation, i.e., `"{(-inf, 1], [2, 3)}"`. |
| 222 | + Values can only be int or float. Uses standard interval notation, i.e., |
| 223 | + `"{(-inf, 1], [2, 3)}"`. If the start and end value of an interval are equal, |
| 224 | + the interval may be expressed as, e.g., `"{[0]}"`. |
215 | 225 | """ |
216 | 226 | if gaps[0] != "{" or gaps[-1] != "}": |
217 | 227 | raise ValueError( |
218 | 228 | "Gap string must start and end with curly braces ('{', '}')." |
219 | 229 | ) |
220 | 230 |
|
221 | | - endpoints = gaps[1:-1].replace(" ", "").split(",") |
222 | | - if len(endpoints) == 1: |
| 231 | + splits = gaps[1:-1].replace(" ", "").split(",") |
| 232 | + if len(splits) == 1 and splits[0] == "": |
223 | 233 | return cls([]) |
224 | 234 |
|
225 | | - for i, endpoint in enumerate(endpoints): |
226 | | - if endpoint.startswith(("(", "[")): |
227 | | - boundary = endpoint[0] |
228 | | - value = endpoint[1:] |
229 | | - elif endpoint.endswith((")", "]")): |
230 | | - boundary = endpoint[-1] |
231 | | - value = endpoint[:-1] |
| 235 | + endpoints: list[Endpoint[int | float]] = [] |
| 236 | + for split in splits: |
| 237 | + if split.startswith("[") and split.endswith("]"): |
| 238 | + value = _to_number(split[1:-1]) |
| 239 | + endpoints.append(Endpoint(value, "[")) |
| 240 | + endpoints.append(Endpoint(value, "]")) |
| 241 | + elif split.startswith(("(", "[")): |
| 242 | + value = _to_number(split[1:]) |
| 243 | + endpoints.append(Endpoint(value, split[0])) |
| 244 | + elif split.endswith((")", "]")): |
| 245 | + value = _to_number(split[:-1]) |
| 246 | + endpoints.append(Endpoint(value, split[-1])) |
232 | 247 | else: |
233 | | - raise ValueError(f"Invalid endpoint ({endpoint!r}).") |
234 | | - |
235 | | - try: |
236 | | - value = int(value) |
237 | | - except ValueError: |
238 | | - value = float(value) |
239 | | - |
240 | | - endpoints[i] = Endpoint(value, boundary) |
| 248 | + raise ValueError(f"Invalid endpoint ({split!r}).") |
241 | 249 |
|
242 | 250 | return cls(endpoints) |
243 | 251 |
|
@@ -282,4 +290,12 @@ def __contains__(self, value: T) -> bool: |
282 | 290 | return False |
283 | 291 |
|
284 | 292 | def __str__(self) -> str: |
285 | | - return f"{{{", ".join(str(endpoint) for endpoint in self.endpoints)}}}" |
| 293 | + endpoints = [] |
| 294 | + for a, b in pairwise(self.endpoints): |
| 295 | + if a == b: |
| 296 | + endpoints.append(f"[{a.value}]") |
| 297 | + else: |
| 298 | + endpoints.append(str(a)) |
| 299 | + if a != b: |
| 300 | + endpoints.append(str(b)) |
| 301 | + return f"{{{", ".join(endpoints)}}}" |
0 commit comments