Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions test/test_backbone_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import random
from copy import deepcopy
from itertools import chain
from typing import Mapping, Sequence

Expand Down Expand Up @@ -322,3 +323,14 @@ def forward(self, x):
out = model(self.inp)
# And backward
out["leaf_module"].float().mean().backward()

def test_deepcopy(self):
# Non-regression test for https://github.com/pytorch/vision/issues/8634
model = models.efficientnet_b3(weights=None)
extractor = create_feature_extractor(model=model, return_nodes={"classifier.0": "out"})

extractor.eval()
extractor.train()
extractor = deepcopy(extractor)
extractor.eval()
extractor.train()
37 changes: 36 additions & 1 deletion torchvision/models/feature_extraction.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import inspect
import math
import re
Expand All @@ -10,7 +11,7 @@
import torch
import torchvision
from torch import fx, nn
from torch.fx.graph_module import _copy_attr
from torch.fx.graph_module import _CodeOnlyModule, _copy_attr, _USER_PRESERVED_ATTRIBUTES_KEY


__all__ = ["create_feature_extractor", "get_graph_node_names"]
Expand Down Expand Up @@ -330,6 +331,40 @@ def train(self, mode=True):
self.graph = self.eval_graph
return super().train(mode=mode)

def _deepcopy_init(self):
# See __deepcopy__ below
return DualGraphModule.__init__

def __deepcopy__(self, memo):
# Same as the base class' __deepcopy__ from pytorch, with minor
# modification to account for train_graph and eval_graph
# https://github.com/pytorch/pytorch/blob/f684dbd0026f98f8fa291cab74dbc4d61ba30580/torch/fx/graph_module.py#L875
#
# This is using a bunch of private stuff from torch, so if that breaks,
# we'll likely have to remove this, along with the associated
# non-regression test.
res = type(self).__new__(type(self))
memo[id(self)] = res
fake_mod = _CodeOnlyModule(copy.deepcopy(self.__dict__, memo))
self._deepcopy_init()(res, fake_mod, fake_mod.__dict__["train_graph"], fake_mod.__dict__["eval_graph"])

extra_preserved_attrs = [
"_state_dict_hooks",
"_load_state_dict_pre_hooks",
"_load_state_dict_post_hooks",
"_replace_hook",
"_create_node_hooks",
"_erase_node_hooks",
]
for attr in extra_preserved_attrs:
if attr in self.__dict__:
setattr(res, attr, copy.deepcopy(self.__dict__[attr], memo))
res.meta = copy.deepcopy(getattr(self, "meta", {}), memo)
if _USER_PRESERVED_ATTRIBUTES_KEY in res.meta:
for attr_name, attr in res.meta[_USER_PRESERVED_ATTRIBUTES_KEY].items():
setattr(res, attr_name, attr)
return res


def create_feature_extractor(
model: nn.Module,
Expand Down