Skip to content

Commit 300019d

Browse files
authored
Fix lint following #1695 (#1713)
* Fix lint following #1695 * V2 * V3
1 parent 06cbdb5 commit 300019d

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

torchvision/models/detection/faster_rcnn.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -318,16 +318,19 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True,
318318
Example::
319319
320320
>>> model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
321-
>>> images,boxes,labels = torch.rand(4,3,600,1200), torch.rand(4,11,4), torch.rand(4,11) # For Training
321+
>>> # For training
322+
>>> images, boxes = torch.rand(4, 3, 600, 1200), torch.rand(4, 11, 4)
323+
>>> labels = torch.randint(1, 91, (4, 11))
322324
>>> images = list(image for image in images)
323-
>>> targets = []
325+
>>> targets = []
324326
>>> for i in range(len(images)):
325327
>>> d = {}
326328
>>> d['boxes'] = boxes[i]
327-
>>> d['labels'] = labels[i].type(torch.int64)
329+
>>> d['labels'] = labels[i]
328330
>>> targets.append(d)
329-
>>> output = model(images,targets)
330-
>>> model.eval() # For inference
331+
>>> output = model(images, targets)
332+
>>> # For inference
333+
>>> model.eval()
331334
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
332335
>>> predictions = model(x)
333336

0 commit comments

Comments
 (0)