Skip to content

Commit 230df3a

Browse files
Merge pull request #93 from tsingcbx99/dev-ssl
Dev ssl
2 parents 5fefb10 + 06517c0 commit 230df3a

22 files changed

+278
-254
lines changed

docs/index.rst

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,6 @@ Transfer Learning
3535

3636
talib/benchmarks/image_classification
3737

38-
.. toctree::
39-
:maxdepth: 2
40-
:caption: Semi Supervised Learning Settings
41-
:titlesonly:
42-
43-
ssllib/benchmarks/image_classification
44-
4538

4639
.. toctree::
4740
:maxdepth: 2
@@ -80,7 +73,10 @@ Transfer Learning
8073
:caption: Semi Supervised Learning Methods
8174
:titlesonly:
8275

83-
ssllib/semi_supervised_learning.rst
76+
ssllib/consistency_regularization.rst
77+
ssllib/contrastive_learning.rst
78+
ssllib/holistic_methods.rst
79+
ssllib/proxy_label.rst
8480

8581

8682

docs/ssllib/benchmarks/image_classification.rst

Lines changed: 0 additions & 70 deletions
This file was deleted.
Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
11
=======================================
2-
Semi Supervised Learning
3-
=======================================
4-
52
Consistency Regularization
63
=======================================
74

@@ -43,40 +40,3 @@ Unsupervised Data Augmentation (UDA)
4340
.. autoclass:: ssllib.uda.SupervisedUDALoss
4441

4542
.. autoclass:: ssllib.uda.UnsupervisedUDALoss
46-
47-
48-
Pseudo Labels
49-
=======================================
50-
51-
.. _PSEUDO:
52-
53-
Pseudo Label
54-
------------------
55-
56-
Given model predictions :math:`y` on unlabeled samples, we can directly utilize them to generate
57-
pseudo labels :math:`label=\mathop{\arg\max}\limits_{i}~y[i]`. Then we use these pseudo labels as supervision to train
58-
our model. Details can be found at `projects/self_tuning/pseudo_label.py`.
59-
60-
61-
Holistic Methods
62-
=======================================
63-
64-
.. _FIXMATCH:
65-
66-
FixMatch
67-
------------------
68-
69-
.. autoclass:: ssllib.fix_match.FixMatchConsistencyLoss
70-
71-
72-
Contrastive Learning
73-
=======================================
74-
75-
.. _SELF_TUNING:
76-
77-
Self-Tuning
78-
------------------
79-
80-
.. autoclass:: ssllib.self_tuning.Classifier
81-
82-
.. autoclass:: ssllib.self_tuning.SelfTuning
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
=======================================
2+
Contrastive Learning
3+
=======================================
4+
5+
.. _SELF_TUNING:
6+
7+
Self-Tuning
8+
------------------
9+
10+
.. autoclass:: ssllib.self_tuning.Classifier
11+
12+
.. autoclass:: ssllib.self_tuning.SelfTuning

docs/ssllib/holistic_methods.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
=======================================
2+
Holistic Methods
3+
=======================================
4+
5+
.. _FIXMATCH:
6+
7+
FixMatch
8+
------------------
9+
10+
.. autoclass:: ssllib.fix_match.FixMatchConsistencyLoss

docs/ssllib/proxy_label.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
=======================================
2+
Proxy-Label Based Methods
3+
=======================================
4+
5+
.. _PSEUDO:
6+
7+
Pseudo Label
8+
------------------
9+
10+
Given model predictions :math:`y` on unlabeled samples, we can directly utilize them to generate
11+
pseudo labels :math:`label=\mathop{\arg\max}\limits_{i}~y[i]`. Then we use these pseudo labels as supervision to train
12+
our model. Details can be found at `projects/self_tuning/pseudo_label.py`.

projects/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
Here are a few projects that are built on Trans-Learn.
2-
They are examples of how to use Trans-Learn as a library, to facilitate your own research.
1+
Here are a few projects that are built on Trans-Learn. They are examples of how to use Trans-Learn as a library, to
2+
facilitate your own research.
33

44
## Projects by [THUML](https://github.com/thuml)
55

projects/self_tuning/README.md

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,9 @@ Supported methods include:
3333
## Experiments and Results
3434

3535
### SSL with supervised pre-trained model
36-
The shell files give the script to reproduce our [results](/docs/ssllib/benchmarks/image_classification.rst#) with specified hyper-parameters.
37-
For example, if you want to run baseline on CUB200 with 15% labeled samples, use the following script
36+
37+
The shell files give the script to reproduce our [results](benchmark.md) with specified hyper-parameters. For example,
38+
if you want to run baseline on CUB200 with 15% labeled samples, use the following script
3839

3940
```shell script
4041
# SSL with ResNet50 backbone on CUB200.
@@ -44,24 +45,32 @@ CUDA_VISIBLE_DEVICES=0 python baseline.py data/cub200 -d CUB200 -sr 15 --seed 0
4445
```
4546

4647
### SSL with unsupervised pre-trained model
47-
Take MoCo as an example.
48+
49+
Take MoCo as an example.
50+
4851
1. Download MoCo pretrained checkpoints from https://github.com/facebookresearch/moco
49-
2. Convert the format of the MoCo checkpoints to the standard format of pytorch
52+
2. Convert the format of the MoCo checkpoints to the standard format of pytorch
53+
5054
```shell
5155
mkdir checkpoints
5256
python convert_moco_to_pretrained.py checkpoints/moco_v1_200ep_pretrain.pth.tar checkpoints/moco_v1_200ep_backbone.pth checkpoints/moco_v1_200ep_fc.pth
5357
```
58+
5459
3. Start training
60+
5561
```shell
5662
CUDA_VISIBLE_DEVICES=0 python baseline.py data/cub200 -d CUB200 -sr 15 --seed 0 --log logs/baseline_moco/cub200_15 \
5763
--pretrained checkpoints/moco_v1_200ep_backbone.pth
5864
```
5965

6066
## TODO
67+
6168
Support datasets: CIFAR10, CIFAR100, ImageNet
6269

6370
## Citation
71+
6472
If you use these methods in your research, please consider citing.
73+
6574
```
6675
@inproceedings{pi-model,
6776
title={Temporal ensembling for semi-supervised learning},

projects/self_tuning/baseline.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,9 @@ def main(args: argparse.Namespace):
8383
pool_layer = nn.Identity() if args.no_pool else None
8484
classifier = Classifier(backbone, num_classes, pool_layer=pool_layer, finetune=not args.scratch).to(device)
8585

86-
# define optimizer
86+
# define optimizer and lr scheduler
8787
optimizer = SGD(classifier.get_parameters(args.lr), args.lr, momentum=0.9, weight_decay=args.wd, nesterov=True)
88+
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.milestones, gamma=args.lr_gamma)
8889

8990
# resume from the best checkpoint
9091
if args.phase == 'test':
@@ -97,8 +98,15 @@ def main(args: argparse.Namespace):
9798
# start training
9899
best_acc1 = 0.0
99100
for epoch in range(args.epochs):
101+
# print lr
102+
print(lr_scheduler.get_lr())
103+
100104
# train for one epoch
101105
train(labeled_train_iter, classifier, optimizer, epoch, args)
106+
107+
# update lr
108+
lr_scheduler.step()
109+
102110
# evaluate on validation set
103111
with torch.no_grad():
104112
acc1 = utils.validate(val_loader, classifier, args, device)
@@ -188,6 +196,8 @@ def train(labeled_train_iter: ForeverDataIterator, model, optimizer: SGD, epoch:
188196
help='mini-batch size (default: 48)')
189197
parser.add_argument('--lr', '--learning-rate', default=0.01, type=float,
190198
metavar='LR', help='initial learning rate', dest='lr')
199+
parser.add_argument('--lr-gamma', default=0.1, type=float, help='parameter for lr scheduler')
200+
parser.add_argument('--milestones', type=int, default=[5], nargs='+', help='epochs to decay lr')
191201
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
192202
metavar='W', help='weight decay (default:1e-4)')
193203
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',

projects/self_tuning/baseline.sh

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,25 +17,25 @@ CUDA_VISIBLE_DEVICES=0 python baseline.py data/aircraft -d Aircraft -sr 50 --see
1717

1818
# MoCo (Unsupervised Pretraining)
1919
# ResNet50, CUB200
20-
CUDA_VISIBLE_DEVICES=0 python baseline.py data/cub200 -d CUB200 -i 2000 -sr 15 --seed 0 --log logs/baseline_moco/cub200_15 \
21-
--pretrained checkpoints/moco_v1_200ep_backbone.pth
22-
CUDA_VISIBLE_DEVICES=0 python baseline.py data/cub200 -d CUB200 -i 2000 -sr 30 --seed 0 --log logs/baseline_moco/cub200_30 \
23-
--pretrained checkpoints/moco_v1_200ep_backbone.pth
24-
CUDA_VISIBLE_DEVICES=0 python baseline.py data/cub200 -d CUB200 -i 2000 -sr 50 --seed 0 --log logs/baseline_moco/cub200_50 \
25-
--pretrained checkpoints/moco_v1_200ep_backbone.pth
20+
CUDA_VISIBLE_DEVICES=0 python baseline.py data/cub200 -d CUB200 --lr 0.1 --epochs 12 --milestones 3 6 9 \
21+
-i 2000 -sr 15 --seed 0 --log logs/baseline_moco/cub200_15 --pretrained checkpoints/moco_v1_200ep_backbone.pth
22+
CUDA_VISIBLE_DEVICES=0 python baseline.py data/cub200 -d CUB200 --lr 0.1 --epochs 12 --milestones 3 6 9 \
23+
-i 2000 -sr 30 --seed 0 --log logs/baseline_moco/cub200_30 --pretrained checkpoints/moco_v1_200ep_backbone.pth
24+
CUDA_VISIBLE_DEVICES=0 python baseline.py data/cub200 -d CUB200 --lr 0.1 --epochs 12 --milestones 3 6 9 \
25+
-i 2000 -sr 50 --seed 0 --log logs/baseline_moco/cub200_50 --pretrained checkpoints/moco_v1_200ep_backbone.pth
2626

2727
# ResNet50, StanfordCars
28-
CUDA_VISIBLE_DEVICES=0 python baseline.py data/stanford_cars -d StanfordCars -i 2000 -sr 15 --seed 0 --log logs/baseline_moco/car_15 \
29-
--pretrained checkpoints/moco_v1_200ep_backbone.pth
30-
CUDA_VISIBLE_DEVICES=0 python baseline.py data/stanford_cars -d StanfordCars -i 2000 -sr 30 --seed 0 --log logs/baseline_moco/car_30 \
31-
--pretrained checkpoints/moco_v1_200ep_backbone.pth
32-
CUDA_VISIBLE_DEVICES=0 python baseline.py data/stanford_cars -d StanfordCars -i 2000 -sr 50 --seed 0 --log logs/baseline_moco/car_50 \
33-
--pretrained checkpoints/moco_v1_200ep_backbone.pth
28+
CUDA_VISIBLE_DEVICES=0 python baseline.py data/stanford_cars -d StanfordCars --lr 0.1 --epochs 12 --milestones 3 6 9 \
29+
-i 2000 -sr 15 --seed 0 --log logs/baseline_moco/car_15 --pretrained checkpoints/moco_v1_200ep_backbone.pth
30+
CUDA_VISIBLE_DEVICES=0 python baseline.py data/stanford_cars -d StanfordCars --lr 0.1 --epochs 12 --milestones 3 6 9 \
31+
-i 2000 -sr 30 --seed 0 --log logs/baseline_moco/car_30 --pretrained checkpoints/moco_v1_200ep_backbone.pth
32+
CUDA_VISIBLE_DEVICES=0 python baseline.py data/stanford_cars -d StanfordCars --lr 0.1 --epochs 12 --milestones 3 6 9 \
33+
-i 2000 -sr 50 --seed 0 --log logs/baseline_moco/car_50 --pretrained checkpoints/moco_v1_200ep_backbone.pth
3434

3535
# ResNet50, Aircraft
36-
CUDA_VISIBLE_DEVICES=0 python baseline.py data/aircraft -d Aircraft -i 2000 -sr 15 --seed 0 --log logs/baseline_moco/aircraft_15 \
37-
--pretrained checkpoints/moco_v1_200ep_backbone.pth
38-
CUDA_VISIBLE_DEVICES=0 python baseline.py data/aircraft -d Aircraft -i 2000 -sr 30 --seed 0 --log logs/baseline_moco/aircraft_30 \
39-
--pretrained checkpoints/moco_v1_200ep_backbone.pth
40-
CUDA_VISIBLE_DEVICES=0 python baseline.py data/aircraft -d Aircraft -i 2000 -sr 50 --seed 0 --log logs/baseline_moco/aircraft_50 \
41-
--pretrained checkpoints/moco_v1_200ep_backbone.pth
36+
CUDA_VISIBLE_DEVICES=0 python baseline.py data/aircraft -d Aircraft --lr 0.1 --epochs 12 --milestones 3 6 9 \
37+
-i 2000 -sr 15 --seed 0 --log logs/baseline_moco/aircraft_15 --pretrained checkpoints/moco_v1_200ep_backbone.pth
38+
CUDA_VISIBLE_DEVICES=0 python baseline.py data/aircraft -d Aircraft --lr 0.1 --epochs 12 --milestones 3 6 9 \
39+
-i 2000 -sr 30 --seed 0 --log logs/baseline_moco/aircraft_30 --pretrained checkpoints/moco_v1_200ep_backbone.pth
40+
CUDA_VISIBLE_DEVICES=0 python baseline.py data/aircraft -d Aircraft --lr 0.1 --epochs 12 --milestones 3 6 9 \
41+
-i 2000 -sr 50 --seed 0 --log logs/baseline_moco/aircraft_50 --pretrained checkpoints/moco_v1_200ep_backbone.pth

0 commit comments

Comments
 (0)