Skip to content

Commit ee6c3d0

Browse files
Support v2.functional.gaussian_blur backprop (#8486)
Co-authored-by: Nicolas Hug <[email protected]>
1 parent 15a69ca commit ee6c3d0

File tree

1 file changed

+3
-3
lines changed
  • torchvision/transforms/v2/functional

1 file changed

+3
-3
lines changed

torchvision/transforms/v2/functional/_misc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,9 @@ def gaussian_blur(inpt: torch.Tensor, kernel_size: List[int], sigma: Optional[Li
8484

8585

8686
def _get_gaussian_kernel1d(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
87-
lim = (kernel_size - 1) / (2.0 * math.sqrt(2.0) * sigma)
87+
lim = (kernel_size - 1) / (2.0 * math.sqrt(2.0))
8888
x = torch.linspace(-lim, lim, steps=kernel_size, dtype=dtype, device=device)
89-
kernel1d = torch.softmax(x.pow_(2).neg_(), dim=0)
89+
kernel1d = torch.softmax(x.div(sigma).pow(2).neg(), dim=0)
9090
return kernel1d
9191

9292

@@ -119,7 +119,7 @@ def gaussian_blur_image(
119119
if isinstance(sigma, (list, tuple)):
120120
length = len(sigma)
121121
if length == 1:
122-
s = float(sigma[0])
122+
s = sigma[0]
123123
sigma = [s, s]
124124
elif length != 2:
125125
raise ValueError(f"If sigma is a sequence, its length should be 2. Got {length}")

0 commit comments

Comments
 (0)