Skip to content

Commit 9a7e06d

Browse files
committed
first commit
0 parents  commit 9a7e06d

17 files changed

+243207
-0
lines changed

README.md

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Code for 'Dynamic MLP for Fine-Grained Image Classification by Leveraging Geographical and Temporal Information'
2+
3+
<p align="center"> <img src="figs/structure.svg" width="100%"></p>
4+
Dynamic MLP, which is parameterized by the learned embeddings of variable locations and dates to help fine-grained image classification.
5+
6+
## Requirements
7+
8+
Experiment Environment
9+
- python 3.6
10+
- pytorch 1.7.1+cu101
11+
- torchvision 0.8.2
12+
13+
Get pretrained models for SK-Res2Net following [here](checkpoints/README.md).
14+
Get datasets following [here](datasets/README.md).
15+
16+
## Train the model
17+
### 1. Train image-only model
18+
Specify ```--image_only``` for training image-only models.
19+
- ResNet-50 (67.924% Top-1 acc)
20+
```python
21+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 train.py \
22+
--name res50_image_only \
23+
--data 'inat21_mini' \
24+
--data_dir 'path/to/your/data' \
25+
--model_file 'resnet' \
26+
--model_name 'resnet50' \
27+
--pretrained \
28+
--batch_size 512 \
29+
--start_lr 0.04 \
30+
--image_only
31+
```
32+
33+
- SK-Res2Net-101 (76.102% Top-1 acc)
34+
```python
35+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 train.py \
36+
--name sk2_image_only \
37+
--data 'inat21_mini' \
38+
--data_dir 'path/to/your/data' \
39+
--model_file 'sk2res2net' \
40+
--model_name 'sk2res2net101' \
41+
--pretrained \
42+
--batch_size 512 \
43+
--start_lr 0.04 \
44+
--image_only
45+
```
46+
47+
### 2. Train dynamic MLP model
48+
- ResNet-50 (78.751% Top-1 acc)
49+
```python
50+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 train.py \
51+
--name res50_dynamic_mlp \
52+
--data 'inat21_mini' \
53+
--data_dir 'path/to/your/data' \
54+
--model_file 'resnet_dynamic_mlp' \
55+
--model_name 'resnet50' \
56+
--pretrained \
57+
--batch_size 512 \
58+
--start_lr 0.04
59+
```
60+
61+
- SK-Res2Net-101 (84.694% Top-1 acc)
62+
```python
63+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 train.py \
64+
--name sk2_dynamic_mlp \
65+
--data 'inat21_mini' \
66+
--data_dir 'path/to/your/data' \
67+
--model_file 'sk2res2net_dynamic_mlp' \
68+
--model_name 'sk2res2net101' \
69+
--pretrained \
70+
--batch_size 512 \
71+
--start_lr 0.04
72+
```
73+
74+
## Test the model
75+
Specify ```--resume``` and ```--evaluate``` for inference and ```--image_only``` for testing image-only models.
76+
```python
77+
python3 train.py \
78+
--name sk2_dynamic_mlp \
79+
--data 'inat21_mini' \
80+
--data_dir 'path/to/your/data' \
81+
--model_file 'sk2res2net_dynamic_mlp' \
82+
--model_name 'sk2res2net101' \
83+
--resume 'path/to/your/checkpoint' \
84+
--evaluate
85+
```
86+
87+
## Model Zoo
88+
### iNaturalist 2021 mini (90 epoch)
89+
90+
| Backbone | Size | Acc@1 | Log | Download |
91+
| -------------- | :---: | :--------: | :---------------------------------------------------------------------------: | :-------: |
92+
| ResNet-50 | 224 | 67.924 | [log](logs/log_inat21-mini_90epoch_r50_image-only_67.924_top1_acc.txt) | [model]() |
93+
| + Dynamic MLP | 224 | **78.751** | [log](logs/log_inat21-mini_90epoch_r50_dynamic-mlp-c_78.751_top1_acc.txt) | [model]() |
94+
| SK-Res2Net-101 | 224 | 76.102 | [log](logs/log_inat21-mini_90epoch_sk2-101_image-only_76.102_top1_acc.txt) | [model]() |
95+
| + Dynamic MLP | 224 | **84.694** | [log](logs/log_inat21-mini_90epoch_sk2-101_dynamic-mlp-c_84.694_top1_acc.txt) | [model]() |

checkpoints/README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Get pretrained SK-Res2Net model
2+
The model is trained on ImageNet-1k for 300 epochs.
3+
4+
Click the link to download:
5+
[[Google Drive]](https://drive.google.com/file/d/1CJzcta4GoYqH5I5hcHyWoBl4iU1Y4oqc/view?usp=sharing)
6+
[[Github]]()
7+
8+
|-- checkpoints
9+
&emsp;&emsp;|-- sk2res2net101_epoch_300.pth

dataset.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
#!/usr/bin/env python3
2+
import datetime
3+
import json
4+
import math
5+
import os
6+
7+
import numpy as np
8+
import torch
9+
import torchvision.transforms as transforms
10+
from PIL import Image
11+
from torch.utils.data import Dataset
12+
13+
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
14+
15+
16+
class INatDataset(Dataset):
17+
def __init__(self, data, root, train, transform=None, args=None):
18+
self.transform = transform
19+
self.args = args
20+
21+
if train:
22+
if 'mini' in data:
23+
jpath = os.path.join(root, 'train_mini.json')
24+
else:
25+
jpath = os.path.join(root, 'train.json')
26+
else:
27+
jpath = os.path.join(root, 'val.json')
28+
29+
samples = []
30+
with open(jpath, 'r') as f:
31+
annotations = json.loads(f)
32+
for img, ann in zip(annotations['images'], annotations['annotations']):
33+
img_path = os.path.join(root, img['file_name'])
34+
label = ann['category_id']
35+
extra = {'date': img['date'], 'latitude': img['latitude'], 'longitude': img['longitude']}
36+
samples.append((img_path, int(label), extra))
37+
38+
self.samples = samples
39+
40+
def __len__(self):
41+
return len(self.samples)
42+
43+
def __getitem__(self, idx):
44+
img_path, label, extra = self.samples[idx]
45+
date = extra['date'] # 拍摄时间
46+
lat = extra['latitude'] # 纬度 -90 ~ 90
47+
lng = extra['longitude'] # 经度 -180 ~ 180
48+
if (lat is not None) and (lng is not None) and (date is not None):
49+
date_time = datetime.datetime.strptime(date[:10], '%Y-%m-%d')
50+
date = get_scaled_date_ratio(date_time)
51+
lat = float(lat) / 90
52+
lng = float(lng) / 180
53+
loc = []
54+
if 'geo' in self.args.metadata:
55+
loc += [lat, lng]
56+
if 'temporal' in self.args.metadata:
57+
loc += [date]
58+
loc = np.array(loc)
59+
loc = encode_loc_time(loc)
60+
else:
61+
loc = np.zeros(self.args.mlp_cin, float)
62+
img = Image.open(img_path)
63+
if self.transform is not None:
64+
img = self.transform(img)
65+
return img, label, loc
66+
67+
68+
def encode_loc_time(loc_time):
69+
# assumes inputs location and date features are in range -1 to 1
70+
# location is lon, lat
71+
feats = np.concatenate((np.sin(math.pi * loc_time), np.cos(math.pi * loc_time)))
72+
return feats
73+
74+
75+
def _is_leap_year(year):
76+
if year % 4 != 0 or (year % 100 == 0 and year % 400 != 0):
77+
return False
78+
return True
79+
80+
81+
def get_scaled_date_ratio(date_time):
82+
r'''
83+
scale date to [-1,1]
84+
'''
85+
days = [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
86+
total_days = 365
87+
year = date_time.year
88+
month = date_time.month
89+
day = date_time.day
90+
if _is_leap_year(year):
91+
days[1] += 1
92+
total_days += 1
93+
94+
assert day <= days[month - 1]
95+
sum_days = sum(days[:month - 1]) + day
96+
assert sum_days > 0 and sum_days <= total_days
97+
98+
return (sum_days / total_days) * 2 - 1
99+
100+
101+
def load_train_dataset(args):
102+
if args.data == 'inat17':
103+
args.num_classes = 5089
104+
elif args.data == 'inat18':
105+
args.num_classes = 8142
106+
elif args.data == 'inat21_mini' or 'inat21_full':
107+
args.num_classes = 10000
108+
else:
109+
raise NotImplementedError
110+
111+
dataset = INatDataset(
112+
args.data,
113+
root=args.data_dir,
114+
train=True,
115+
transform=transforms.Compose([
116+
transforms.RandomResizedCrop(224),
117+
transforms.RandomHorizontalFlip(),
118+
transforms.ToTensor(),
119+
normalize,
120+
]),
121+
args=args,
122+
)
123+
train_loader = torch.utils.data.DataLoader(
124+
dataset,
125+
batch_size=args.batch_size,
126+
shuffle=True,
127+
num_workers=args.num_workers,
128+
pin_memory=True,
129+
)
130+
return train_loader
131+
132+
133+
def load_val_dataset(args):
134+
if args.data == 'inat17':
135+
args.num_classes = 5089
136+
elif args.data == 'inat18':
137+
args.num_classes = 8142
138+
elif args.data == 'inat21_mini' or 'inat21_full':
139+
args.num_classes = 10000
140+
else:
141+
raise NotImplementedError
142+
143+
if args.tencrop:
144+
dataset = INatDataset(
145+
args.data,
146+
root=args.data_dir,
147+
train=False,
148+
transform=transforms.Compose([
149+
transforms.Resize(256),
150+
transforms.TenCrop(224),
151+
transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
152+
transforms.Lambda(lambda crops: torch.stack([normalize(crop) for crop in crops])),
153+
]),
154+
args=args,
155+
)
156+
val_loader = torch.utils.data.DataLoader(
157+
dataset,
158+
batch_size=args.batch_size,
159+
shuffle=False,
160+
num_workers=args.num_workers,
161+
pin_memory=True,
162+
)
163+
else:
164+
dataset = INatDataset(
165+
args.data,
166+
root=args.data_dir,
167+
train=False,
168+
transform=transforms.Compose([
169+
transforms.Resize(256),
170+
transforms.CenterCrop(224),
171+
transforms.ToTensor(),
172+
normalize,
173+
]),
174+
args=args,
175+
)
176+
val_loader = torch.utils.data.DataLoader(
177+
dataset,
178+
batch_size=args.batch_size,
179+
shuffle=False,
180+
num_workers=args.num_workers,
181+
pin_memory=True,
182+
)
183+
return val_loader

datasets/README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Get iNaturalist datasets
2+
3+
Download the iNaturalist datasets at:
4+
https://github.com/visipedia/inat_comp
5+
6+
|-- datasets/
7+
&emsp;&emsp;|-- inat21/
8+
&emsp;&emsp;&emsp;&emsp;|-- train_mini.json
9+
&emsp;&emsp;&emsp;&emsp;|-- train.json
10+
&emsp;&emsp;&emsp;&emsp;|-- val.json
11+
&emsp;&emsp;&emsp;&emsp;|-- train_mini/
12+
&emsp;&emsp;&emsp;&emsp;|-- train/
13+
&emsp;&emsp;&emsp;&emsp;|-- val/

figs/structure.svg

Lines changed: 1 addition & 0 deletions
Loading

0 commit comments

Comments
 (0)