@@ -61,15 +61,11 @@ def __init__(self, path: str, shuffle: bool = True):
6161
6262 # Read train trials file
6363 with open (os .path .join (path , self ._TRAIN_TRIALS_FILE ), "r" ) as f :
64- self ._TRAIN_FILES = map (
65- lambda d : os .path .join (path , d .rstrip ()), f .readlines ()
66- )
64+ self ._TRAIN_FILES = map (lambda d : os .path .join (path , d .rstrip ()), f .readlines ())
6765
6866 # Read test trials file
6967 with open (os .path .join (path , self ._TEST_TRIALS_FILE ), "r" ) as f :
70- self ._TEST_FILES = map (
71- lambda d : os .path .join (path , d .rstrip ()), f .readlines ()
72- )
68+ self ._TEST_FILES = map (lambda d : os .path .join (path , d .rstrip ()), f .readlines ())
7369
7470 self ._TRAIN_FILES = list (filter (lambda f : os .path .isfile (f ), self ._TRAIN_FILES ))
7571 self ._TEST_FILES = list (filter (lambda f : os .path .isfile (f ), self ._TEST_FILES ))
@@ -80,9 +76,7 @@ def __init__(self, path: str, shuffle: bool = True):
8076
8177 def _read_labels (self , file : str ) -> np .array :
8278 assert os .path .exists (file ), "File %s doesn't exist" % file
83- return np .genfromtxt (
84- file , delimiter = "," , skip_header = 1 , dtype = self ._LABELS_DTYPE
85- )
79+ return np .genfromtxt (file , delimiter = "," , skip_header = 1 , dtype = self ._LABELS_DTYPE )
8680
8781 def _parse_filename (self , file : str ) -> Tuple [str , str , str ]:
8882 trial = re .search (r"^user([0-9]+)_(.+)\.(aedat|csv)$" , file , re .IGNORECASE )
@@ -108,13 +102,9 @@ def _create_generator(self, files: List[str]):
108102 labels = self ._read_labels (file .replace (".aedat" , "_labels.csv" ))
109103 multilabel_spike_train = readAEDATv3 (file )
110104 for (label_id , start_time , end_time ) in labels :
111- event_mask = (multilabel_spike_train .ts >= start_time ) & (
112- multilabel_spike_train .ts < end_time
113- )
105+ event_mask = (multilabel_spike_train .ts >= start_time ) & (multilabel_spike_train .ts < end_time )
114106 ts = multilabel_spike_train .ts [event_mask ] - start_time
115- spike_train = DVSSpikeTrain (
116- ts .size , width = 128 , height = 128 , duration = end_time - start_time + 1
117- )
107+ spike_train = DVSSpikeTrain (ts .size , width = 128 , height = 128 , duration = end_time - start_time + 1 )
118108 spike_train .ts = ts
119109 spike_train .x = multilabel_spike_train .x [event_mask ]
120110 spike_train .y = multilabel_spike_train .y [event_mask ]
@@ -165,9 +155,7 @@ def __init__(self, path: str, is_train: bool = True):
165155 """
166156 _ , file_extension = os .path .splitext (path )
167157 if file_extension != ".h5" :
168- raise Exception (
169- "The dvs gesture must first be converted to a .h5 file. Please call H5DvsGesture.Convert"
170- )
158+ raise Exception ("The dvs gesture must first be converted to a .h5 file. Please call H5DvsGesture.Convert" )
171159
172160 self .indx = 0 if is_train else 1
173161 self .file_path = path
@@ -189,9 +177,7 @@ def convert(dvs_folder_path: str, h5_output_path: str, verbose=True):
189177 position_type = h5py .vlen_dtype (np .dtype ("uint16" ))
190178 time_type = h5py .vlen_dtype (np .dtype ("uint32" ))
191179
192- step_counter = tqdm (
193- total = sum (H5IBMGesture ._nb_of_samples ), disable = (not verbose )
194- )
180+ step_counter = tqdm (total = sum (H5IBMGesture ._nb_of_samples ), disable = (not verbose ))
195181
196182 with h5py .File (h5_output_path , "w-" ) as f :
197183 for (name , gen , length ) in zip (
@@ -223,9 +209,7 @@ def __getitem__(self, index):
223209 tos = file_hndl [name + "_tos" ][index ]
224210 label = file_hndl [name + "_label" ][index ]
225211
226- spike_train = DVSSpikeTrain (
227- tos .size , width = 128 , height = 128 , duration = tos .max () + 1
228- )
212+ spike_train = DVSSpikeTrain (tos .size , width = 128 , height = 128 , duration = tos .max () + 1 )
229213 spike_train .x = pos [0 ]
230214 spike_train .y = pos [1 ]
231215 spike_train .p = pos [2 ]
0 commit comments