|
1 | | -from typing import Sized, Union, Optional |
| 1 | +from typing import Sized, Union |
2 | 2 |
|
3 | 3 | import numpy as np |
4 | 4 | from PIL import Image |
@@ -127,18 +127,24 @@ def pil2ndarray(image: Union[Image.Image, np.ndarray]): |
127 | 127 | def pad2square( |
128 | 128 | image: Image.Image, |
129 | 129 | size: int, |
130 | | - fill_color: Optional[Union[str, int, tuple[int, ...]]] = None, |
| 130 | + fill_color: Union[str, int, tuple[int, ...]] = 0, |
131 | 131 | ) -> Image.Image: |
132 | 132 | height, width = image.height, image.width |
133 | 133 |
|
134 | | - # if the size is larger than the new canvas |
135 | | - if width > size or height > size: |
| 134 | + left, right = 0, width |
| 135 | + top, bottom = 0, height |
| 136 | + |
| 137 | + crop_required = False |
| 138 | + if width > size: |
136 | 139 | left = (width - size) // 2 |
137 | | - top = (height - size) // 2 |
138 | 140 | right = left + size |
| 141 | + crop_required = True |
| 142 | + |
| 143 | + if height > size: |
| 144 | + top = (height - size) // 2 |
139 | 145 | bottom = top + size |
140 | | - image = image.crop((left, top, right, bottom)) |
| 146 | + crop_required = True |
141 | 147 |
|
142 | | - new_image = Image.new(mode="RGB", size=(size, size), color=fill_color or 0) |
143 | | - new_image.paste(image) |
| 148 | + new_image = Image.new(mode="RGB", size=(size, size), color=fill_color) |
| 149 | + new_image.paste(image.crop((left, top, right, bottom)) if crop_required else image) |
144 | 150 | return new_image |
0 commit comments