11#!/usr/bin/env python3
22# -*- coding:utf-8 -*-
33# Copyright (c) Megvii, Inc. and its affiliates.
4-
4+ import copy
55import os
6+ import random
7+ from multiprocessing .pool import ThreadPool
8+ import psutil
69from loguru import logger
10+ from tqdm import tqdm
711
812import cv2
913import numpy as np
@@ -45,6 +49,7 @@ def __init__(
4549 img_size = (416 , 416 ),
4650 preproc = None ,
4751 cache = False ,
52+ cache_type = "ram" ,
4853 ):
4954 """
5055 COCO dataset initialization. Annotation data are read into memory by COCO API.
@@ -64,74 +69,95 @@ def __init__(
6469 self .coco = COCO (os .path .join (self .data_dir , "annotations" , self .json_file ))
6570 remove_useless_info (self .coco )
6671 self .ids = self .coco .getImgIds ()
72+ self .num_imgs = len (self .ids )
6773 self .class_ids = sorted (self .coco .getCatIds ())
6874 self .cats = self .coco .loadCats (self .coco .getCatIds ())
6975 self ._classes = tuple ([c ["name" ] for c in self .cats ])
70- self .imgs = None
7176 self .name = name
7277 self .img_size = img_size
7378 self .preproc = preproc
7479 self .annotations = self ._load_coco_annotations ()
75- if cache :
80+ self .imgs = None
81+ self .cache = cache
82+ self .cache_type = cache_type
83+
84+ if self .cache :
7685 self ._cache_images ()
7786
78- def __len__ (self ):
79- return len (self .ids )
87+ def _cache_images (self ):
88+ mem = psutil .virtual_memory ()
89+ mem_required = self .cal_cache_ram ()
90+ gb = 1 << 30
8091
81- def __del__ (self ):
82- del self .imgs
92+ if self .cache_type == "ram" and mem_required > mem .available :
93+ self .cache = False
94+ else :
95+ logger .info (
96+ f"{ mem_required / gb :.1f} GB RAM required, "
97+ f"{ mem .available / gb :.1f} /{ mem .total / gb :.1f} GB RAM available, "
98+ f"Since the first thing we do is cache, "
99+ f"there is no guarantee that the remaining memory space is sufficient"
100+ )
83101
84- def _load_coco_annotations (self ):
85- return [self .load_anno_from_ids (_ids ) for _ids in self .ids ]
102+ if self .cache and self .imgs is None :
103+ if self .cache_type == 'ram' :
104+ self .imgs = [None ] * self .num_imgs
105+ logger .info ("You are using cached images in RAM to accelerate training!" )
106+ else : # 'disk'
107+ self .cache_dir = os .path .join (
108+ self .data_dir ,
109+ f"{ self .name } _cache{ self .img_size [0 ]} x{ self .img_size [1 ]} "
110+ )
111+ if not os .path .exists (self .cache_dir ):
112+ os .mkdir (self .cache_dir )
113+ logger .warning (
114+ f"\n *******************************************************************\n "
115+ f"You are using cached images in DISK to accelerate training.\n "
116+ f"This requires large DISK space.\n "
117+ f"Make sure you have { mem_required / gb :.1f} "
118+ f"available DISK space for training COCO.\n "
119+ f"*******************************************************************\\ n"
120+ )
121+ else :
122+ logger .info ("Found disk cache!" )
123+ return
86124
87- def _cache_images (self ):
88- logger .warning (
89- "\n ********************************************************************************\n "
90- "You are using cached images in RAM to accelerate training.\n "
91- "This requires large system RAM.\n "
92- "Make sure you have 200G+ RAM and 136G available disk space for training COCO.\n "
93- "********************************************************************************\n "
94- )
95- max_h = self .img_size [0 ]
96- max_w = self .img_size [1 ]
97- cache_file = os .path .join (self .data_dir , f"img_resized_cache_{ self .name } .array" )
98- if not os .path .exists (cache_file ):
99125 logger .info (
100- "Caching images for the first time. This might take about 20 minutes for COCO"
126+ "Caching images for the first time. "
127+ "This might take about 15 minutes for COCO"
101128 )
102- self .imgs = np .memmap (
103- cache_file ,
104- shape = (len (self .ids ), max_h , max_w , 3 ),
105- dtype = np .uint8 ,
106- mode = "w+" ,
107- )
108- from tqdm import tqdm
109- from multiprocessing .pool import ThreadPool
110129
111- NUM_THREADs = min (8 , os .cpu_count ())
112- loaded_images = ThreadPool (NUM_THREADs ).imap (
113- lambda x : self .load_resized_img (x ),
114- range (len (self .annotations )),
115- )
116- pbar = tqdm (enumerate (loaded_images ), total = len (self .annotations ))
117- for k , out in pbar :
118- self .imgs [k ][: out .shape [0 ], : out .shape [1 ], :] = out .copy ()
119- self .imgs .flush ()
130+ num_threads = min (8 , max (1 , os .cpu_count () - 1 ))
131+ b = 0
132+ load_imgs = ThreadPool (num_threads ).imap (self .load_resized_img , range (self .num_imgs ))
133+ pbar = tqdm (enumerate (load_imgs ), total = self .num_imgs )
134+ for i , x in pbar : # x = self.load_resized_img(self, i)
135+ if self .cache_type == 'ram' :
136+ self .imgs [i ] = x
137+ else : # 'disk'
138+ cache_filename = f'{ self .annotations [i ]["filename" ].split ("." )[0 ]} .npy'
139+ np .save (os .path .join (self .cache_dir , cache_filename ), x )
140+ b += x .nbytes
141+ pbar .desc = f'Caching images ({ b / gb :.1f} /{ mem_required / gb :.1f} GB { self .cache } )'
120142 pbar .close ()
121- else :
122- logger .warning (
123- "You are using cached imgs! Make sure your dataset is not changed!!\n "
124- "Everytime the self.input_size is changed in your exp file, you need to delete\n "
125- "the cached data and re-generate them.\n "
126- )
127143
128- logger .info ("Loading cached imgs..." )
129- self .imgs = np .memmap (
130- cache_file ,
131- shape = (len (self .ids ), max_h , max_w , 3 ),
132- dtype = np .uint8 ,
133- mode = "r+" ,
134- )
144+ def cal_cache_ram (self ):
145+ cache_bytes = 0
146+ num_samples = min (self .num_imgs , 32 )
147+ for _ in range (num_samples ):
148+ img = self .load_resized_img (random .randint (0 , self .num_imgs - 1 ))
149+ cache_bytes += img .nbytes
150+ mem_required = cache_bytes * self .num_imgs / num_samples
151+ return mem_required
152+
153+ def __len__ (self ):
154+ return self .num_imgs
155+
156+ def __del__ (self ):
157+ del self .imgs
158+
159+ def _load_coco_annotations (self ):
160+ return [self .load_anno_from_ids (_ids ) for _ids in self .ids ]
135161
136162 def load_anno_from_ids (self , id_ ):
137163 im_ann = self .coco .loadImgs (id_ )[0 ]
@@ -152,7 +178,6 @@ def load_anno_from_ids(self, id_):
152178 num_objs = len (objs )
153179
154180 res = np .zeros ((num_objs , 5 ))
155-
156181 for ix , obj in enumerate (objs ):
157182 cls = self .class_ids .index (obj ["category_id" ])
158183 res [ix , 0 :4 ] = obj ["clean_bbox" ]
@@ -197,15 +222,16 @@ def load_image(self, index):
197222
198223 def pull_item (self , index ):
199224 id_ = self .ids [index ]
225+ label , origin_image_size , _ , filename = self .annotations [index ]
200226
201- res , img_info , resized_info , _ = self . annotations [ index ]
202- if self . imgs is not None :
203- pad_img = self . imgs [ index ]
204- img = pad_img [: resized_info [0 ], : resized_info [ 1 ], :]. copy ( )
227+ if self . cache_type == 'ram' :
228+ img = self . imgs [ index ]
229+ elif self . cache_type == 'disk' :
230+ img = np . load ( os . path . join ( self . cache_dir , f" { filename . split ( '.' ) [0 ]} .npy" ) )
205231 else :
206232 img = self .load_resized_img (index )
207233
208- return img , res . copy ( ), img_info , np .array ([id_ ])
234+ return copy . deepcopy ( img ), copy . deepcopy ( label ), origin_image_size , np .array ([id_ ])
209235
210236 @Dataset .mosaic_getitem
211237 def __getitem__ (self , index ):
0 commit comments