Skip to content

Commit 7be960f

Browse files
alykhantejanifmassa
authored andcommitted
Add docs to FiveCrop + TenCrop (#294)
* add docs to FiveCrop + TenCrop * fix typo in docstring
1 parent 174dbbd commit 7be960f

File tree

1 file changed

+83
-35
lines changed

1 file changed

+83
-35
lines changed

torchvision/transforms.py

Lines changed: 83 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -312,16 +312,28 @@ def five_crop(img, size):
312312
"""Crop the given PIL Image into four corners and the central crop.
313313
314314
.. Note::
315-
This transform returns a tuple of images and there may be a
316-
mismatch in the number of inputs and targets your ``Dataset`` returns.
315+
This transform returns a tuple of images and there may be a mismatch in the number of
316+
inputs and targets your Dataset returns. See below for an example of how to deal with
317+
this.
317318
318319
Args:
319-
size (sequence or int): Desired output size of the crop. If size is an
320-
int instead of sequence like (h, w), a square crop (size, size) is
321-
made.
320+
img (PIL Image): Image to be cropped.
321+
size (sequence or int): Desired output size of the crop. If size is an ``int``
322+
instead of sequence like (h, w), a square crop of size (size, size) is made.
323+
322324
Returns:
323325
tuple: tuple (tl, tr, bl, br, center) corresponding top left,
324326
top right, bottom left, bottom right and center crop.
327+
328+
Example:
329+
>>> def transform(img):
330+
>>> crops = five_crop(img, size) # this is a list of PIL Images
331+
>>> return torch.stack([to_tensor(crop) for crop in crops)]) # returns a 4D tensor
332+
>>> #In your test loop you can do the following:
333+
>>> input, target = batch # input is a 5d tensor, target is 2d
334+
>>> bs, ncrops, c, h, w = input.size()
335+
>>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
336+
>>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
325337
"""
326338
if isinstance(size, numbers.Number):
327339
size = (int(size), int(size))
@@ -342,24 +354,35 @@ def five_crop(img, size):
342354

343355

344356
def ten_crop(img, size, vertical_flip=False):
345-
"""Crop the given PIL Image into four corners and the central crop plus the
346-
flipped version of these (horizontal flipping is used by default).
357+
"""Crop the given PIL Image into four corners and the central crop plus the flipped version of
358+
these (horizontal flipping is used by default).
347359
348360
.. Note::
349-
This transform returns a tuple of images and there may be a
350-
mismatch in the number of inputs and targets your ``Dataset`` returns.
361+
This transform returns a tuple of images and there may be a mismatch in the number of
362+
inputs and targets your Dataset returns. See below for an example of how to deal with
363+
this.
351364
352-
Args:
353-
size (sequence or int): Desired output size of the crop. If size is an
354-
int instead of sequence like (h, w), a square crop (size, size) is
355-
made.
356-
vertical_flip (bool): Use vertical flipping instead of horizontal
365+
Args:
366+
img (PIL Image): Image to be cropped.
367+
size (sequence or int): Desired output size of the crop. If size is an ``int``
368+
instead of sequence like (h, w), a square crop of size (size, size) is made.
369+
vertical_flip (bool): Use vertical flipping instead of horizontal.
357370
358-
Returns:
359-
tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip,
360-
br_flip, center_flip) corresponding top left, top right,
361-
bottom left, bottom right and center crop and same for the
362-
flipped image.
371+
Returns:
372+
tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip,
373+
br_flip, center_flip) corresponding top left, top right,
374+
bottom left, bottom right and center crop and same for the
375+
flipped image.
376+
377+
Example:
378+
>>> def transform(img):
379+
>>> crops = ten_crop(img, size) # this is a list of PIL Images
380+
>>> return torch.stack([to_tensor(crop) for crop in crops)]) # returns a 4D tensor
381+
>>> #In your test loop you can do the following:
382+
>>> input, target = batch # input is a 5d tensor, target is 2d
383+
>>> bs, ncrops, c, h, w = input.size()
384+
>>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
385+
>>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
363386
"""
364387
if isinstance(size, numbers.Number):
365388
size = (int(size), int(size))
@@ -906,15 +929,27 @@ def __init__(self, *args, **kwargs):
906929

907930

908931
class FiveCrop(object):
909-
"""Crop the given PIL Image into four corners and the central crop.abs
932+
"""Crop the given PIL Image into four corners and the central crop
910933
911-
Note: this transform returns a tuple of images and there may be a mismatch in the number of
912-
inputs and targets your `Dataset` returns.
934+
.. Note::
935+
This transform returns a tuple of images and there may be a mismatch in the number of
936+
inputs and targets your Dataset returns. See below for an example of how to deal with
937+
this.
913938
914-
Args:
915-
size (sequence or int): Desired output size of the crop. If size is an
916-
int instead of sequence like (h, w), a square crop (size, size) is
917-
made.
939+
Args:
940+
size (sequence or int): Desired output size of the crop. If size is an ``int``
941+
instead of sequence like (h, w), a square crop of size (size, size) is made.
942+
943+
Example:
944+
>>> transform = Compose([
945+
>>> FiveCrop(size), # this is a list of PIL Images
946+
>>> Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor
947+
>>> ])
948+
>>> #In your test loop you can do the following:
949+
>>> input, target = batch # input is a 5d tensor, target is 2d
950+
>>> bs, ncrops, c, h, w = input.size()
951+
>>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
952+
>>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
918953
"""
919954

920955
def __init__(self, size):
@@ -930,17 +965,30 @@ def __call__(self, img):
930965

931966

932967
class TenCrop(object):
933-
"""Crop the given PIL Image into four corners and the central crop plus the
934-
flipped version of these (horizontal flipping is used by default)
968+
"""Crop the given PIL Image into four corners and the central crop plus the flipped version of
969+
these (horizontal flipping is used by default)
935970
936-
Note: this transform returns a tuple of images and there may be a mismatch in the number of
937-
inputs and targets your `Dataset` returns.
971+
.. Note::
972+
This transform returns a tuple of images and there may be a mismatch in the number of
973+
inputs and targets your Dataset returns. See below for an example of how to deal with
974+
this.
938975
939-
Args:
940-
size (sequence or int): Desired output size of the crop. If size is an
941-
int instead of sequence like (h, w), a square crop (size, size) is
942-
made.
943-
vertical_flip(bool): Use vertical flipping instead of horizontal
976+
Args:
977+
size (sequence or int): Desired output size of the crop. If size is an
978+
int instead of sequence like (h, w), a square crop (size, size) is
979+
made.
980+
vertical_flip(bool): Use vertical flipping instead of horizontal
981+
982+
Example:
983+
>>> transform = Compose([
984+
>>> TenCrop(size), # this is a list of PIL Images
985+
>>> Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor
986+
>>> ])
987+
>>> #In your test loop you can do the following:
988+
>>> input, target = batch # input is a 5d tensor, target is 2d
989+
>>> bs, ncrops, c, h, w = input.size()
990+
>>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
991+
>>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
944992
"""
945993

946994
def __init__(self, size, vertical_flip=False):

0 commit comments

Comments
 (0)