1- from typing import Sized , Union
1+ from typing import Union
22
33import numpy as np
44from PIL import Image
55
6+ from fastembed .common .types import NumpyArray
7+
68
79def convert_to_rgb (image : Image .Image ) -> Image .Image :
810 if image .mode == "RGB" :
@@ -13,9 +15,9 @@ def convert_to_rgb(image: Image.Image) -> Image.Image:
1315
1416
1517def center_crop (
16- image : Union [Image .Image , np . ndarray ],
18+ image : Union [Image .Image , NumpyArray ],
1719 size : tuple [int , int ],
18- ) -> np . ndarray :
20+ ) -> NumpyArray :
1921 if isinstance (image , np .ndarray ):
2022 _ , orig_height , orig_width = image .shape
2123 else :
@@ -40,7 +42,7 @@ def center_crop(
4042 new_height = max (crop_height , orig_height )
4143 new_width = max (crop_width , orig_width )
4244 new_shape = image .shape [:- 2 ] + (new_height , new_width )
43- new_image = np .zeros_like (image , shape = new_shape )
45+ new_image = np .zeros_like (image , shape = new_shape , dtype = np . float32 )
4446
4547 top_pad = (new_height - orig_height ) // 2
4648 bottom_pad = top_pad + orig_height
@@ -61,37 +63,34 @@ def center_crop(
6163
6264
6365def normalize (
64- image : np .ndarray ,
65- mean : Union [float , np .ndarray ],
66- std : Union [float , np .ndarray ],
67- ) -> np .ndarray :
68- if not isinstance (image , np .ndarray ):
69- raise ValueError ("image must be a numpy array" )
70-
66+ image : NumpyArray ,
67+ mean : Union [float , list [float ]],
68+ std : Union [float , list [float ]],
69+ ) -> NumpyArray :
7170 num_channels = image .shape [1 ] if len (image .shape ) == 4 else image .shape [0 ]
7271
7372 if not np .issubdtype (image .dtype , np .floating ):
7473 image = image .astype (np .float32 )
7574
76- if isinstance (mean , Sized ):
77- if len ( mean ) != num_channels :
78- raise ValueError (
79- f"mean must have { num_channels } elements if it is an iterable, got { len ( mean ) } "
80- )
81- else :
82- mean = [ mean ] * num_channels
83- mean = np . array ( mean , dtype = image . dtype )
84-
85- if isinstance ( std , Sized ):
86- if len (std ) != num_channels :
87- raise ValueError (
88- f"std must have { num_channels } elements if it is an iterable, got { len ( std ) } "
89- )
90- else :
91- std = [ std ] * num_channels
92- std = np .array (std , dtype = image . dtype )
75+ mean = mean if isinstance (mean , list ) else [ mean ] * num_channels
76+
77+ if len ( mean ) != num_channels :
78+ raise ValueError (
79+ f"mean must have the same number of channels as the image, image has { num_channels } channels, got "
80+ f" { len ( mean ) } "
81+ )
82+
83+ mean_arr = np . array ( mean , dtype = np . float32 )
84+
85+ std = std if isinstance (std , list ) else [ std ] * num_channels
86+ if len ( std ) != num_channels :
87+ raise ValueError (
88+ f"std must have the same number of channels as the image, image has { num_channels } channels, got { len ( std ) } "
89+ )
90+
91+ std_arr = np .array (std , dtype = np . float32 )
9392
94- image = ((image .T - mean ) / std ).T
93+ image = ((image .T - mean_arr ) / std_arr ).T
9594 return image
9695
9796
@@ -114,11 +113,11 @@ def resize(
114113 return image .resize (new_size , resample )
115114
116115
117- def rescale (image : np . ndarray , scale : float , dtype : type = np .float32 ) -> np . ndarray :
116+ def rescale (image : NumpyArray , scale : float , dtype : type = np .float32 ) -> NumpyArray :
118117 return (image * scale ).astype (dtype )
119118
120119
121- def pil2ndarray (image : Union [Image .Image , np . ndarray ]) -> np . ndarray :
120+ def pil2ndarray (image : Union [Image .Image , NumpyArray ]) -> NumpyArray :
122121 if isinstance (image , Image .Image ):
123122 return np .asarray (image ).transpose ((2 , 0 , 1 ))
124123 return image
0 commit comments