Skip to content

Commit 4417e18

Browse files
vfdev-5fmassa
authored andcommitted
Fixed device when used in _gen_affine_grid and _perspective_grid (#2813)
1 parent 342a3d8 commit 4417e18

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

torchvision/transforms/functional_tensor.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -929,8 +929,10 @@ def _gen_affine_grid(
929929

930930
d = 0.5
931931
base_grid = torch.empty(1, oh, ow, 3, dtype=theta.dtype, device=theta.device)
932-
base_grid[..., 0].copy_(torch.linspace(-ow * 0.5 + d, ow * 0.5 + d - 1, steps=ow))
933-
base_grid[..., 1].copy_(torch.linspace(-oh * 0.5 + d, oh * 0.5 + d - 1, steps=oh).unsqueeze_(-1))
932+
x_grid = torch.linspace(-ow * 0.5 + d, ow * 0.5 + d - 1, steps=ow, device=theta.device)
933+
base_grid[..., 0].copy_(x_grid)
934+
y_grid = torch.linspace(-oh * 0.5 + d, oh * 0.5 + d - 1, steps=oh, device=theta.device).unsqueeze_(-1)
935+
base_grid[..., 1].copy_(y_grid)
934936
base_grid[..., 2].fill_(1)
935937

936938
rescaled_theta = theta.transpose(1, 2) / torch.tensor([0.5 * w, 0.5 * h], dtype=theta.dtype, device=theta.device)
@@ -1065,8 +1067,10 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype,
10651067

10661068
d = 0.5
10671069
base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
1068-
base_grid[..., 0].copy_(torch.linspace(d, ow * 1.0 + d - 1.0, steps=ow))
1069-
base_grid[..., 1].copy_(torch.linspace(d, oh * 1.0 + d - 1.0, steps=oh).unsqueeze_(-1))
1070+
x_grid = torch.linspace(d, ow * 1.0 + d - 1.0, steps=ow, device=device)
1071+
base_grid[..., 0].copy_(x_grid)
1072+
y_grid = torch.linspace(d, oh * 1.0 + d - 1.0, steps=oh, device=device).unsqueeze_(-1)
1073+
base_grid[..., 1].copy_(y_grid)
10701074
base_grid[..., 2].fill_(1)
10711075

10721076
rescaled_theta1 = theta1.transpose(1, 2) / torch.tensor([0.5 * ow, 0.5 * oh], dtype=dtype, device=device)

0 commit comments

Comments
 (0)