Skip to content

Commit 2a17422

Browse files
prajjwal1fmassa
authored andcommitted
Added Training Sample code for fasterrcnn_resnet50_fpn (#1695)
1 parent d2c763e commit 2a17422

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

torchvision/models/detection/faster_rcnn.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,16 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True,
318318
Example::
319319
320320
>>> model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
321-
>>> model.eval()
321+
>>> images,boxes,labels = torch.rand(4,3,600,1200), torch.rand(4,11,4), torch.rand(4,11) # For Training
322+
>>> images = list(image for image in images)
323+
>>> targets = []
324+
>>> for i in range(len(images)):
325+
>>> d = {}
326+
>>> d['boxes'] = boxes[i]
327+
>>> d['labels'] = labels[i].type(torch.int64)
328+
>>> targets.append(d)
329+
>>> output = model(images,targets)
330+
>>> model.eval() # For inference
322331
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
323332
>>> predictions = model(x)
324333

0 commit comments

Comments
 (0)