@@ -28,6 +28,7 @@ def bandwidth(self):
2828
2929 @bandwidth .setter
3030 def bandwidth (self , bandwidth ):
31+ self ._norm_constant = None # reset normalization constant when bandwidth changes
3132 self ._bandwidth = bandwidth
3233 # compute H^(-1/2)
3334 if check_if_mat (bandwidth ):
@@ -41,32 +42,38 @@ def norm_constant(self):
4142 assert self .dim is not None , "Dimension not set."
4243 self ._norm_constant = self ._compute_norm_constant (self .dim )
4344 return self ._norm_constant
45+
46+ def _bw_norm (self , dim ):
47+ if check_if_mat (self ._bandwidth ):
48+ return torch .sqrt (torch .det (self ._bandwidth ))
49+ else :
50+ return self ._bandwidth ** (dim / 2 )
4451
4552 @abstractmethod
4653 def _compute_norm_constant (self , dim ):
4754 pass
4855
4956 @abstractmethod
5057 def __call__ (self , x1 , x2 ):
58+ """This __call__ must be called in child classes to clear the norm constant cache and check the inputs."""
5159 assert self .bandwidth is not None , "Bandwidth not set."
60+ assert x1 .shape [- 1 ] == x2 .shape [- 1 ], "Input data must have the same dimensionality."
61+ new_dim = x1 .shape [- 1 ]
62+ if new_dim != self .dim : # first call or change of dimensionality
63+ self .dim = new_dim
64+ self ._norm_constant = None
5265
5366
5467class GaussianKernel (Kernel ):
5568 def __call__ (self , x1 , x2 ):
5669 super ().__call__ (x1 , x2 )
5770 differences = x1 - x2
58- self .dim = differences .shape [- 1 ]
59- u = kernel_input (self .inv_bandwidth , differences )
71+ u = compute_u (self .inv_bandwidth , differences )
6072
6173 return torch .exp (- u / 2 )
6274
6375 def _compute_norm_constant (self , dim ):
64- if check_if_mat (self ._bandwidth ):
65- # When bandwidth is a matrix, include sqrt(det(bandwidth))
66- bw_norm = torch .sqrt (torch .det (self ._bandwidth ))
67- else :
68- # When bandwidth is a scalar, raise it to the dim/2
69- bw_norm = self ._bandwidth ** (dim / 2 )
76+ bw_norm = self ._bw_norm (dim )
7077 return 1 / ((2 * math .pi )** (dim / 2 ) * bw_norm )
7178
7279
@@ -75,82 +82,72 @@ class TopHatKernel(Kernel):
7582 via a generalized Gaussian."""
7683 def __init__ (self , beta = 8 ):
7784 super ().__init__ ()
78- assert type (beta ) == int , "beta must be an integer."
85+ assert isinstance (beta , int ) , "beta must be an integer."
7986 self .beta = beta
8087
8188 def __call__ (self , x1 , x2 ):
8289 super ().__call__ (x1 , x2 )
8390 differences = x1 - x2
84- self .dim = differences .shape [- 1 ]
85- u = kernel_input (self .inv_bandwidth , differences )
91+ u = compute_u (self .inv_bandwidth , differences )
8692
8793 return torch .exp (- (u ** self .beta )/ 2 )
8894
8995 def _compute_norm_constant (self , dim ):
90- if check_if_mat (self ._bandwidth ):
91- # When bandwidth is a matrix, include sqrt(det(bandwidth))
92- bw_norm = torch .sqrt (torch .det (self ._bandwidth ))
93- else :
94- # When bandwidth is a scalar, raise it to the d/2
95- bw_norm = self ._bandwidth ** (dim / 2 )
96+ bw_norm = self ._bw_norm (dim )
9697 return (self .beta * gamma (dim / 2 ))/ (math .pi ** (dim / 2 ) * \
9798 gamma (dim / (2 * self .beta )) * 2 ** (dim / (2 * self .beta )) * bw_norm )
9899
99100
100101class EpanechnikovKernel (Kernel ):
101102 def __init__ (self ):
102103 super ().__init__ ()
103- self ._intrinsic_norm_constant = None
104+ self ._unit_ball_constant = None
104105
105106 def __call__ (self , x1 , x2 ):
107+ old_dim = self .dim
106108 super ().__call__ (x1 , x2 )
107109 differences = x1 - x2
108- self .dim = differences .shape [- 1 ]
109- c = self .intrinsic_norm_constant
110- u = kernel_input (self .inv_bandwidth , differences )
110+ if old_dim is not None and old_dim != differences .shape [- 1 ]:
111+ self ._unit_ball_constant = None
112+ c = self .unit_ball_constant
113+ u = compute_u (self .inv_bandwidth , differences )
111114
112115 return torch .where (u > 1 , 0 , c * (1 - u ))
113116
114- def _compute_intrinsic_norm_constant (self , dim ):
117+ @Kernel .bandwidth .setter
118+ def bandwidth (self , bandwidth ):
119+ Kernel .bandwidth .fset (self , bandwidth )
120+ self ._unit_ball_constant = None # reset the cached constant when bandwidth changes
121+
122+ def _compute_unit_ball_constant (self , dim ):
115123 return ((dim + 2 )* gamma (dim / 2 + 1 ))/ (2 * math .pi ** (dim / 2 ))
116124
117125 @property
118- def intrinsic_norm_constant (self ):
126+ def unit_ball_constant (self ):
119127 """Return the cached intrinsic normalization constant, computing it if necessary."""
120- if self ._intrinsic_norm_constant is None :
121- self ._intrinsic_norm_constant = self ._compute_intrinsic_norm_constant (self .dim )
122- return self ._intrinsic_norm_constant
128+ if self ._unit_ball_constant is None :
129+ self ._unit_ball_constant = self ._compute_unit_ball_constant (self .dim )
130+ return self ._unit_ball_constant
123131
124132 def _compute_norm_constant (self , dim ):
125- if check_if_mat (self ._bandwidth ):
126- # When bandwidth is a matrix, include sqrt(det(bandwidth))
127- bw_norm = torch .sqrt (torch .det (self ._bandwidth ))
128- else :
129- # When bandwidth is a scalar, raise it to the dim/2
130- bw_norm = self ._bandwidth ** (dim / 2 )
133+ bw_norm = self ._bw_norm (dim )
131134 return 1 / bw_norm
132135
133136
134137class ExponentialKernel (Kernel ):
135138 def __call__ (self , x1 , x2 ):
136139 super ().__call__ (x1 , x2 )
137140 differences = x1 - x2
138- self .dim = differences .shape [- 1 ]
139- u = kernel_input (self .inv_bandwidth , differences , exp = 1 )
141+ u = compute_u (self .inv_bandwidth , differences , exp = 1 )
140142
141143 return torch .exp (- u )
142144
143145 def _compute_norm_constant (self , dim ):
144- if check_if_mat (self ._bandwidth ):
145- # When bandwidth is a matrix, include sqrt(det(bandwidth))
146- bw_norm = torch .sqrt (torch .det (self ._bandwidth ))
147- else :
148- # When bandwidth is a scalar, raise it to the dim/2
149- bw_norm = self ._bandwidth ** (dim / 2 )
146+ bw_norm = self ._bw_norm (dim )
150147 return 1 / (2 ** dim * bw_norm )
151148
152149
153- def kernel_input (inv_bandwidth , x , exp = 2 ):
150+ def compute_u (inv_bandwidth , x , exp = 2 ):
154151 """Compute the input to the kernel function."""
155152 if exp >= 2 :
156153 if check_if_mat (inv_bandwidth ):
@@ -167,12 +164,13 @@ def kernel_input(inv_bandwidth, x, exp=2):
167164class VonMisesFisherKernel (Kernel ):
168165 @Kernel .bandwidth .setter
169166 def bandwidth (self , bandwidth ):
167+ Kernel .bandwidth .fset (self , bandwidth )
170168 # For vMF, the bandwidth is directly the concentration parameter.
171- if type (bandwidth ) == torch .Tensor :
172- assert bandwidth .requires_grad == False , \
169+ if isinstance (bandwidth , torch .Tensor ) :
170+ assert not bandwidth .requires_grad , \
173171 "The bandwidth for the von Mises-Fisher kernel must not require gradients."
174- bandwidth = bandwidth .item ()
175- assert type (bandwidth ) == float or isinstance (bandwidth , torch .Tensor ) and bandwidth .dim () == 0 , \
172+ bandwidth = bandwidth .item () # input to iv function cannot handle tensors
173+ assert isinstance (bandwidth , float ) or isinstance (bandwidth , torch .Tensor ) and bandwidth .dim () == 0 , \
176174 "The bandwidth for the von Mises-Fisher kernel must be a scalar."
177175 self ._bandwidth = bandwidth
178176 self .inv_bandwidth = bandwidth
@@ -183,7 +181,6 @@ def __call__(self, x1, x2):
183181 assert torch .allclose (
184182 x_all .norm (dim = - 1 ), torch .ones_like (x_all [..., 0 ]), atol = 1e-5
185183 ), "The von Mises-Fisher kernel assumes all data to lie on the unit sphere. Please normalize data."
186- self .dim = x1 .shape [- 1 ]
187184
188185 return torch .exp (self ._bandwidth * (x1 * x2 ).sum (dim = - 1 ))
189186
0 commit comments