Skip to content

Commit 6d30efe

Browse files
authored
feat(data): support custom dataset cache (Megvii-BaseDetection#1584)
feat(data): support custom dataset cache
1 parent f15f193 commit 6d30efe

File tree

11 files changed

+408
-302
lines changed

11 files changed

+408
-302
lines changed

README.md

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,9 @@ python -m yolox.tools.train -n yolox-s -d 8 -b 64 --fp16 -o [--cache]
122122
* -d: number of gpu devices
123123
* -b: total batch size, the recommended number for -b is num-gpu * 8
124124
* --fp16: mixed precision training
125-
* --cache: caching imgs into RAM to accelarate training, which need large system RAM.
125+
* --cache: caching imgs into RAM to accelarate training, which need large system RAM.
126+
126127

127-
128128

129129
When using -f, the above commands are equivalent to:
130130
```shell
@@ -140,7 +140,8 @@ We also support multi-nodes training. Just add the following args:
140140
* --num\_machines: num of your total training nodes
141141
* --machine\_rank: specify the rank of each node
142142

143-
Suppose you want to train YOLOX on 2 machines, and your master machines's IP is 123.123.123.123, use port 12312 and TCP.
143+
Suppose you want to train YOLOX on 2 machines, and your master machines's IP is 123.123.123.123, use port 12312 and TCP.
144+
144145
On master machine, run
145146
```shell
146147
python tools/train.py -n yolox-s -b 128 --dist-url tcp://123.123.123.123:12312 --num_machines 2 --machine_rank 0
@@ -163,7 +164,8 @@ python tools/train.py -n yolox-s -d 8 -b 64 --fp16 -o [--cache] --logger wandb w
163164

164165
An example wandb dashboard is available [here](https://wandb.ai/manan-goel/yolox-nano/runs/3pzfeom0)
165166

166-
**Others**
167+
**Others**
168+
167169
See more information with the following command:
168170
```shell
169171
python -m yolox.tools.train --help
@@ -202,6 +204,7 @@ python -m yolox.tools.eval -n yolox-s -c yolox_s.pth -b 1 -d 1 --conf 0.001 --f
202204
<summary>Tutorials</summary>
203205

204206
* [Training on custom data](docs/train_custom_data.md)
207+
* [Caching for custom data](docs/cache.md)
205208
* [Manipulating training image size](docs/manipulate_training_image_size.md)
206209
* [Freezing model](docs/freeze_module.md)
207210

docs/cache.md

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Cache Custom Data
2+
3+
The caching feature is specifically tailored for users with ample memory resources. However, we still offer the option to cache data to disk, but disk performance can vary and may not guarantee optimal user experience. Implementing custom dataset RAM caching is also more straightforward and user-friendly compared to disk caching. With a few simple modifications, users can expect to see a significant increase in training speed, with speeds nearly double that of non-cached datasets.
4+
5+
This page explains how to cache your own custom data with YOLOX.
6+
7+
## 0. Before you start
8+
9+
**Step1** Clone this repo and follow the [README](../README.md) to install YOLOX.
10+
11+
**Stpe2** Read the [Training on custom data](./train_custom_data.md) tutorial to understand how to prepare your custom data.
12+
13+
## 1. Inheirit from `CacheDataset`
14+
15+
16+
**Step1** Create a custom dataset that inherits from the `CacheDataset` class. Note that whether inheriting from `Dataset` or `CacheDataset `, the `__init__()` method of your custom dataset should take the following keyword arguments: `input_dimension`, `cache`, and `cache_type`. Also, call `super().__init__()` and pass in `input_dimension`, `num_imgs`, `cache`, and `cache_type` as input, where `num_imgs` is the size of the dataset.
17+
18+
**Step2** Implement the abstract function `read_img(self, index, use_cache=True)` of parent class and decorate it with `@cache_read_img`. This function takes an `index` as input and returns an `image`, and the returned image will be used for caching. It is recommended to put all repetitive and fixed post-processing operations on the image in this function to reduce the post-processing time of the image during training.
19+
20+
```python
21+
# CustomDataset.py
22+
from yolox.data.datasets import CacheDataset, cache_read_img
23+
24+
class CustomDataset(CacheDataset):
25+
def __init__(self, input_dimension, cache, cache_type, *args, **kwargs):
26+
# Get the required keyword arguments of super().__init__()
27+
super().__init__(
28+
input_dimension=input_dimension,
29+
num_imgs=num_imgs,
30+
cache=cache,
31+
cache_type=cache_type
32+
)
33+
# ...
34+
35+
@cache_read_img
36+
def read_img(self, index, use_cache=True):
37+
# get image ...
38+
# (optional) repetitive and fixed post-processing operations for image
39+
return image
40+
```
41+
42+
## 2. Create your Exp file and return your custom dataset
43+
44+
**Step1** Create a new class that inherits from the `Exp` class provided by the `yolox_base.py`. Override the `get_dataset()` and `get_eval_dataset()` method to return an instance of your custom dataset.
45+
46+
**Step2** Implement your own `get_evaluator` method to return an instance of your custom evaluator.
47+
48+
```python
49+
# CustomeExp.py
50+
from yolox.exp import Exp as MyExp
51+
52+
class Exp(MyExp):
53+
def get_dataset(self, cache, cache_type: str = "ram"):
54+
return CustomDataset(
55+
input_dimension=self.input_size,
56+
cache=cache,
57+
cache_type=cache_type
58+
)
59+
60+
def get_eval_dataset(self):
61+
return CustomDataset(
62+
input_dimension=self.input_size,
63+
)
64+
65+
def get_evaluator(self, batch_size, is_distributed, testdev=False, legacy=False):
66+
return CustomEvaluator(
67+
dataloader=self.get_eval_loader(batch_size, is_distributed, testdev=testdev, legacy=legacy),
68+
img_size=self.test_size,
69+
confthre=self.test_conf,
70+
nmsthre=self.nmsthre,
71+
num_classes=self.num_classes,
72+
testdev=testdev,
73+
)
74+
```
75+
76+
**(Optional)** `get_data_loader` and `get_eval_loader` are now a default behavior in `yolox_base.py` and generally do not need to be changed. If you have to change `get_data_loader`, you need to add the following code at the beginning.
77+
78+
```python
79+
# CustomeExp.py
80+
from yolox.exp import Exp as MyExp
81+
82+
class Exp(MyExp):
83+
def get_data_loader(self, batch_size, is_distributed, no_aug=False, cache_img: str = None):
84+
if self.dataset is None:
85+
with wait_for_the_master():
86+
assert cache_img is None
87+
self.dataset = self.get_dataset(cache=False, cache_type=cache_img)
88+
# ...
89+
90+
```
91+
92+
## 3. Cache to Disk
93+
It's important to note that the `cache_type` can be `"ram"` or `"disk"`, depending on where you want to cache your dataset. If you choose `"disk"`, you need to pass in additional parameters to `super().__init__()` of `CustomDataset`: `data_dir`, `cache_dir_name`, `path_filename`.
94+
95+
- `data_dir`: the root directory of the dataset, e.g. `/path/to/COCO`.
96+
- `cache_dir_name`: the name of the directory to cache to disk, for example `"custom_cache"`, then the files cached to disk will be saved under `/path/to/COCO/custom_cache`.
97+
- `path_filename`: a list of paths to the data relative to the `data_dir`, e.g. if you have data `/path/to/COCO/train/1.jpg`, `/path/to/COCO/train/2.jpg`, then `path_filename = ['train/1.jpg', ' train/2.jpg']`.
Lines changed: 14 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
# encoding: utf-8
22
import os
33

4-
import torch
5-
import torch.distributed as dist
6-
74
from yolox.data import get_yolox_datadir
85
from yolox.exp import Exp as MyExp
96

@@ -24,115 +21,40 @@ def __init__(self):
2421

2522
self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
2623

27-
def get_data_loader(self, batch_size, is_distributed, no_aug=False, cache_img=False):
28-
from yolox.data import (
29-
VOCDetection,
30-
TrainTransform,
31-
YoloBatchSampler,
32-
DataLoader,
33-
InfiniteSampler,
34-
MosaicDetection,
35-
worker_init_reset_seed,
36-
)
37-
from yolox.utils import (
38-
wait_for_the_master,
39-
get_local_rank,
40-
)
41-
local_rank = get_local_rank()
24+
def get_dataset(self, cache: bool, cache_type: str = "ram"):
25+
from yolox.data import VOCDetection, TrainTransform
4226

43-
with wait_for_the_master(local_rank):
44-
dataset = VOCDetection(
45-
data_dir=os.path.join(get_yolox_datadir(), "VOCdevkit"),
46-
image_sets=[('2007', 'trainval'), ('2012', 'trainval')],
47-
img_size=self.input_size,
48-
preproc=TrainTransform(
49-
max_labels=50,
50-
flip_prob=self.flip_prob,
51-
hsv_prob=self.hsv_prob),
52-
cache=cache_img,
53-
)
54-
55-
dataset = MosaicDetection(
56-
dataset,
57-
mosaic=not no_aug,
27+
return VOCDetection(
28+
data_dir=os.path.join(get_yolox_datadir(), "VOCdevkit"),
29+
image_sets=[('2007', 'trainval'), ('2012', 'trainval')],
5830
img_size=self.input_size,
5931
preproc=TrainTransform(
60-
max_labels=120,
32+
max_labels=50,
6133
flip_prob=self.flip_prob,
6234
hsv_prob=self.hsv_prob),
63-
degrees=self.degrees,
64-
translate=self.translate,
65-
mosaic_scale=self.mosaic_scale,
66-
mixup_scale=self.mixup_scale,
67-
shear=self.shear,
68-
enable_mixup=self.enable_mixup,
69-
mosaic_prob=self.mosaic_prob,
70-
mixup_prob=self.mixup_prob,
71-
)
72-
73-
self.dataset = dataset
74-
75-
if is_distributed:
76-
batch_size = batch_size // dist.get_world_size()
77-
78-
sampler = InfiniteSampler(
79-
len(self.dataset), seed=self.seed if self.seed else 0
35+
cache=cache,
36+
cache_type=cache_type,
8037
)
8138

82-
batch_sampler = YoloBatchSampler(
83-
sampler=sampler,
84-
batch_size=batch_size,
85-
drop_last=False,
86-
mosaic=not no_aug,
87-
)
88-
89-
dataloader_kwargs = {"num_workers": self.data_num_workers, "pin_memory": True}
90-
dataloader_kwargs["batch_sampler"] = batch_sampler
91-
92-
# Make sure each process has different random seed, especially for 'fork' method
93-
dataloader_kwargs["worker_init_fn"] = worker_init_reset_seed
94-
95-
train_loader = DataLoader(self.dataset, **dataloader_kwargs)
96-
97-
return train_loader
98-
99-
def get_eval_loader(self, batch_size, is_distributed, testdev=False, legacy=False):
39+
def get_eval_dataset(self, **kwargs):
10040
from yolox.data import VOCDetection, ValTransform
41+
legacy = kwargs.get("legacy", False)
10142

102-
valdataset = VOCDetection(
43+
return VOCDetection(
10344
data_dir=os.path.join(get_yolox_datadir(), "VOCdevkit"),
10445
image_sets=[('2007', 'test')],
10546
img_size=self.test_size,
10647
preproc=ValTransform(legacy=legacy),
10748
)
10849

109-
if is_distributed:
110-
batch_size = batch_size // dist.get_world_size()
111-
sampler = torch.utils.data.distributed.DistributedSampler(
112-
valdataset, shuffle=False
113-
)
114-
else:
115-
sampler = torch.utils.data.SequentialSampler(valdataset)
116-
117-
dataloader_kwargs = {
118-
"num_workers": self.data_num_workers,
119-
"pin_memory": True,
120-
"sampler": sampler,
121-
}
122-
dataloader_kwargs["batch_size"] = batch_size
123-
val_loader = torch.utils.data.DataLoader(valdataset, **dataloader_kwargs)
124-
125-
return val_loader
126-
12750
def get_evaluator(self, batch_size, is_distributed, testdev=False, legacy=False):
12851
from yolox.evaluators import VOCEvaluator
12952

130-
val_loader = self.get_eval_loader(batch_size, is_distributed, testdev, legacy)
131-
evaluator = VOCEvaluator(
132-
dataloader=val_loader,
53+
return VOCEvaluator(
54+
dataloader=self.get_eval_loader(batch_size, is_distributed,
55+
testdev=testdev, legacy=legacy),
13356
img_size=self.test_size,
13457
confthre=self.test_conf,
13558
nmsthre=self.nmsthre,
13659
num_classes=self.num_classes,
13760
)
138-
return evaluator

tools/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def main(exp: Exp, args):
131131
assert num_gpu <= get_num_devices()
132132

133133
if args.cache is not None:
134-
exp.create_cache_dataset(args.cache)
134+
exp.dataset = exp.get_dataset(cache=True, cache_type=args.cache)
135135

136136
dist_url = "auto" if args.dist_url is None else args.dist_url
137137
launch(

yolox/data/datasets/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@
44

55
from .coco import COCODataset
66
from .coco_classes import COCO_CLASSES
7-
from .datasets_wrapper import ConcatDataset, Dataset, MixConcatDataset
7+
from .datasets_wrapper import CacheDataset, ConcatDataset, Dataset, MixConcatDataset
88
from .mosaicdetection import MosaicDetection
99
from .voc import VOCDetection

0 commit comments

Comments
 (0)