@@ -688,50 +688,50 @@ def write_png(self, fname):
688688 bytes = True , norm = True )
689689 PIL .Image .fromarray (im ).save (fname , format = "png" )
690690
691- def set_data (self , A ):
691+ @staticmethod
692+ def _normalize_image_array (A ):
692693 """
693- Set the image array.
694-
695- Note that this function does *not* update the normalization used.
696-
697- Parameters
698- ----------
699- A : array-like or `PIL.Image.Image`
694+ Check validity of image-like input *A* and normalize it to a format suitable for
695+ Image subclasses.
700696 """
701- if isinstance (A , PIL .Image .Image ):
702- A = pil_to_array (A ) # Needed e.g. to apply png palette.
703- self ._A = cbook .safe_masked_invalid (A , copy = True )
704-
705- if (self ._A .dtype != np .uint8 and
706- not np .can_cast (self ._A .dtype , float , "same_kind" )):
707- raise TypeError (f"Image data of dtype { self ._A .dtype } cannot be "
708- "converted to float" )
709-
710- if self ._A .ndim == 3 and self ._A .shape [- 1 ] == 1 :
711- # If just one dimension assume scalar and apply colormap
712- self ._A = self ._A [:, :, 0 ]
713-
714- if not (self ._A .ndim == 2
715- or self ._A .ndim == 3 and self ._A .shape [- 1 ] in [3 , 4 ]):
716- raise TypeError (f"Invalid shape { self ._A .shape } for image data" )
717-
718- if self ._A .ndim == 3 :
697+ A = cbook .safe_masked_invalid (A , copy = True )
698+ if A .dtype != np .uint8 and not np .can_cast (A .dtype , float , "same_kind" ):
699+ raise TypeError (f"Image data of dtype { A .dtype } cannot be "
700+ f"converted to float" )
701+ if A .ndim == 3 and A .shape [- 1 ] == 1 :
702+ A = A .squeeze (- 1 ) # If just (M, N, 1), assume scalar and apply colormap.
703+ if not (A .ndim == 2 or A .ndim == 3 and A .shape [- 1 ] in [3 , 4 ]):
704+ raise TypeError (f"Invalid shape { A .shape } for image data" )
705+ if A .ndim == 3 :
719706 # If the input data has values outside the valid range (after
720707 # normalisation), we issue a warning and then clip X to the bounds
721708 # - otherwise casting wraps extreme values, hiding outliers and
722709 # making reliable interpretation impossible.
723- high = 255 if np .issubdtype (self . _A .dtype , np .integer ) else 1
724- if self . _A . min () < 0 or high < self . _A .max ():
710+ high = 255 if np .issubdtype (A .dtype , np .integer ) else 1
711+ if A . min () < 0 or high < A .max ():
725712 _log .warning (
726713 'Clipping input data to the valid range for imshow with '
727714 'RGB data ([0..1] for floats or [0..255] for integers).'
728715 )
729- self . _A = np .clip (self . _A , 0 , high )
716+ A = np .clip (A , 0 , high )
730717 # Cast unsupported integer types to uint8
731- if self . _A . dtype != np .uint8 and np .issubdtype (self . _A . dtype ,
732- np .integer ):
733- self . _A = self . _A . astype ( np . uint8 )
718+ if A . dtype != np .uint8 and np .issubdtype (A . dtype , np . integer ):
719+ A = A . astype ( np .uint8 )
720+ return A
734721
722+ def set_data (self , A ):
723+ """
724+ Set the image array.
725+
726+ Note that this function does *not* update the normalization used.
727+
728+ Parameters
729+ ----------
730+ A : array-like or `PIL.Image.Image`
731+ """
732+ if isinstance (A , PIL .Image .Image ):
733+ A = pil_to_array (A ) # Needed e.g. to apply png palette.
734+ self ._A = self ._normalize_image_array (A )
735735 self ._imcache = None
736736 self .stale = True
737737
@@ -1149,23 +1149,15 @@ def set_data(self, x, y, A):
11491149 (M, N) `~numpy.ndarray` or masked array of values to be
11501150 colormapped, or (M, N, 3) RGB array, or (M, N, 4) RGBA array.
11511151 """
1152+ A = self ._normalize_image_array (A )
11521153 x = np .array (x , np .float32 )
11531154 y = np .array (y , np .float32 )
1154- A = cbook .safe_masked_invalid (A , copy = True )
1155- if not (x .ndim == y .ndim == 1 and A .shape [0 :2 ] == y .shape + x .shape ):
1155+ if not (x .ndim == y .ndim == 1 and A .shape [:2 ] == y .shape + x .shape ):
11561156 raise TypeError ("Axes don't match array shape" )
1157- if A .ndim not in [2 , 3 ]:
1158- raise TypeError ("Can only plot 2D or 3D data" )
1159- if A .ndim == 3 and A .shape [2 ] not in [1 , 3 , 4 ]:
1160- raise TypeError ("3D arrays must have three (RGB) "
1161- "or four (RGBA) color components" )
1162- if A .ndim == 3 and A .shape [2 ] == 1 :
1163- A = A .squeeze (axis = - 1 )
11641157 self ._A = A
11651158 self ._Ax = x
11661159 self ._Ay = y
11671160 self ._imcache = None
1168-
11691161 self .stale = True
11701162
11711163 def set_array (self , * args ):
@@ -1307,36 +1299,20 @@ def set_data(self, x, y, A):
13071299 - (M, N, 3): RGB array
13081300 - (M, N, 4): RGBA array
13091301 """
1310- A = cbook .safe_masked_invalid (A , copy = True )
1311- if x is None :
1312- x = np .arange (0 , A .shape [1 ]+ 1 , dtype = np .float64 )
1313- else :
1314- x = np .array (x , np .float64 ).ravel ()
1315- if y is None :
1316- y = np .arange (0 , A .shape [0 ]+ 1 , dtype = np .float64 )
1317- else :
1318- y = np .array (y , np .float64 ).ravel ()
1319-
1320- if A .shape [:2 ] != (y .size - 1 , x .size - 1 ):
1302+ A = self ._normalize_image_array (A )
1303+ x = np .arange (0. , A .shape [1 ] + 1 ) if x is None else np .array (x , float ).ravel ()
1304+ y = np .arange (0. , A .shape [0 ] + 1 ) if y is None else np .array (y , float ).ravel ()
1305+ if A .shape [:2 ] != (y .size - 1 , x .size - 1 ):
13211306 raise ValueError (
13221307 "Axes don't match array shape. Got %s, expected %s." %
13231308 (A .shape [:2 ], (y .size - 1 , x .size - 1 )))
1324- if A .ndim not in [2 , 3 ]:
1325- raise ValueError ("A must be 2D or 3D" )
1326- if A .ndim == 3 :
1327- if A .shape [2 ] == 1 :
1328- A = A .squeeze (axis = - 1 )
1329- elif A .shape [2 ] not in [3 , 4 ]:
1330- raise ValueError ("3D arrays must have RGB or RGBA as last dim" )
1331-
13321309 # For efficient cursor readout, ensure x and y are increasing.
13331310 if x [- 1 ] < x [0 ]:
13341311 x = x [::- 1 ]
13351312 A = A [:, ::- 1 ]
13361313 if y [- 1 ] < y [0 ]:
13371314 y = y [::- 1 ]
13381315 A = A [::- 1 ]
1339-
13401316 self ._A = A
13411317 self ._Ax = x
13421318 self ._Ay = y
0 commit comments