Skip to content

Commit 249326b

Browse files
Zhitao Yufacebook-github-bot
authored andcommitted
Docstring Fix for PILToTensor in Torchvision (pytorch#9254)
Summary: pytorch#9221 identifies a confusion around image shape conventions for ToTensor and PILToTensor classes. The docstring has the following statement: Converts a PIL Image (H x W x C) to a Tensor of shape (C x H x W). This is confusing since PIL Image shape is not (H x W x C) but rather PIL Images expose their size as (W, H) via the size attribute, not as a shape tuple. Proposed Docstring Update Convert a PIL Image with H height, W width, and C channels to a Tensor of shape (C x H x W). Reviewed By: AntoineSimoulin Differential Revision: D85779518
1 parent 218d2ab commit 249326b

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

torchvision/transforms/transforms.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,15 @@ class PILToTensor:
145145
146146
This transform does not support torchscript.
147147
148-
Converts a PIL Image (H x W x C) to a Tensor of shape (C x H x W).
148+
Convert a PIL Image with H height, W width, and C channels to a Tensor of shape (C x H x W).
149+
150+
Example:
151+
>>> from PIL import Image
152+
>>> import torchvision.transforms as T
153+
>>> img = Image.new("RGB", (320, 240)) # size (W=320, H=240)
154+
>>> tensor = T.PILToTensor()(img)
155+
>>> print(tensor.shape)
156+
torch.Size([3, 240, 320])
149157
"""
150158

151159
def __init__(self) -> None:

torchvision/transforms/v2/_type_conversion.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,15 @@ class PILToTensor(Transform):
1515
1616
This transform does not support torchscript.
1717
18-
Converts a PIL Image (H x W x C) to a Tensor of shape (C x H x W).
18+
Convert a PIL Image with H height, W width, and C channels to a Tensor of shape (C x H x W).
19+
20+
Example:
21+
>>> from PIL import Image
22+
>>> from torchvision.transforms import v2
23+
>>> img = Image.new("RGB", (320, 240)) # size (W=320, H=240)
24+
>>> tensor = v2.PILToTensor()(img)
25+
>>> print(tensor.shape)
26+
torch.Size([3, 240, 320])
1927
"""
2028

2129
_transformed_types = (PIL.Image.Image,)

0 commit comments

Comments
 (0)