Skip to content

Commit 667906e

Browse files
committed
Add example to save load share
1 parent b6d711f commit 667906e

File tree

1 file changed

+250
-0
lines changed

1 file changed

+250
-0
lines changed
Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"import segmentation_models_pytorch as smp"
10+
]
11+
},
12+
{
13+
"cell_type": "markdown",
14+
"metadata": {},
15+
"source": [
16+
"## Save to local directory and load back"
17+
]
18+
},
19+
{
20+
"cell_type": "code",
21+
"execution_count": 2,
22+
"metadata": {},
23+
"outputs": [
24+
{
25+
"name": "stdout",
26+
"output_type": "stream",
27+
"text": [
28+
"Loading weights from local directory\n"
29+
]
30+
}
31+
],
32+
"source": [
33+
"model = smp.Unet()\n",
34+
"\n",
35+
"# save the model\n",
36+
"model.save_pretrained(\"saved-model-dir/unet/\")\n",
37+
"\n",
38+
"# load the model\n",
39+
"restored_model = smp.from_pretrained(\"saved-model-dir/unet/\")"
40+
]
41+
},
42+
{
43+
"cell_type": "markdown",
44+
"metadata": {},
45+
"source": [
46+
"## Save model with additional metadata"
47+
]
48+
},
49+
{
50+
"cell_type": "code",
51+
"execution_count": 6,
52+
"metadata": {},
53+
"outputs": [],
54+
"source": [
55+
"model = smp.Unet()\n",
56+
"\n",
57+
"# save the model\n",
58+
"model.save_pretrained(\n",
59+
" \"saved-model-dir/unet-with-metadata/\",\n",
60+
"\n",
61+
" # additional information to be saved with the model\n",
62+
" # only \"dataset\" and \"metrics\" are supported\n",
63+
" dataset=\"PASCAL VOC\", # only string name is supported\n",
64+
" metrics={ # should be a dictionary with metric name as key and metric value as value\n",
65+
" \"mIoU\": 0.95,\n",
66+
" \"accuracy\": 0.96\n",
67+
" }\n",
68+
")"
69+
]
70+
},
71+
{
72+
"cell_type": "code",
73+
"execution_count": 7,
74+
"metadata": {},
75+
"outputs": [
76+
{
77+
"name": "stdout",
78+
"output_type": "stream",
79+
"text": [
80+
"---\n",
81+
"library_name: segmentation-models-pytorch\n",
82+
"license: mit\n",
83+
"pipeline_tag: image-segmentation\n",
84+
"tags:\n",
85+
"- semantic-segmentation\n",
86+
"- pytorch\n",
87+
"- segmentation-models-pytorch\n",
88+
"languages:\n",
89+
"- python\n",
90+
"---\n",
91+
"# Unet Model Card\n",
92+
"\n",
93+
"Table of Contents:\n",
94+
"- [Load trained model](#load-trained-model)\n",
95+
"- [Model init parameters](#model-init-parameters)\n",
96+
"- [Model metrics](#model-metrics)\n",
97+
"- [Dataset](#dataset)\n",
98+
"\n",
99+
"## Load trained model\n",
100+
"```python\n",
101+
"import segmentation_models_pytorch as smp\n",
102+
"\n",
103+
"model = smp.from_pretrained(\"<save-directory-or-this-repo>\")\n",
104+
"```\n",
105+
"\n",
106+
"## Model init parameters\n",
107+
"```python\n",
108+
"model_init_params = {\n",
109+
" \"encoder_name\": \"resnet34\",\n",
110+
" \"encoder_depth\": 5,\n",
111+
" \"encoder_weights\": \"imagenet\",\n",
112+
" \"decoder_use_batchnorm\": True,\n",
113+
" \"decoder_channels\": (256, 128, 64, 32, 16),\n",
114+
" \"decoder_attention_type\": None,\n",
115+
" \"in_channels\": 3,\n",
116+
" \"classes\": 1,\n",
117+
" \"activation\": None,\n",
118+
" \"aux_params\": None\n",
119+
"}\n",
120+
"```\n",
121+
"\n",
122+
"## Model metrics\n",
123+
"```json\n",
124+
"{\n",
125+
" \"mIoU\": 0.95,\n",
126+
" \"accuracy\": 0.96\n",
127+
"}\n",
128+
"```\n",
129+
"\n",
130+
"## Dataset\n",
131+
"Dataset name: PASCAL VOC\n",
132+
"\n",
133+
"## More Information\n",
134+
"- Library: https://github.com/qubvel/segmentation_models.pytorch\n",
135+
"- Docs: https://smp.readthedocs.io/en/latest/\n",
136+
"\n",
137+
"This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin)"
138+
]
139+
}
140+
],
141+
"source": [
142+
"!cat \"saved-model-dir/unet-with-metadata/README.md\""
143+
]
144+
},
145+
{
146+
"cell_type": "markdown",
147+
"metadata": {},
148+
"source": [
149+
"## Share model with HF Hub"
150+
]
151+
},
152+
{
153+
"cell_type": "code",
154+
"execution_count": 5,
155+
"metadata": {},
156+
"outputs": [
157+
{
158+
"data": {
159+
"application/vnd.jupyter.widget-view+json": {
160+
"model_id": "075ae026811542bdb4030e53b943efc7",
161+
"version_major": 2,
162+
"version_minor": 0
163+
},
164+
"text/plain": [
165+
"VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
166+
]
167+
},
168+
"metadata": {},
169+
"output_type": "display_data"
170+
}
171+
],
172+
"source": [
173+
"from huggingface_hub import notebook_login\n",
174+
"\n",
175+
"# You only need to run this once on the machine,\n",
176+
"# the token will be stored for later use\n",
177+
"notebook_login()"
178+
]
179+
},
180+
{
181+
"cell_type": "code",
182+
"execution_count": 8,
183+
"metadata": {},
184+
"outputs": [
185+
{
186+
"data": {
187+
"application/vnd.jupyter.widget-view+json": {
188+
"model_id": "2921a81d7fd747939b4a425cc17d6104",
189+
"version_major": 2,
190+
"version_minor": 0
191+
},
192+
"text/plain": [
193+
"model.safetensors: 0%| | 0.00/97.8M [00:00<?, ?B/s]"
194+
]
195+
},
196+
"metadata": {},
197+
"output_type": "display_data"
198+
},
199+
{
200+
"data": {
201+
"text/plain": [
202+
"CommitInfo(commit_url='https://huggingface.co/qubvel-hf/unet-with-metadata/commit/9f821c7bc3a12db827c0da96a31f354ec6ba5253', commit_message='Push model using huggingface_hub.', commit_description='', oid='9f821c7bc3a12db827c0da96a31f354ec6ba5253', pr_url=None, pr_revision=None, pr_num=None)"
203+
]
204+
},
205+
"execution_count": 8,
206+
"metadata": {},
207+
"output_type": "execute_result"
208+
}
209+
],
210+
"source": [
211+
"model = smp.Unet()\n",
212+
"\n",
213+
"# save the model and share it on the HF Hub (https://huggingface.co/models)\n",
214+
"model.save_pretrained(\n",
215+
" \"qubvel-hf/unet-with-metadata/\",\n",
216+
" push_to_hub=True, # <---------- push the model to the hub\n",
217+
" private=False, # <---------- make the model private or or public\n",
218+
" dataset=\"PASCAL VOC\",\n",
219+
" metrics={\n",
220+
" \"mIoU\": 0.95,\n",
221+
" \"accuracy\": 0.96\n",
222+
" }\n",
223+
")\n",
224+
"\n",
225+
"# see result here https://huggingface.co/qubvel-hf/unet-with-metadata"
226+
]
227+
}
228+
],
229+
"metadata": {
230+
"kernelspec": {
231+
"display_name": ".venv",
232+
"language": "python",
233+
"name": "python3"
234+
},
235+
"language_info": {
236+
"codemirror_mode": {
237+
"name": "ipython",
238+
"version": 3
239+
},
240+
"file_extension": ".py",
241+
"mimetype": "text/x-python",
242+
"name": "python",
243+
"nbconvert_exporter": "python",
244+
"pygments_lexer": "ipython3",
245+
"version": "3.10.12"
246+
}
247+
},
248+
"nbformat": 4,
249+
"nbformat_minor": 2
250+
}

0 commit comments

Comments
 (0)