11import warnings
2- from typing import Callable , Union
2+ from typing import Callable , Optional , Union
33
44import numpy as np
5- from numpy .typing import ArrayLike
65from scipy .special import logsumexp as LSE
76from sklearn .base import BaseEstimator
87from sklearn .utils .validation import check_is_fitted , check_random_state
@@ -33,19 +32,18 @@ class SparseKDE(BaseEstimator):
3332 weights: numpy.ndarray, default=None
3433 Weights of the descriptors.
3534 If None, all weights are set to `1/n_descriptors`.
36- metric : Callable[[ArrayLike, ArrayLike, bool, dict], ArrayLike],
37- default=:func:`skmatter.metrics.pairwise_euclidean_distances()`
35+ metric : Callable, default=None
3836 The metric to use. Your metric should be able to take at least three arguments
3937 in secquence: `X`, `Y`, and `squared=True`. Here, `X` and `Y` are two array-like
4038 of shape (n_samples, n_components). The return of the metric is an array-like of
41- shape (n_samples, n_samples). If you want to use periodic boundary
42- conditions, be sure to provide the cell size in the metric_params and
43- provide a metric that can take the cell argument.
39+ shape (n_samples, n_samples). If you want to use periodic boundary conditions,
40+ be sure to provide the cell size in the metric_params and provide a metric that
41+ can take the cell argument. If :obj:`None`, the
42+ :func:`skmatter.metrics.periodic_pairwise_euclidean_distances()` is used.
4443 metric_params : dict, default=None
45- Additional parameters to be passed to the use of
46- metric. i.e. the cell dimension for
47- :func:`skmatter.metrics.pairwise_euclidean_distances()`
48- `{'cell_length': [side_length_1, ..., side_length_n]}`
44+ Additional parameters to be passed to the use of metric. i.e. the cell
45+ dimension for :func:`skmatter.metrics.periodic_pairwise_euclidean_distances()`
46+ ``{'cell_length': [side_length_1, ..., side_length_n]}``
4947 fspread : float, default=-1.0
5048 The fractional "space" occupied by the voronoi cell of each grid. Use this when
5149 each cell is of a similar size.
@@ -106,11 +104,9 @@ class SparseKDE(BaseEstimator):
106104 def __init__ (
107105 self ,
108106 descriptors : np .ndarray ,
109- weights : Union [np .ndarray , None ] = None ,
110- metric : Callable [
111- [ArrayLike , ArrayLike , bool , dict ], ArrayLike
112- ] = periodic_pairwise_euclidean_distances ,
113- metric_params : Union [dict , None ] = None ,
107+ weights : Optional [np .ndarray ] = None ,
108+ metric : Optional [Callable ] = None ,
109+ metric_params : Optional [dict ] = None ,
114110 fspread : float = - 1.0 ,
115111 fpoints : float = 0.15 ,
116112 kernel : str = "gaussian" ,
@@ -119,6 +115,10 @@ def __init__(
119115 self .metric_params = (
120116 metric_params if metric_params is not None else {"cell_length" : None }
121117 )
118+
119+ if metric is None :
120+ metric = periodic_pairwise_euclidean_distances
121+
122122 self .metric = lambda X , Y : metric (X , Y , squared = True , ** self .metric_params )
123123 self .cell = metric_params ["cell_length" ] if metric_params is not None else None
124124 self ._check_dimension (descriptors )
0 commit comments