Skip to content

Commit a40d6b9

Browse files
committed
Fix: pin_memory() preserves TVTensor class and metadata
1 parent fab1188 commit a40d6b9

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

test/test_tv_tensors.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,10 @@ def test_force_subclass_with_metadata(return_type):
162162
if return_type == "TVTensor":
163163
assert bbox.format, bbox.canvas_size == (format, canvas_size)
164164

165+
bbox = bbox.pin_memory()
166+
if return_type == "TVTensor":
167+
assert bbox.format, bbox.canvas_size == (format, canvas_size)
168+
165169
assert not bbox.requires_grad
166170
bbox.requires_grad_(True)
167171
if return_type == "TVTensor":

torchvision/tv_tensors/_torch_function_helpers.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,4 +69,10 @@ def _must_return_subclass():
6969

7070

7171
# For those ops we always want to preserve the original subclass instead of returning a pure Tensor
72-
_FORCE_TORCHFUNCTION_SUBCLASS = {torch.Tensor.clone, torch.Tensor.to, torch.Tensor.detach, torch.Tensor.requires_grad_}
72+
_FORCE_TORCHFUNCTION_SUBCLASS = {
73+
torch.Tensor.clone,
74+
torch.Tensor.to,
75+
torch.Tensor.detach,
76+
torch.Tensor.requires_grad_,
77+
torch.Tensor.pin_memory,
78+
}

0 commit comments

Comments
 (0)