Skip to content

Commit 7940146

Browse files
authored
Fix: pin_memory() preserves TVTensor class and metadata (#8921)
1 parent e239710 commit 7940146

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

test/test_tv_tensors.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,11 @@ def test_force_subclass_with_metadata(return_type):
162162
if return_type == "TVTensor":
163163
assert bbox.format, bbox.canvas_size == (format, canvas_size)
164164

165+
if torch.cuda.is_available():
166+
bbox = bbox.pin_memory()
167+
if return_type == "TVTensor":
168+
assert bbox.format, bbox.canvas_size == (format, canvas_size)
169+
165170
assert not bbox.requires_grad
166171
bbox.requires_grad_(True)
167172
if return_type == "TVTensor":

torchvision/tv_tensors/_torch_function_helpers.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,4 +69,10 @@ def _must_return_subclass():
6969

7070

7171
# For those ops we always want to preserve the original subclass instead of returning a pure Tensor
72-
_FORCE_TORCHFUNCTION_SUBCLASS = {torch.Tensor.clone, torch.Tensor.to, torch.Tensor.detach, torch.Tensor.requires_grad_}
72+
_FORCE_TORCHFUNCTION_SUBCLASS = {
73+
torch.Tensor.clone,
74+
torch.Tensor.to,
75+
torch.Tensor.detach,
76+
torch.Tensor.requires_grad_,
77+
torch.Tensor.pin_memory,
78+
}

0 commit comments

Comments
 (0)