Skip to content

Commit a4bf64e

Browse files
author
Ismael Balafrej
committed
Changed units system from quantities to pint
1 parent 12715c6 commit a4bf64e

File tree

7 files changed

+44
-21
lines changed

7 files changed

+44
-21
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ In the code:
5151
```python
5252
from ebdataset.vision import NMnist
5353
from ebdataset.vision.transforms import ToDense
54-
from quantities import ms
54+
from ebdataset import ms
5555

5656
# With sparse representation:
5757
for spike_train, label in NMnist(path):

ebdataset/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
import ebdataset.vision as vision
2+
import ebdataset.audio as audio
3+
import ebdataset.generated as generated
4+
from .utils.units import *

ebdataset/generated/parity.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22
import torch.utils.data as data
3-
from quantities import hertz, ms, second
3+
from ..utils.units import hertz, ms, second, wunits
44

55

66
class ParityTask(data.IterableDataset):
@@ -11,6 +11,7 @@ class ParityTask(data.IterableDataset):
1111
bits are encoded both temporally and spatially if sequential=True, otherwise only spatially
1212
"""
1313

14+
@wunits(None, (None, None, hertz, hertz, second, None, None, second, second, None, None, None))
1415
def __init__(
1516
self,
1617
seed=0x1B,
@@ -27,35 +28,35 @@ def __init__(
2728
):
2829
self.seed = seed
2930
self.max_iter = max_iter
30-
self.gen = np.random.RandomState(seed=self.seed)
31-
self.low_freq = float((low_freq * dt).simplified)
32-
self.high_freq = float((high_freq * dt).simplified)
33-
self.sample_duration = int((sample_duration / dt).simplified)
34-
self.duration_per_bit = int((duration_per_bit / dt).simplified)
31+
self.rand = np.random.RandomState(seed=self.seed)
32+
self.low_freq = low_freq * dt
33+
self.high_freq = high_freq * dt
34+
self.sample_duration = int(sample_duration / dt)
35+
self.duration_per_bit = int(duration_per_bit / dt)
3536
self.number_of_bits = number_of_bits
3637
self.features_per_bit = features_per_bit
3738
self.as_recarray = as_recarray
3839
self.sequential = sequential
3940
if sequential:
4041
assert (
41-
duration_per_bit * number_of_bits >= sample_duration
42+
duration_per_bit * number_of_bits <= sample_duration
4243
), "Sample duration is not enough to contain every bits"
4344

4445
def __iter__(self):
4546
worker_info = data.get_worker_info()
4647
m = 1
4748
if worker_info is not None: # multi-process data loading, re-seed the iterator
48-
self.gen = np.random.RandomState(seed=worker_info.id + self.seed)
49+
self.rand = np.random.RandomState(seed=worker_info.id + self.seed)
4950
m = worker_info.num_workers
5051

5152
i = 0
5253
while i < self.max_iter / m:
5354
i += 1
5455

55-
bits = self.gen.randint(0, 2, size=self.number_of_bits)
56+
bits = self.rand.randint(0, 2, size=self.number_of_bits)
5657
y = np.sum(bits) % 2
5758

58-
spike_train = self.gen.poisson(
59+
spike_train = self.rand.poisson(
5960
lam=self.low_freq, size=(self.number_of_bits * self.features_per_bit, self.sample_duration)
6061
)
6162

@@ -68,7 +69,7 @@ def __iter__(self):
6869
)
6970
spike_train[
7071
b * self.features_per_bit : (b + 1) * self.features_per_bit, time_pos
71-
] = self.gen.poisson(lam=self.high_freq, size=(self.features_per_bit, self.duration_per_bit))
72+
] = self.rand.poisson(lam=self.high_freq, size=(self.features_per_bit, self.duration_per_bit))
7273

7374
if self.as_recarray:
7475
ts, addr = np.nonzero(spike_train.T)

ebdataset/utils/units.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from pint import UnitRegistry as _UR
2+
3+
_reg = _UR()
4+
5+
wunits = _reg.wraps
6+
7+
## Global time management namespace
8+
second = s = _reg.s
9+
millisecond = ms = _reg.ms
10+
microsecond = us = _reg.us
11+
nanosecond = ns = _reg.ns
12+
killosecond = ks = _reg.ks
13+
hertz = Hz = _reg.Hz
14+
millihertz = mhertz = mHz = _reg.mHz
15+
kilohertz = khertz = kHz = _reg.kHz

ebdataset/vision/ini_roshambo.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from .parsers.aedat import readAEDATv2_davies
66
from torch.utils.data.dataset import Dataset
77
from .type import DVSSpikeTrain
8-
from quantities import us
8+
from ..utils.units import us, wunits
99

1010

1111
class INIRoshambo(Dataset):
@@ -66,11 +66,12 @@ def convert(self, out_path, verbose=False):
6666

6767
return INIRoshambo(out_path, with_backgrounds=self.with_backgrounds, transforms=self.transforms)
6868

69+
@wunits(None, (None, None, us, None))
6970
def split_to_subsamples(self, out_path, duration_per_sample, verbose=False):
7071
if not (".h5" in out_path):
7172
out_path += ".h5"
7273

73-
duration_per_sample = int(duration_per_sample.rescale(us).magnitude)
74+
duration_per_sample = int(duration_per_sample)
7475

7576
with File(out_path, "w-", libver="latest") as f_hndl:
7677
for i, (sample, label) in enumerate(tqdm(self, disable=not verbose)):

ebdataset/vision/transforms.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""
44
import torch
55
import numpy as np
6-
import quantities as units
6+
from ..utils.units import second, us, ms, wunits
77
from torchvision.transforms import Compose
88

99

@@ -32,8 +32,9 @@ def __call__(self, sparse_spike_train):
3232
class MaxTime(object):
3333
"""Limit the time of a 2d sparse spike train"""
3434

35-
def __init__(self, max_time: units.UnitTime, dt: units.UnitTime = 1 * units.us):
36-
self.max = (max_time.rescale(dt.units) / dt).magnitude
35+
@wunits(None, (None, second, second))
36+
def __init__(self, max_time, dt=1 * us):
37+
self.max = max_time / dt
3738

3839
def __call__(self, sparse_spike_train):
3940
mask = sparse_spike_train.ts < self.max
@@ -50,14 +51,15 @@ class ToDense(object):
5051
"""Transform a sparse spike train to a dense torch tensor of shape (x, y, p, time)
5152
with time unit defined by dt. Time accumulation is done with a max function."""
5253

54+
@wunits(None, (None, second))
5355
def __init__(
5456
self,
55-
dt: units.UnitTime, # Time scale of dense tensor
57+
dt, # Time scale of dense tensor
5658
):
5759
self.dt = dt
5860

5961
def __call__(self, sparse_spike_train):
60-
time_scale = ((sparse_spike_train.time_scale * units.second).rescale(self.dt.units) / self.dt).magnitude
62+
time_scale = sparse_spike_train.time_scale / self.dt
6163
duration = np.ceil(sparse_spike_train.duration * time_scale).astype(int)
6264
dense_spike_train = torch.zeros((sparse_spike_train.width, sparse_spike_train.height, 2, duration))
6365

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
setup(
99
name="ebdataset",
10-
version="0.0.2",
10+
version="0.1.0",
1111
author="Ismael Balafrej - NECOTIS",
1212
author_email="[email protected]",
1313
description="An event based dataset loader under one common API.",
@@ -17,7 +17,7 @@
1717
packages=find_packages(),
1818
install_requires=[
1919
"numpy>=1.14.3",
20-
"quantities>=0.12.4",
20+
"pint>=0.17",
2121
"tqdm>=4.45.0",
2222
"torch>=1.4.0",
2323
"torchvision>=0.5.0",

0 commit comments

Comments
 (0)