Skip to content

Commit f137857

Browse files
committed
Circumvent deprecation warning
1 parent 9baf4c8 commit f137857

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

torchkde/kernels.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ def bandwidth(self, bandwidth):
171171
if type(bandwidth) == torch.Tensor:
172172
assert bandwidth.requires_grad == False, \
173173
"The bandwidth for the von Mises-Fisher kernel must not require gradients."
174+
bandwidth = bandwidth.item()
174175
assert type(bandwidth) == float or isinstance(bandwidth, torch.Tensor) and bandwidth.dim() == 0, \
175176
"The bandwidth for the von Mises-Fisher kernel must be a scalar."
176177
self._bandwidth = bandwidth

0 commit comments

Comments
 (0)