26
26
```python
27
27
import segmentation_models_pytorch as smp
28
28
29
- model = smp.{{ model_name }}. from_pretrained("{{ save_directory | default(" <save-directory-or-repo>", true)}} ")
29
+ model = smp.from_pretrained("<save-directory-or-this- repo>")
30
30
```
31
31
32
32
## Model init parameters
@@ -61,23 +61,22 @@ def _format_parameters(parameters: dict):
61
61
62
62
class SMPHubMixin (PyTorchModelHubMixin ):
63
63
def generate_model_card (self , * args , ** kwargs ) -> ModelCard :
64
- model_parameters_json = _format_parameters (self ._hub_mixin_config )
65
- directory = self ._save_directory if hasattr (self , "_save_directory" ) else None
66
- repo_id = self ._repo_id if hasattr (self , "_repo_id" ) else None
67
- repo_or_directory = repo_id if repo_id is not None else directory
68
-
69
- metrics = self ._metrics if hasattr (self , "_metrics" ) else None
70
- dataset = self ._dataset if hasattr (self , "_dataset" ) else None
64
+ model_parameters_json = _format_parameters (self .config )
65
+ metrics = kwargs .get ("metrics" , None )
66
+ dataset = kwargs .get ("dataset" , None )
71
67
72
68
if metrics is not None :
73
69
metrics = json .dumps (metrics , indent = 4 )
74
70
metrics = f"```json\n { metrics } \n ```"
75
71
72
+ tags = self ._hub_mixin_info .model_card_data .get ("tags" , []) or []
73
+ tags .extend (["segmentation-models-pytorch" , "semantic-segmentation" , "pytorch" ])
74
+
76
75
model_card_data = ModelCardData (
77
76
languages = ["python" ],
78
77
library_name = "segmentation-models-pytorch" ,
79
78
license = "mit" ,
80
- tags = [ "semantic-segmentation" , "pytorch" , "segmentation-models-pytorch" ] ,
79
+ tags = tags ,
81
80
pipeline_tag = "image-segmentation" ,
82
81
)
83
82
model_card = ModelCard .from_template (
@@ -86,64 +85,49 @@ def generate_model_card(self, *args, **kwargs) -> ModelCard:
86
85
repo_url = "https://github.com/qubvel/segmentation_models.pytorch" ,
87
86
docs_url = "https://smp.readthedocs.io/en/latest/" ,
88
87
model_parameters = model_parameters_json ,
89
- save_directory = repo_or_directory ,
90
88
model_name = self .__class__ .__name__ ,
91
89
metrics = metrics ,
92
90
dataset = dataset ,
93
91
)
94
92
return model_card
95
93
96
- def _set_attrs_from_kwargs (self , attrs , kwargs ):
97
- for attr in attrs :
98
- if attr in kwargs :
99
- setattr (self , f"_{ attr } " , kwargs .pop (attr ))
100
-
101
- def _del_attrs (self , attrs ):
102
- for attr in attrs :
103
- if hasattr (self , f"_{ attr } " ):
104
- delattr (self , f"_{ attr } " )
105
-
106
94
@wraps (PyTorchModelHubMixin .save_pretrained )
107
95
def save_pretrained (
108
96
self , save_directory : Union [str , Path ], * args , ** kwargs
109
97
) -> Optional [str ]:
110
- # set additional attributes to be used in generate_model_card
111
- self ._save_directory = save_directory
112
- self ._set_attrs_from_kwargs (["metrics" , "dataset" ], kwargs )
98
+ model_card_kwargs = kwargs .pop ("model_card_kwargs" , {})
99
+ if "dataset" in kwargs :
100
+ model_card_kwargs ["dataset" ] = kwargs .pop ("dataset" )
101
+ if "metrics" in kwargs :
102
+ model_card_kwargs ["metrics" ] = kwargs .pop ("metrics" )
103
+ kwargs ["model_card_kwargs" ] = model_card_kwargs
113
104
114
- # set additional attribute to be used in from_pretrained
115
- self ._hub_mixin_config ["_model_class" ] = self .__class__ .__name__
105
+ # set additional attribute to be ble to deserialize the model
106
+ self .config ["_model_class" ] = self .__class__ .__name__
116
107
117
108
try :
118
109
# call the original save_pretrained
119
110
result = super ().save_pretrained (save_directory , * args , ** kwargs )
120
111
finally :
121
- # delete the additional attributes
122
- self ._del_attrs (["save_directory" , "metrics" , "dataset" ])
123
- self ._hub_mixin_config .pop ("_model_class" , None )
112
+ self .config .pop ("_model_class" , None )
124
113
125
114
return result
126
115
127
- @wraps (PyTorchModelHubMixin .push_to_hub )
128
- def push_to_hub (self , repo_id : str , * args , ** kwargs ):
129
- self ._repo_id = repo_id
130
- self ._set_attrs_from_kwargs (["metrics" , "dataset" ], kwargs )
131
- result = super ().push_to_hub (repo_id , * args , ** kwargs )
132
- self ._del_attrs (["repo_id" , "metrics" , "dataset" ])
133
- return result
134
-
135
116
@property
136
- def config (self ):
117
+ def config (self ) -> dict :
137
118
return self ._hub_mixin_config
138
119
139
120
140
121
@wraps (PyTorchModelHubMixin .from_pretrained )
141
122
def from_pretrained (pretrained_model_name_or_path : str , * args , ** kwargs ):
142
- config_path = hf_hub_download (
143
- pretrained_model_name_or_path ,
144
- filename = "config.json" ,
145
- revision = kwargs .get ("revision" , None ),
146
- )
123
+ config_path = Path (pretrained_model_name_or_path ) / "config.json"
124
+ if not config_path .exists ():
125
+ config_path = hf_hub_download (
126
+ pretrained_model_name_or_path ,
127
+ filename = "config.json" ,
128
+ revision = kwargs .get ("revision" , None ),
129
+ )
130
+
147
131
with open (config_path , "r" ) as f :
148
132
config = json .load (f )
149
133
model_class_name = config .pop ("_model_class" )
0 commit comments