@@ -27,17 +27,26 @@ def fit_background(spot_colors: np.ndarray, weight_shift: float = 0) -> Tuple[np
2727 - background_vectors `float [n_channels x n_rounds x n_channels]`.
2828 background_vectors[c] is the background vector for channel c.
2929 """
30- weight_shift = np .clip (weight_shift , 1e-20 , np .inf ) # ensure weight_shift > 1e-20 to avoid blow up to infinity.
30+ # Preserve the spot colours datatype throughout.
31+ dtype = spot_colors .dtype
32+ weight_shift = np .clip (
33+ weight_shift , 1e-20 , np .inf , dtype = dtype
34+ ) # ensure weight_shift > 1e-20 to avoid blow up to infinity.
3135
3236 n_rounds , n_channels = spot_colors [0 ].shape
33- background_vectors = np .repeat (np .expand_dims (np .eye (n_channels ), axis = 1 ), n_rounds , axis = 1 )
37+ background_vectors = np .repeat (np .expand_dims (np .eye (n_channels ), axis = 1 ), n_rounds , axis = 1 ). astype ( dtype )
3438 # give background_vectors an L2 norm of 1 so can compare coefficients with other genes.
3539 background_vectors = background_vectors / np .linalg .norm (background_vectors , axis = (1 , 2 ), keepdims = True )
3640
3741 weight_factor = 1 / (np .abs (spot_colors ) + weight_shift )
3842 spot_weight = spot_colors * weight_factor
39- background_weight = np .ones ((1 , n_rounds , n_channels )) * background_vectors [0 , 0 , 0 ] * weight_factor
43+ background_weight = np .ones ((1 , n_rounds , n_channels ), dtype = dtype ) * background_vectors [0 , 0 , 0 ] * weight_factor
44+ # Avoid overflow from squaring the background_weight.
45+ background_weight = np .clip (background_weight , None , np .sqrt (np .finfo (dtype ).max ))
4046 coef = np .sum (spot_weight * background_weight , axis = 1 ) / np .sum (background_weight ** 2 , axis = 1 )
41- residual = spot_colors - np .expand_dims (coef , 1 ) * np .ones ((1 , n_rounds , n_channels )) * background_vectors [0 , 0 , 0 ]
47+ residual = (
48+ spot_colors
49+ - np .expand_dims (coef , 1 ) * np .ones ((1 , n_rounds , n_channels ), dtype = dtype ) * background_vectors [0 , 0 , 0 ]
50+ )
4251
4352 return residual , coef , background_vectors
0 commit comments