Skip to content

Commit 1732770

Browse files
committed
train model with gpu instances
1 parent eac174e commit 1732770

File tree

1 file changed

+299
-0
lines changed

1 file changed

+299
-0
lines changed

train_colab.ipynb

Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
{
2+
"nbformat": 4,
3+
"nbformat_minor": 0,
4+
"metadata": {
5+
"colab": {
6+
"name": "crowd_analysis.ipynb",
7+
"provenance": [],
8+
"toc_visible": true,
9+
"include_colab_link": true
10+
},
11+
"kernelspec": {
12+
"name": "python3",
13+
"display_name": "Python 3"
14+
},
15+
"accelerator": "GPU"
16+
},
17+
"cells": [
18+
{
19+
"cell_type": "markdown",
20+
"metadata": {
21+
"id": "view-in-github",
22+
"colab_type": "text"
23+
},
24+
"source": [
25+
"<a href=\"https://colab.research.google.com/github/tempdata73/crowd_analysis/blob/master/train_colab.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
26+
]
27+
},
28+
{
29+
"cell_type": "markdown",
30+
"metadata": {
31+
"id": "imABng0dsrwP",
32+
"colab_type": "text"
33+
},
34+
"source": [
35+
"# CSRNet: Congested Scene Recognition Network\n",
36+
"The purpose of this notebook is to train a CSRNet with the help of GPU instances."
37+
]
38+
},
39+
{
40+
"cell_type": "code",
41+
"metadata": {
42+
"id": "KNdQNhT2xRhW",
43+
"colab_type": "code",
44+
"colab": {}
45+
},
46+
"source": [
47+
"!pip install PyDrive -qq"
48+
],
49+
"execution_count": 0,
50+
"outputs": []
51+
},
52+
{
53+
"cell_type": "code",
54+
"metadata": {
55+
"id": "gJ5dZqsqZGN7",
56+
"colab_type": "code",
57+
"colab": {}
58+
},
59+
"source": [
60+
"!git clone https://github.com/tempdata73/crowd_analysis.git\n",
61+
"%cd crowd_analysis"
62+
],
63+
"execution_count": 0,
64+
"outputs": []
65+
},
66+
{
67+
"cell_type": "code",
68+
"metadata": {
69+
"id": "Ha9l1tnSxkCg",
70+
"colab_type": "code",
71+
"colab": {}
72+
},
73+
"source": [
74+
"import os\n",
75+
"from pydrive.auth import GoogleAuth\n",
76+
"from pydrive.drive import GoogleDrive\n",
77+
"from google.colab import auth\n",
78+
"from oauth2client.client import GoogleCredentials"
79+
],
80+
"execution_count": 0,
81+
"outputs": []
82+
},
83+
{
84+
"cell_type": "code",
85+
"metadata": {
86+
"id": "rJ6-XWCrx2-d",
87+
"colab_type": "code",
88+
"colab": {}
89+
},
90+
"source": [
91+
"auth.authenticate_user()\n",
92+
"gauth = GoogleAuth()\n",
93+
"gauth.credentials = GoogleCredentials.get_application_default()\n",
94+
"drive = GoogleDrive(gauth)"
95+
],
96+
"execution_count": 0,
97+
"outputs": []
98+
},
99+
{
100+
"cell_type": "code",
101+
"metadata": {
102+
"id": "DoeThDOYyQ_F",
103+
"colab_type": "code",
104+
"colab": {}
105+
},
106+
"source": [
107+
"FILE_ID = '1FKTNDuK8-IlvufVMW8ymP_onPYhJNYL1'\n",
108+
"FILENAME = 'shanghai_dataset.zip'\n",
109+
"DST_DIR = 'Shanghai'\n",
110+
"\n",
111+
"download = drive.CreateFile({'id': FILE_ID})\n",
112+
"download.GetContentFile(FILENAME)"
113+
],
114+
"execution_count": 0,
115+
"outputs": []
116+
},
117+
{
118+
"cell_type": "code",
119+
"metadata": {
120+
"id": "qKDBHQW3yQ2C",
121+
"colab_type": "code",
122+
"colab": {}
123+
},
124+
"source": [
125+
"from zipfile import ZipFile\n",
126+
"\n",
127+
"ZipFile(FILENAME).extractall(DST_DIR)"
128+
],
129+
"execution_count": 0,
130+
"outputs": []
131+
},
132+
{
133+
"cell_type": "code",
134+
"metadata": {
135+
"id": "Cm4BY3gRXZuM",
136+
"colab_type": "code",
137+
"colab": {}
138+
},
139+
"source": [
140+
"!rm shanghai_dataset.zip"
141+
],
142+
"execution_count": 0,
143+
"outputs": []
144+
},
145+
{
146+
"cell_type": "markdown",
147+
"metadata": {
148+
"id": "qC3OGuvAcLUa",
149+
"colab_type": "text"
150+
},
151+
"source": [
152+
"## Density map generation"
153+
]
154+
},
155+
{
156+
"cell_type": "code",
157+
"metadata": {
158+
"id": "4IFILD5JcLDx",
159+
"colab_type": "code",
160+
"colab": {}
161+
},
162+
"source": [
163+
"import time\n",
164+
"start_time = time.time()\n",
165+
"!python make_dataset.py Shanghai/ Shanghai_A\n",
166+
"end_time = time.time()\n",
167+
"print('Density map generation took {:0.2f} secs'.format(end_time - start_time))"
168+
],
169+
"execution_count": 0,
170+
"outputs": []
171+
},
172+
{
173+
"cell_type": "code",
174+
"metadata": {
175+
"id": "Er_o7vyiXODE",
176+
"colab_type": "code",
177+
"colab": {}
178+
},
179+
"source": [
180+
"!python train.py part_A/trainval.json part_A/test.json"
181+
],
182+
"execution_count": 0,
183+
"outputs": []
184+
},
185+
{
186+
"cell_type": "markdown",
187+
"metadata": {
188+
"id": "PWtG-sopZcS4",
189+
"colab_type": "text"
190+
},
191+
"source": [
192+
"## Model tests"
193+
]
194+
},
195+
{
196+
"cell_type": "code",
197+
"metadata": {
198+
"id": "M4fQJgI0ZUTx",
199+
"colab_type": "code",
200+
"colab": {}
201+
},
202+
"source": [
203+
"import torch\n",
204+
"import os\n",
205+
"\n",
206+
"device = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n",
207+
"CKPT_DIR = 'ckpts'\n",
208+
"models = []\n",
209+
"\n",
210+
"for model_fn in os.listdir(CKPT_DIR):\n",
211+
" if model_fn == 'model.pth.tar':\n",
212+
" continue\n",
213+
" model_path = os.path.join(CKPT_DIR, model_fn)\n",
214+
" ckpt = torch.load(model_path, map_location=torch.device(device))\n",
215+
" models.append((model_path, ckpt['loss']))\n",
216+
"\n",
217+
"models = sorted(models, key=lambda loss: loss[1], reverse=False)\n",
218+
"model_path = models[0][0]\n",
219+
"print(model_path)"
220+
],
221+
"execution_count": 0,
222+
"outputs": []
223+
},
224+
{
225+
"cell_type": "code",
226+
"metadata": {
227+
"id": "pAs71HOhKiiH",
228+
"colab_type": "code",
229+
"colab": {}
230+
},
231+
"source": [
232+
"!python validation.py $model_path part_A/test.json"
233+
],
234+
"execution_count": 0,
235+
"outputs": []
236+
},
237+
{
238+
"cell_type": "code",
239+
"metadata": {
240+
"id": "5QXfW5lSr_V5",
241+
"colab_type": "code",
242+
"colab": {}
243+
},
244+
"source": [
245+
"import json\n",
246+
"\n",
247+
"with open('data/metrics.json') as infile:\n",
248+
" data = json.load(infile)\n",
249+
"\n",
250+
"data"
251+
],
252+
"execution_count": 0,
253+
"outputs": []
254+
},
255+
{
256+
"cell_type": "code",
257+
"metadata": {
258+
"id": "etscb0CHAIxs",
259+
"colab_type": "code",
260+
"colab": {}
261+
},
262+
"source": [
263+
"# Last epoch did not necessarily yield best model\n",
264+
"# Fix data according to that\n",
265+
"\n",
266+
"ckpt = torch.load(model_path, map_location=torch.device(device))\n",
267+
"last_epoch = ckpt['epoch']\n",
268+
"\n",
269+
"with open('data/loss_data.json') as infile:\n",
270+
" loss = json.load(infile)\n",
271+
"\n",
272+
"last_ckpt_loss_data = {'train_mae': loss['train_mae'][:last_epoch], 'val_mae': loss['val_mae'][:last_epoch]}\n",
273+
"\n",
274+
"with open('data/loss_data.json', 'w') as outfile:\n",
275+
" json.dump(last_ckpt_loss_data, outfile)"
276+
],
277+
"execution_count": 0,
278+
"outputs": []
279+
},
280+
{
281+
"cell_type": "code",
282+
"metadata": {
283+
"id": "fok2K6zE1KwN",
284+
"colab_type": "code",
285+
"colab": {}
286+
},
287+
"source": [
288+
"# Download model and data\n",
289+
"from google.colab import files\n",
290+
"\n",
291+
"files.download(model_path)\n",
292+
"files.download('data/metrics.json')\n",
293+
"files.download('data/loss_data.json')"
294+
],
295+
"execution_count": 0,
296+
"outputs": []
297+
}
298+
]
299+
}

0 commit comments

Comments
 (0)