1414 "yolox_l" ,
1515 "yolox_x" ,
1616 "yolov3" ,
17+ "yolox_custom"
1718]
1819
1920_CKPT_ROOT_URL = "https://github.com/Megvii-BaseDetection/YOLOX/releases/download"
2829}
2930
3031
31- def create_yolox_model (
32- name : str , pretrained : bool = True , num_classes : int = 80 , device = None
33- ) -> nn .Module :
32+ def create_yolox_model (name : str , pretrained : bool = True , num_classes : int = 80 , device = None ,
33+ exp_path : str = None , ckpt_path : str = None ) -> nn .Module :
3434 """creates and loads a YOLOX model
3535
3636 Args:
37- name (str): name of model. for example, "yolox-s", "yolox-tiny".
37+ name (str): name of model. for example, "yolox-s", "yolox-tiny" or "yolox_custom"
38+ if you want to load your own model.
3839 pretrained (bool): load pretrained weights into the model. Default to True.
39- num_classes (int): number of model classes. Defalut to 80.
40- device (str): default device to for model. Defalut to None.
40+ device (str): default device to for model. Default to None.
41+ num_classes (int): number of model classes. Default to 80.
42+ exp_path (str): path to your own experiment file. Required if name="yolox_custom"
43+ ckpt_path (str): path to your own ckpt. Required if name="yolox_custom" and you want to
44+ load a pretrained model
45+
4146
4247 Returns:
4348 YOLOX model (nn.Module)
@@ -48,44 +53,59 @@ def create_yolox_model(
4853 device = "cuda:0" if torch .cuda .is_available () else "cpu"
4954 device = torch .device (device )
5055
51- assert name in _CKPT_FULL_PATH , f"user should use one of value in { _CKPT_FULL_PATH .keys ()} "
52- exp : Exp = get_exp (exp_name = name )
53- exp .num_classes = num_classes
54- yolox_model = exp .get_model ()
55- if pretrained and num_classes == 80 :
56- weights_url = _CKPT_FULL_PATH [name ]
57- ckpt = load_state_dict_from_url (weights_url , map_location = "cpu" )
58- if "model" in ckpt :
59- ckpt = ckpt ["model" ]
60- yolox_model .load_state_dict (ckpt )
56+ assert name in _CKPT_FULL_PATH or name == "yolox_custom" , \
57+ f"user should use one of value in { _CKPT_FULL_PATH .keys ()} or \" yolox_custom\" "
58+ if name in _CKPT_FULL_PATH :
59+ exp : Exp = get_exp (exp_name = name )
60+ exp .num_classes = num_classes
61+ yolox_model = exp .get_model ()
62+ if pretrained and num_classes == 80 :
63+ weights_url = _CKPT_FULL_PATH [name ]
64+ ckpt = load_state_dict_from_url (weights_url , map_location = "cpu" )
65+ if "model" in ckpt :
66+ ckpt = ckpt ["model" ]
67+ yolox_model .load_state_dict (ckpt )
68+ else :
69+ assert exp_path is not None , "for a \" yolox_custom\" model exp_path must be provided"
70+ exp : Exp = get_exp (exp_file = exp_path )
71+ yolox_model = exp .get_model ()
72+ if ckpt_path :
73+ ckpt = torch .load (ckpt_path , map_location = "cpu" )
74+ if "model" in ckpt :
75+ ckpt = ckpt ["model" ]
76+ yolox_model .load_state_dict (ckpt )
6177
6278 yolox_model .to (device )
6379 return yolox_model
6480
6581
66- def yolox_nano (pretrained = True , num_classes = 80 , device = None ):
82+ def yolox_nano (pretrained : bool = True , num_classes : int = 80 , device : str = None ) -> nn . Module :
6783 return create_yolox_model ("yolox-nano" , pretrained , num_classes , device )
6884
6985
70- def yolox_tiny (pretrained = True , num_classes = 80 , device = None ):
86+ def yolox_tiny (pretrained : bool = True , num_classes : int = 80 , device : str = None ) -> nn . Module :
7187 return create_yolox_model ("yolox-tiny" , pretrained , num_classes , device )
7288
7389
74- def yolox_s (pretrained = True , num_classes = 80 , device = None ):
90+ def yolox_s (pretrained : bool = True , num_classes : int = 80 , device : str = None ) -> nn . Module :
7591 return create_yolox_model ("yolox-s" , pretrained , num_classes , device )
7692
7793
78- def yolox_m (pretrained = True , num_classes = 80 , device = None ):
94+ def yolox_m (pretrained : bool = True , num_classes : int = 80 , device : str = None ) -> nn . Module :
7995 return create_yolox_model ("yolox-m" , pretrained , num_classes , device )
8096
8197
82- def yolox_l (pretrained = True , num_classes = 80 , device = None ):
98+ def yolox_l (pretrained : bool = True , num_classes : int = 80 , device : str = None ) -> nn . Module :
8399 return create_yolox_model ("yolox-l" , pretrained , num_classes , device )
84100
85101
86- def yolox_x (pretrained = True , num_classes = 80 , device = None ):
102+ def yolox_x (pretrained : bool = True , num_classes : int = 80 , device : str = None ) -> nn . Module :
87103 return create_yolox_model ("yolox-x" , pretrained , num_classes , device )
88104
89105
90- def yolov3 (pretrained = True , num_classes = 80 , device = None ):
91- return create_yolox_model ("yolox-tiny" , pretrained , num_classes , device )
106+ def yolov3 (pretrained : bool = True , num_classes : int = 80 , device : str = None ) -> nn .Module :
107+ return create_yolox_model ("yolov3" , pretrained , num_classes , device )
108+
109+
110+ def yolox_custom (ckpt_path : str = None , exp_path : str = None , device : str = None ) -> nn .Module :
111+ return create_yolox_model ("yolox_custom" , ckpt_path = ckpt_path , exp_path = exp_path , device = device )
0 commit comments