Skip to content

Commit 57b4cc7

Browse files
feat(hub): loading a custom model with torch.hub.load (Megvii-BaseDetection#1396)
feat(hub): loading a custom model with `torch.hub.load`
1 parent ce4b996 commit 57b4cc7

File tree

2 files changed

+47
-24
lines changed

2 files changed

+47
-24
lines changed

hubconf.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
Usage example:
66
import torch
77
model = torch.hub.load("Megvii-BaseDetection/YOLOX", "yolox_s")
8+
model = torch.hub.load("Megvii-BaseDetection/YOLOX", "yolox_custom",
9+
exp_path="exp.py", ckpt_path="ckpt.pth")
810
"""
911
dependencies = ["torch"]
1012

@@ -16,4 +18,5 @@
1618
yolox_l,
1719
yolox_x,
1820
yolov3,
21+
yolox_custom
1922
)

yolox/models/build.py

Lines changed: 44 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"yolox_l",
1515
"yolox_x",
1616
"yolov3",
17+
"yolox_custom"
1718
]
1819

1920
_CKPT_ROOT_URL = "https://github.com/Megvii-BaseDetection/YOLOX/releases/download"
@@ -28,16 +29,20 @@
2829
}
2930

3031

31-
def create_yolox_model(
32-
name: str, pretrained: bool = True, num_classes: int = 80, device=None
33-
) -> nn.Module:
32+
def create_yolox_model(name: str, pretrained: bool = True, num_classes: int = 80, device=None,
33+
exp_path: str = None, ckpt_path: str = None) -> nn.Module:
3434
"""creates and loads a YOLOX model
3535
3636
Args:
37-
name (str): name of model. for example, "yolox-s", "yolox-tiny".
37+
name (str): name of model. for example, "yolox-s", "yolox-tiny" or "yolox_custom"
38+
if you want to load your own model.
3839
pretrained (bool): load pretrained weights into the model. Default to True.
39-
num_classes (int): number of model classes. Defalut to 80.
40-
device (str): default device to for model. Defalut to None.
40+
device (str): default device to for model. Default to None.
41+
num_classes (int): number of model classes. Default to 80.
42+
exp_path (str): path to your own experiment file. Required if name="yolox_custom"
43+
ckpt_path (str): path to your own ckpt. Required if name="yolox_custom" and you want to
44+
load a pretrained model
45+
4146
4247
Returns:
4348
YOLOX model (nn.Module)
@@ -48,44 +53,59 @@ def create_yolox_model(
4853
device = "cuda:0" if torch.cuda.is_available() else "cpu"
4954
device = torch.device(device)
5055

51-
assert name in _CKPT_FULL_PATH, f"user should use one of value in {_CKPT_FULL_PATH.keys()}"
52-
exp: Exp = get_exp(exp_name=name)
53-
exp.num_classes = num_classes
54-
yolox_model = exp.get_model()
55-
if pretrained and num_classes == 80:
56-
weights_url = _CKPT_FULL_PATH[name]
57-
ckpt = load_state_dict_from_url(weights_url, map_location="cpu")
58-
if "model" in ckpt:
59-
ckpt = ckpt["model"]
60-
yolox_model.load_state_dict(ckpt)
56+
assert name in _CKPT_FULL_PATH or name == "yolox_custom", \
57+
f"user should use one of value in {_CKPT_FULL_PATH.keys()} or \"yolox_custom\""
58+
if name in _CKPT_FULL_PATH:
59+
exp: Exp = get_exp(exp_name=name)
60+
exp.num_classes = num_classes
61+
yolox_model = exp.get_model()
62+
if pretrained and num_classes == 80:
63+
weights_url = _CKPT_FULL_PATH[name]
64+
ckpt = load_state_dict_from_url(weights_url, map_location="cpu")
65+
if "model" in ckpt:
66+
ckpt = ckpt["model"]
67+
yolox_model.load_state_dict(ckpt)
68+
else:
69+
assert exp_path is not None, "for a \"yolox_custom\" model exp_path must be provided"
70+
exp: Exp = get_exp(exp_file=exp_path)
71+
yolox_model = exp.get_model()
72+
if ckpt_path:
73+
ckpt = torch.load(ckpt_path, map_location="cpu")
74+
if "model" in ckpt:
75+
ckpt = ckpt["model"]
76+
yolox_model.load_state_dict(ckpt)
6177

6278
yolox_model.to(device)
6379
return yolox_model
6480

6581

66-
def yolox_nano(pretrained=True, num_classes=80, device=None):
82+
def yolox_nano(pretrained: bool = True, num_classes: int = 80, device: str = None) -> nn.Module:
6783
return create_yolox_model("yolox-nano", pretrained, num_classes, device)
6884

6985

70-
def yolox_tiny(pretrained=True, num_classes=80, device=None):
86+
def yolox_tiny(pretrained: bool = True, num_classes: int = 80, device: str = None) -> nn.Module:
7187
return create_yolox_model("yolox-tiny", pretrained, num_classes, device)
7288

7389

74-
def yolox_s(pretrained=True, num_classes=80, device=None):
90+
def yolox_s(pretrained: bool = True, num_classes: int = 80, device: str = None) -> nn.Module:
7591
return create_yolox_model("yolox-s", pretrained, num_classes, device)
7692

7793

78-
def yolox_m(pretrained=True, num_classes=80, device=None):
94+
def yolox_m(pretrained: bool = True, num_classes: int = 80, device: str = None) -> nn.Module:
7995
return create_yolox_model("yolox-m", pretrained, num_classes, device)
8096

8197

82-
def yolox_l(pretrained=True, num_classes=80, device=None):
98+
def yolox_l(pretrained: bool = True, num_classes: int = 80, device: str = None) -> nn.Module:
8399
return create_yolox_model("yolox-l", pretrained, num_classes, device)
84100

85101

86-
def yolox_x(pretrained=True, num_classes=80, device=None):
102+
def yolox_x(pretrained: bool = True, num_classes: int = 80, device: str = None) -> nn.Module:
87103
return create_yolox_model("yolox-x", pretrained, num_classes, device)
88104

89105

90-
def yolov3(pretrained=True, num_classes=80, device=None):
91-
return create_yolox_model("yolox-tiny", pretrained, num_classes, device)
106+
def yolov3(pretrained: bool = True, num_classes: int = 80, device: str = None) -> nn.Module:
107+
return create_yolox_model("yolov3", pretrained, num_classes, device)
108+
109+
110+
def yolox_custom(ckpt_path: str = None, exp_path: str = None, device: str = None) -> nn.Module:
111+
return create_yolox_model("yolox_custom", ckpt_path=ckpt_path, exp_path=exp_path, device=device)

0 commit comments

Comments
 (0)