Skip to content

Commit 5a780a8

Browse files
committed
Update mixin
1 parent f40b6ed commit 5a780a8

File tree

1 file changed

+26
-42
lines changed

1 file changed

+26
-42
lines changed

segmentation_models_pytorch/base/hub_mixin.py

Lines changed: 26 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
```python
2727
import segmentation_models_pytorch as smp
2828
29-
model = smp.{{ model_name }}.from_pretrained("{{ save_directory | default("<save-directory-or-repo>", true)}}")
29+
model = smp.from_pretrained("<save-directory-or-this-repo>")
3030
```
3131
3232
## Model init parameters
@@ -61,23 +61,22 @@ def _format_parameters(parameters: dict):
6161

6262
class SMPHubMixin(PyTorchModelHubMixin):
6363
def generate_model_card(self, *args, **kwargs) -> ModelCard:
64-
model_parameters_json = _format_parameters(self._hub_mixin_config)
65-
directory = self._save_directory if hasattr(self, "_save_directory") else None
66-
repo_id = self._repo_id if hasattr(self, "_repo_id") else None
67-
repo_or_directory = repo_id if repo_id is not None else directory
68-
69-
metrics = self._metrics if hasattr(self, "_metrics") else None
70-
dataset = self._dataset if hasattr(self, "_dataset") else None
64+
model_parameters_json = _format_parameters(self.config)
65+
metrics = kwargs.get("metrics", None)
66+
dataset = kwargs.get("dataset", None)
7167

7268
if metrics is not None:
7369
metrics = json.dumps(metrics, indent=4)
7470
metrics = f"```json\n{metrics}\n```"
7571

72+
tags = self._hub_mixin_info.model_card_data.get("tags", []) or []
73+
tags.extend(["segmentation-models-pytorch", "semantic-segmentation", "pytorch"])
74+
7675
model_card_data = ModelCardData(
7776
languages=["python"],
7877
library_name="segmentation-models-pytorch",
7978
license="mit",
80-
tags=["semantic-segmentation", "pytorch", "segmentation-models-pytorch"],
79+
tags=tags,
8180
pipeline_tag="image-segmentation",
8281
)
8382
model_card = ModelCard.from_template(
@@ -86,64 +85,49 @@ def generate_model_card(self, *args, **kwargs) -> ModelCard:
8685
repo_url="https://github.com/qubvel/segmentation_models.pytorch",
8786
docs_url="https://smp.readthedocs.io/en/latest/",
8887
model_parameters=model_parameters_json,
89-
save_directory=repo_or_directory,
9088
model_name=self.__class__.__name__,
9189
metrics=metrics,
9290
dataset=dataset,
9391
)
9492
return model_card
9593

96-
def _set_attrs_from_kwargs(self, attrs, kwargs):
97-
for attr in attrs:
98-
if attr in kwargs:
99-
setattr(self, f"_{attr}", kwargs.pop(attr))
100-
101-
def _del_attrs(self, attrs):
102-
for attr in attrs:
103-
if hasattr(self, f"_{attr}"):
104-
delattr(self, f"_{attr}")
105-
10694
@wraps(PyTorchModelHubMixin.save_pretrained)
10795
def save_pretrained(
10896
self, save_directory: Union[str, Path], *args, **kwargs
10997
) -> Optional[str]:
110-
# set additional attributes to be used in generate_model_card
111-
self._save_directory = save_directory
112-
self._set_attrs_from_kwargs(["metrics", "dataset"], kwargs)
98+
model_card_kwargs = kwargs.pop("model_card_kwargs", {})
99+
if "dataset" in kwargs:
100+
model_card_kwargs["dataset"] = kwargs.pop("dataset")
101+
if "metrics" in kwargs:
102+
model_card_kwargs["metrics"] = kwargs.pop("metrics")
103+
kwargs["model_card_kwargs"] = model_card_kwargs
113104

114-
# set additional attribute to be used in from_pretrained
115-
self._hub_mixin_config["_model_class"] = self.__class__.__name__
105+
# set additional attribute to be ble to deserialize the model
106+
self.config["_model_class"] = self.__class__.__name__
116107

117108
try:
118109
# call the original save_pretrained
119110
result = super().save_pretrained(save_directory, *args, **kwargs)
120111
finally:
121-
# delete the additional attributes
122-
self._del_attrs(["save_directory", "metrics", "dataset"])
123-
self._hub_mixin_config.pop("_model_class", None)
112+
self.config.pop("_model_class", None)
124113

125114
return result
126115

127-
@wraps(PyTorchModelHubMixin.push_to_hub)
128-
def push_to_hub(self, repo_id: str, *args, **kwargs):
129-
self._repo_id = repo_id
130-
self._set_attrs_from_kwargs(["metrics", "dataset"], kwargs)
131-
result = super().push_to_hub(repo_id, *args, **kwargs)
132-
self._del_attrs(["repo_id", "metrics", "dataset"])
133-
return result
134-
135116
@property
136-
def config(self):
117+
def config(self) -> dict:
137118
return self._hub_mixin_config
138119

139120

140121
@wraps(PyTorchModelHubMixin.from_pretrained)
141122
def from_pretrained(pretrained_model_name_or_path: str, *args, **kwargs):
142-
config_path = hf_hub_download(
143-
pretrained_model_name_or_path,
144-
filename="config.json",
145-
revision=kwargs.get("revision", None),
146-
)
123+
config_path = Path(pretrained_model_name_or_path) / "config.json"
124+
if not config_path.exists():
125+
config_path = hf_hub_download(
126+
pretrained_model_name_or_path,
127+
filename="config.json",
128+
revision=kwargs.get("revision", None),
129+
)
130+
147131
with open(config_path, "r") as f:
148132
config = json.load(f)
149133
model_class_name = config.pop("_model_class")

0 commit comments

Comments
 (0)