1+ {
2+ "nbformat" : 4 ,
3+ "nbformat_minor" : 0 ,
4+ "metadata" : {
5+ "colab" : {
6+ "name" : " crowd_analysis.ipynb" ,
7+ "provenance" : [],
8+ "toc_visible" : true ,
9+ "include_colab_link" : true
10+ },
11+ "kernelspec" : {
12+ "name" : " python3" ,
13+ "display_name" : " Python 3"
14+ },
15+ "accelerator" : " GPU"
16+ },
17+ "cells" : [
18+ {
19+ "cell_type" : " markdown" ,
20+ "metadata" : {
21+ "id" : " view-in-github" ,
22+ "colab_type" : " text"
23+ },
24+ "source" : [
25+ " <a href=\" https://colab.research.google.com/github/tempdata73/crowd_analysis/blob/master/train_colab.ipynb\" target=\" _parent\" ><img src=\" https://colab.research.google.com/assets/colab-badge.svg\" alt=\" Open In Colab\" /></a>"
26+ ]
27+ },
28+ {
29+ "cell_type" : " markdown" ,
30+ "metadata" : {
31+ "id" : " imABng0dsrwP" ,
32+ "colab_type" : " text"
33+ },
34+ "source" : [
35+ " # CSRNet: Congested Scene Recognition Network\n " ,
36+ " The purpose of this notebook is to train a CSRNet with the help of GPU instances."
37+ ]
38+ },
39+ {
40+ "cell_type" : " code" ,
41+ "metadata" : {
42+ "id" : " KNdQNhT2xRhW" ,
43+ "colab_type" : " code" ,
44+ "colab" : {}
45+ },
46+ "source" : [
47+ " !pip install PyDrive -qq"
48+ ],
49+ "execution_count" : 0 ,
50+ "outputs" : []
51+ },
52+ {
53+ "cell_type" : " code" ,
54+ "metadata" : {
55+ "id" : " gJ5dZqsqZGN7" ,
56+ "colab_type" : " code" ,
57+ "colab" : {}
58+ },
59+ "source" : [
60+ " !git clone https://github.com/tempdata73/crowd_analysis.git\n " ,
61+ " %cd crowd_analysis"
62+ ],
63+ "execution_count" : 0 ,
64+ "outputs" : []
65+ },
66+ {
67+ "cell_type" : " code" ,
68+ "metadata" : {
69+ "id" : " Ha9l1tnSxkCg" ,
70+ "colab_type" : " code" ,
71+ "colab" : {}
72+ },
73+ "source" : [
74+ " import os\n " ,
75+ " from pydrive.auth import GoogleAuth\n " ,
76+ " from pydrive.drive import GoogleDrive\n " ,
77+ " from google.colab import auth\n " ,
78+ " from oauth2client.client import GoogleCredentials"
79+ ],
80+ "execution_count" : 0 ,
81+ "outputs" : []
82+ },
83+ {
84+ "cell_type" : " code" ,
85+ "metadata" : {
86+ "id" : " rJ6-XWCrx2-d" ,
87+ "colab_type" : " code" ,
88+ "colab" : {}
89+ },
90+ "source" : [
91+ " auth.authenticate_user()\n " ,
92+ " gauth = GoogleAuth()\n " ,
93+ " gauth.credentials = GoogleCredentials.get_application_default()\n " ,
94+ " drive = GoogleDrive(gauth)"
95+ ],
96+ "execution_count" : 0 ,
97+ "outputs" : []
98+ },
99+ {
100+ "cell_type" : " code" ,
101+ "metadata" : {
102+ "id" : " DoeThDOYyQ_F" ,
103+ "colab_type" : " code" ,
104+ "colab" : {}
105+ },
106+ "source" : [
107+ " FILE_ID = '1FKTNDuK8-IlvufVMW8ymP_onPYhJNYL1'\n " ,
108+ " FILENAME = 'shanghai_dataset.zip'\n " ,
109+ " DST_DIR = 'Shanghai'\n " ,
110+ " \n " ,
111+ " download = drive.CreateFile({'id': FILE_ID})\n " ,
112+ " download.GetContentFile(FILENAME)"
113+ ],
114+ "execution_count" : 0 ,
115+ "outputs" : []
116+ },
117+ {
118+ "cell_type" : " code" ,
119+ "metadata" : {
120+ "id" : " qKDBHQW3yQ2C" ,
121+ "colab_type" : " code" ,
122+ "colab" : {}
123+ },
124+ "source" : [
125+ " from zipfile import ZipFile\n " ,
126+ " \n " ,
127+ " ZipFile(FILENAME).extractall(DST_DIR)"
128+ ],
129+ "execution_count" : 0 ,
130+ "outputs" : []
131+ },
132+ {
133+ "cell_type" : " code" ,
134+ "metadata" : {
135+ "id" : " Cm4BY3gRXZuM" ,
136+ "colab_type" : " code" ,
137+ "colab" : {}
138+ },
139+ "source" : [
140+ " !rm shanghai_dataset.zip"
141+ ],
142+ "execution_count" : 0 ,
143+ "outputs" : []
144+ },
145+ {
146+ "cell_type" : " markdown" ,
147+ "metadata" : {
148+ "id" : " qC3OGuvAcLUa" ,
149+ "colab_type" : " text"
150+ },
151+ "source" : [
152+ " ## Density map generation"
153+ ]
154+ },
155+ {
156+ "cell_type" : " code" ,
157+ "metadata" : {
158+ "id" : " 4IFILD5JcLDx" ,
159+ "colab_type" : " code" ,
160+ "colab" : {}
161+ },
162+ "source" : [
163+ " import time\n " ,
164+ " start_time = time.time()\n " ,
165+ " !python make_dataset.py Shanghai/ Shanghai_A\n " ,
166+ " end_time = time.time()\n " ,
167+ " print('Density map generation took {:0.2f} secs'.format(end_time - start_time))"
168+ ],
169+ "execution_count" : 0 ,
170+ "outputs" : []
171+ },
172+ {
173+ "cell_type" : " code" ,
174+ "metadata" : {
175+ "id" : " Er_o7vyiXODE" ,
176+ "colab_type" : " code" ,
177+ "colab" : {}
178+ },
179+ "source" : [
180+ " !python train.py part_A/trainval.json part_A/test.json"
181+ ],
182+ "execution_count" : 0 ,
183+ "outputs" : []
184+ },
185+ {
186+ "cell_type" : " markdown" ,
187+ "metadata" : {
188+ "id" : " PWtG-sopZcS4" ,
189+ "colab_type" : " text"
190+ },
191+ "source" : [
192+ " ## Model tests"
193+ ]
194+ },
195+ {
196+ "cell_type" : " code" ,
197+ "metadata" : {
198+ "id" : " M4fQJgI0ZUTx" ,
199+ "colab_type" : " code" ,
200+ "colab" : {}
201+ },
202+ "source" : [
203+ " import torch\n " ,
204+ " import os\n " ,
205+ " \n " ,
206+ " device = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n " ,
207+ " CKPT_DIR = 'ckpts'\n " ,
208+ " models = []\n " ,
209+ " \n " ,
210+ " for model_fn in os.listdir(CKPT_DIR):\n " ,
211+ " if model_fn == 'model.pth.tar':\n " ,
212+ " continue\n " ,
213+ " model_path = os.path.join(CKPT_DIR, model_fn)\n " ,
214+ " ckpt = torch.load(model_path, map_location=torch.device(device))\n " ,
215+ " models.append((model_path, ckpt['loss']))\n " ,
216+ " \n " ,
217+ " models = sorted(models, key=lambda loss: loss[1], reverse=False)\n " ,
218+ " model_path = models[0][0]\n " ,
219+ " print(model_path)"
220+ ],
221+ "execution_count" : 0 ,
222+ "outputs" : []
223+ },
224+ {
225+ "cell_type" : " code" ,
226+ "metadata" : {
227+ "id" : " pAs71HOhKiiH" ,
228+ "colab_type" : " code" ,
229+ "colab" : {}
230+ },
231+ "source" : [
232+ " !python validation.py $model_path part_A/test.json"
233+ ],
234+ "execution_count" : 0 ,
235+ "outputs" : []
236+ },
237+ {
238+ "cell_type" : " code" ,
239+ "metadata" : {
240+ "id" : " 5QXfW5lSr_V5" ,
241+ "colab_type" : " code" ,
242+ "colab" : {}
243+ },
244+ "source" : [
245+ " import json\n " ,
246+ " \n " ,
247+ " with open('data/metrics.json') as infile:\n " ,
248+ " data = json.load(infile)\n " ,
249+ " \n " ,
250+ " data"
251+ ],
252+ "execution_count" : 0 ,
253+ "outputs" : []
254+ },
255+ {
256+ "cell_type" : " code" ,
257+ "metadata" : {
258+ "id" : " etscb0CHAIxs" ,
259+ "colab_type" : " code" ,
260+ "colab" : {}
261+ },
262+ "source" : [
263+ " # Last epoch did not necessarily yield best model\n " ,
264+ " # Fix data according to that\n " ,
265+ " \n " ,
266+ " ckpt = torch.load(model_path, map_location=torch.device(device))\n " ,
267+ " last_epoch = ckpt['epoch']\n " ,
268+ " \n " ,
269+ " with open('data/loss_data.json') as infile:\n " ,
270+ " loss = json.load(infile)\n " ,
271+ " \n " ,
272+ " last_ckpt_loss_data = {'train_mae': loss['train_mae'][:last_epoch], 'val_mae': loss['val_mae'][:last_epoch]}\n " ,
273+ " \n " ,
274+ " with open('data/loss_data.json', 'w') as outfile:\n " ,
275+ " json.dump(last_ckpt_loss_data, outfile)"
276+ ],
277+ "execution_count" : 0 ,
278+ "outputs" : []
279+ },
280+ {
281+ "cell_type" : " code" ,
282+ "metadata" : {
283+ "id" : " fok2K6zE1KwN" ,
284+ "colab_type" : " code" ,
285+ "colab" : {}
286+ },
287+ "source" : [
288+ " # Download model and data\n " ,
289+ " from google.colab import files\n " ,
290+ " \n " ,
291+ " files.download(model_path)\n " ,
292+ " files.download('data/metrics.json')\n " ,
293+ " files.download('data/loss_data.json')"
294+ ],
295+ "execution_count" : 0 ,
296+ "outputs" : []
297+ }
298+ ]
299+ }
0 commit comments