Skip to content

Commit 806ae43

Browse files
authored
Merge pull request #404 from sergregory/t-sne-visualization
t-SNE visualization
2 parents 1fca803 + 279fe27 commit 806ae43

File tree

6 files changed

+421
-0
lines changed

6 files changed

+421
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ Want to become an expert in AI? [AI Courses by OpenCV](https://opencv.org/course
1313

1414
| Blog Post | |
1515
| ------------- |:-------------|
16+
|[t-SNE for ResNet feature visualization](https://www.learnopencv.com/t-sne-for-resnet-feature-visualization/)|[Code](https://github.com/spmallick/learnopencv/tree/master/TSNE)|
1617
|[Multi-Label Image Classification with Pytorch](https://www.learnopencv.com/multi-label-image-classification-with-pytorch/)|[Code](https://github.com/spmallick/learnopencv/tree/master/PyTorch-Multi-Label-Image-Classification)|
1718
|[CNN Receptive Field Computation Using Backprop](https://www.learnopencv.com/cnn-receptive-field-computation-using-backprop/)|[Code](https://github.com/spmallick/learnopencv/tree/master/PyTorch-Receptive-Field-With-Backprop)|
1819
|[Augmented Reality using AruCo Markers in OpenCV(C++ and Python)](https://www.learnopencv.com/augmented-reality-using-aruco-markers-in-opencv-(c++-python)/) |[Code](https://github.com/spmallick/learnopencv/tree/master/AugmentedRealityWithArucoMarkers)|

TSNE/README.md

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Setup
2+
3+
This code was tested with python 3.7, however, it should work with any python 3.
4+
5+
1. Create and activate virtual environment for experiments with t-SNE.
6+
7+
```bash
8+
python3 -m venv venv
9+
source venv/bin/activate
10+
```
11+
12+
2. install the dependencies
13+
14+
```bash
15+
python3 -m pip install -r requirements.txt
16+
```
17+
18+
# Data downloading
19+
20+
Download data from Kaggle and unzip it.
21+
The easiest way is to use kaggle console API. To setup it, follow [this guide](https://www.kaggle.com/general/74235).
22+
However, you can download the data using your browser - results will be the same.
23+
24+
After that, execute the following commands:
25+
26+
```bash
27+
28+
kaggle datasets download alessiocorrado99/animals10
29+
30+
mkdir -p data
31+
32+
cd data
33+
34+
unzip ../animals10.zip
35+
36+
cd ..
37+
38+
```
39+
40+
# Executing the T-SNE visualization
41+
42+
```bash
43+
44+
python3 tsne.py
45+
46+
```
47+
48+
Additional options:
49+
50+
```bash
51+
python3 tsne.py -h
52+
53+
usage: tsne.py [-h] [--path PATH] [--batch BATCH] [--num_images NUM_IMAGES]
54+
55+
optional arguments:
56+
-h, --help show this help message and exit
57+
--path PATH
58+
--batch BATCH
59+
--num_images NUM_IMAGES
60+
61+
```
62+
63+
You can change the data directory with `--path` argument.
64+
65+
Tweak the `--num_images` to speed-up the process - by default it is 500, you can make it smaller.
66+
67+
Tweak the `--batch` to better utilize your PC's resources. The script uses GPU automatically if it available. You may
68+
want to increase the batch size to utilize the GPU better or decrease it if the default batch size does not fit your
69+
GPU.

TSNE/animals_dataset.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
from os import path, listdir
2+
import torch
3+
from torchvision import transforms
4+
import random
5+
6+
from PIL import Image, ImageFile
7+
ImageFile.LOAD_TRUNCATED_IMAGES = True
8+
9+
10+
colors_per_class = {
11+
'dog' : [254, 202, 87],
12+
'horse' : [255, 107, 107],
13+
'elephant' : [10, 189, 227],
14+
'butterfly' : [255, 159, 243],
15+
'chicken' : [16, 172, 132],
16+
'cat' : [128, 80, 128],
17+
'cow' : [87, 101, 116],
18+
'sheep' : [52, 31, 151],
19+
'spider' : [0, 0, 0],
20+
'squirrel' : [100, 100, 255],
21+
}
22+
23+
24+
# processes Animals10 dataset: https://www.kaggle.com/alessiocorrado99/animals10
25+
class AnimalsDataset(torch.utils.data.Dataset):
26+
def __init__(self, data_path, num_images=1000):
27+
translation = {'cane' : 'dog',
28+
'cavallo' : 'horse',
29+
'elefante' : 'elephant',
30+
'farfalla' : 'butterfly',
31+
'gallina' : 'chicken',
32+
'gatto' : 'cat',
33+
'mucca' : 'cow',
34+
'pecora' : 'sheep',
35+
'ragno' : 'spider',
36+
'scoiattolo' : 'squirrel'}
37+
38+
self.classes = translation.values()
39+
40+
if not path.exists(data_path):
41+
raise Exception(data_path + ' does not exist!')
42+
43+
self.data = []
44+
45+
folders = listdir(data_path)
46+
for folder in folders:
47+
label = translation[folder]
48+
49+
full_path = path.join(data_path, folder)
50+
images = listdir(full_path)
51+
52+
current_data = [(path.join(full_path, image), label) for image in images]
53+
self.data += current_data
54+
55+
num_images = min(num_images, len(self.data))
56+
self.data = random.sample(self.data, num_images) # only use num_images images
57+
58+
# We use the transforms described in official PyTorch ResNet inference example:
59+
# https://pytorch.org/hub/pytorch_vision_resnet/.
60+
self.transform = transforms.Compose([
61+
transforms.Resize(256),
62+
transforms.CenterCrop(224),
63+
transforms.ToTensor(),
64+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
65+
])
66+
67+
68+
def __len__(self):
69+
return len(self.data)
70+
71+
72+
def __getitem__(self, index):
73+
image_path, label = self.data[index]
74+
75+
image = Image.open(image_path)
76+
77+
try:
78+
image = self.transform(image) # some images in the dataset cannot be processed - we'll skip them
79+
except Exception:
80+
return None
81+
82+
dict_data = {
83+
'image' : image,
84+
'label' : label,
85+
'image_path' : image_path
86+
}
87+
return dict_data
88+
89+
90+
# Skips empty samples in a batch
91+
def collate_skip_empty(batch):
92+
batch = [sample for sample in batch if sample] # check that sample is not None
93+
return torch.utils.data.dataloader.default_collate(batch)

TSNE/requirements.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
opencv-python>=3.4.1.15
2+
numpy>=1.18.1
3+
matplotlib>=3.2.0
4+
tqdm==4.45.0
5+
Pillow==7.0.0
6+
scikit-learn==0.22.2.post1
7+
scipy==1.4.1
8+
torch==1.4.0
9+
torchvision==0.5.0

TSNE/resnet.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import torch
2+
from torchvision import models
3+
from torch.hub import load_state_dict_from_url
4+
5+
6+
# Define the architecture by modifying resnet.
7+
# Original code is here
8+
# https://github.com/pytorch/vision/blob/b2e95657cd5f389e3973212ba7ddbdcc751a7878/torchvision/models/resnet.py
9+
class ResNet101(models.ResNet):
10+
def __init__(self, num_classes=1000, pretrained=True, **kwargs):
11+
# Start with standard resnet101 defined here
12+
# https://github.com/pytorch/vision/blob/b2e95657cd5f389e3973212ba7ddbdcc751a7878/torchvision/models/resnet.py
13+
super().__init__(block=models.resnet.Bottleneck, layers=[3, 4, 23, 3], num_classes=num_classes, **kwargs)
14+
if pretrained:
15+
state_dict = load_state_dict_from_url(models.resnet.model_urls['resnet101'], progress=True)
16+
self.load_state_dict(state_dict)
17+
18+
# Reimplementing forward pass.
19+
# Replacing the following code
20+
# https://github.com/pytorch/vision/blob/b2e95657cd5f389e3973212ba7ddbdcc751a7878/torchvision/models/resnet.py#L197-L213
21+
def _forward_impl(self, x):
22+
# Standard forward for resnet
23+
x = self.conv1(x)
24+
x = self.bn1(x)
25+
x = self.relu(x)
26+
x = self.maxpool(x)
27+
28+
x = self.layer1(x)
29+
x = self.layer2(x)
30+
x = self.layer3(x)
31+
x = self.layer4(x)
32+
33+
# Notice there is no forward pass through the original classifier.
34+
x = self.avgpool(x)
35+
x = torch.flatten(x, 1)
36+
37+
return x

0 commit comments

Comments
 (0)