|
| 1 | +""" |
| 2 | +Split in multiple dimensions. |
| 3 | +""" |
| 4 | + |
| 5 | +from dataclasses import dataclass |
| 6 | +from functools import cached_property |
| 7 | +from itertools import combinations_with_replacement, product, starmap |
| 8 | +from math import floor, prod, sqrt |
| 9 | +from typing import Iterator |
| 10 | + |
| 11 | + |
| 12 | +@dataclass |
| 13 | +class Bounds: |
| 14 | + """Spatial bounds defined in multiple dimensions.""" |
| 15 | + |
| 16 | + start_point: tuple[int, ...] |
| 17 | + end_point: tuple[int, ...] |
| 18 | + |
| 19 | + def __post_init__(self) -> None: |
| 20 | + """Validate the bounds.""" |
| 21 | + assert len(self.start_point) == len(self.end_point) |
| 22 | + assert self.start_point < self.end_point |
| 23 | + |
| 24 | + def __len__(self): |
| 25 | + """Get the number of points in the bounds (length, area, or volume). |
| 26 | +
|
| 27 | + Example: |
| 28 | + >>> bounds = Bounds((50, 25), (150, 100)) |
| 29 | + >>> len(bounds) |
| 30 | + 7500 |
| 31 | + """ |
| 32 | + return prod( |
| 33 | + map(lambda x: x[1] - x[0], zip(self.start_point, self.end_point)) |
| 34 | + ) |
| 35 | + |
| 36 | + def __iter__(self) -> Iterator[tuple[int, ...]]: |
| 37 | + """Iterate over the bounds, yielding each point like in an odometer. |
| 38 | +
|
| 39 | + Example: |
| 40 | + >>> bounds = Bounds((50, 25), (150, 100)) |
| 41 | + >>> for x, y in bounds: |
| 42 | + ... print(x, y) # doctest: +ELLIPSIS |
| 43 | + 50 25 |
| 44 | + 50 26 |
| 45 | + 50 27 |
| 46 | + ... |
| 47 | + 149 97 |
| 48 | + 149 98 |
| 49 | + 149 99 |
| 50 | + """ |
| 51 | + return product(*starmap(range, zip(self.start_point, self.end_point))) |
| 52 | + |
| 53 | + def slices(self) -> tuple[slice, ...]: |
| 54 | + """Return the slice for each dimension. |
| 55 | +
|
| 56 | + Example: |
| 57 | + >>> bounds = Bounds((50, 25), (150, 100)) |
| 58 | + >>> bounds.slices() |
| 59 | + (slice(50, 150, None), slice(25, 100, None)) |
| 60 | + """ |
| 61 | + return tuple( |
| 62 | + slice(start, end) |
| 63 | + for start, end in zip(self.start_point, self.end_point) |
| 64 | + ) |
| 65 | + |
| 66 | + @cached_property |
| 67 | + def size(self) -> tuple[int, ...]: |
| 68 | + """Return the size of the bounds in each dimension. |
| 69 | +
|
| 70 | + Example: |
| 71 | + >>> bounds = Bounds((50, 25), (150, 100)) |
| 72 | + >>> width, height = bounds.size |
| 73 | + >>> width |
| 74 | + 100 |
| 75 | + >>> height |
| 76 | + 75 |
| 77 | + """ |
| 78 | + return tuple( |
| 79 | + self.end_point[i] - self.start_point[i] |
| 80 | + for i in range(self.num_dimensions) |
| 81 | + ) |
| 82 | + |
| 83 | + @cached_property |
| 84 | + def num_dimensions(self) -> int: |
| 85 | + """Return the number of dimensions. |
| 86 | +
|
| 87 | + Example: |
| 88 | + >>> bounds = Bounds((50, 25), (150, 100)) |
| 89 | + >>> bounds.num_dimensions |
| 90 | + 2 |
| 91 | + """ |
| 92 | + return len(self.start_point) |
| 93 | + |
| 94 | + def offset(self, *coordinates): |
| 95 | + """Return the offset of the given coordinates from the start point. |
| 96 | +
|
| 97 | + Example: |
| 98 | + >>> bounds = Bounds((50, 25), (150, 100)) |
| 99 | + >>> for x, y in bounds: |
| 100 | + ... print(bounds.offset(x, y)) # doctest: +ELLIPSIS |
| 101 | + (0, 0) |
| 102 | + (0, 1) |
| 103 | + (0, 2) |
| 104 | + ... |
| 105 | + (99, 72) |
| 106 | + (99, 73) |
| 107 | + (99, 74) |
| 108 | + """ |
| 109 | + return tuple( |
| 110 | + coordinates[i] - self.start_point[i] |
| 111 | + for i in range(self.num_dimensions) |
| 112 | + ) |
| 113 | + |
| 114 | + |
| 115 | +def split_multi(num_chunks: int, *dimensions: int) -> Iterator[Bounds]: |
| 116 | + """Return a sequence of n-dimensional slices.""" |
| 117 | + num_chunks_along_axis = find_most_even(num_chunks, len(dimensions)) |
| 118 | + for slices_by_dimension in product( |
| 119 | + *starmap(get_slices, zip(dimensions, num_chunks_along_axis)) |
| 120 | + ): |
| 121 | + yield Bounds( |
| 122 | + start_point=tuple(s.start for s in slices_by_dimension), |
| 123 | + end_point=tuple(s.stop for s in slices_by_dimension), |
| 124 | + ) |
| 125 | + |
| 126 | + |
| 127 | +def get_slices(length: int, num_chunks: int) -> Iterator[slice]: |
| 128 | + """Return a sequence of slices for the given length.""" |
| 129 | + chunk_size, remaining = divmod(length, num_chunks) |
| 130 | + for i in range(num_chunks): |
| 131 | + begin = i * chunk_size + min(i, remaining) |
| 132 | + end = (i + 1) * chunk_size + min(i + 1, remaining) |
| 133 | + yield slice(begin, end) |
| 134 | + |
| 135 | + |
| 136 | +def find_most_even(number: int, num_factors: int): |
| 137 | + """Return the most even tuple of integer divisors of a number.""" |
| 138 | + products_by_sum = { |
| 139 | + sum(products): products |
| 140 | + for products in find_products(number, num_factors) |
| 141 | + } |
| 142 | + return products_by_sum[min(products_by_sum)] |
| 143 | + |
| 144 | + |
| 145 | +def find_products(number: int, num_factors: int) -> Iterator[tuple[int, ...]]: |
| 146 | + """Return all possible products of a number.""" |
| 147 | + divisors = find_divisors(number) |
| 148 | + for factors in combinations_with_replacement(divisors, num_factors): |
| 149 | + if prod(factors) == number: |
| 150 | + yield factors |
| 151 | + |
| 152 | + |
| 153 | +def find_divisors(number: int) -> set[int]: |
| 154 | + """Return unique integer divisors of a number.""" |
| 155 | + divisors = {1, number} |
| 156 | + for divisor in range(2, floor(sqrt(number)) + 1): |
| 157 | + factor, remainder = divmod(number, divisor) |
| 158 | + if remainder == 0: |
| 159 | + divisors.add(divisor) |
| 160 | + divisors.add(factor) |
| 161 | + return divisors |
| 162 | + |
| 163 | + |
| 164 | +if __name__ == "__main__": |
| 165 | + import doctest |
| 166 | + |
| 167 | + doctest.testmod() |
0 commit comments