Skip to content

Commit a45b43a

Browse files
authored
Merge pull request #4 from tihbe/parity_task
Added generated bit parity task
2 parents 51780a5 + b034337 commit a45b43a

File tree

5 files changed

+75
-1
lines changed

5 files changed

+75
-1
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ You can install the latest version of this package with:
4545
pip install git+https://github.com/tihbe/python-ebdataset.git
4646
```
4747

48-
# Usage
48+
# Getting started
4949

5050
In the code:
5151
```python

ebdataset/audio/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1+
"""This subpackage regroups audio-based spiking dataset"""
12
from .ntidigits import NTidigits

ebdataset/generated/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
"""This subpackage regroups self-generated spiking dataset"""
2+
from .parity import ParityTask

ebdataset/generated/parity.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import numpy as np
2+
import torch.utils.data as data
3+
from quantities import hertz, ms, second
4+
5+
6+
class ParityTask(data.IterableDataset):
7+
"""Create a spike-based population encoded parity (or n-bits xor) task of max_iter samples
8+
if max_iter is not specified or is np.inf, this dataset will keep generating samples forever
9+
each HIGH bits is encoded with high_freq poisson-sampled spikes of shape features_per_bit x duration_per_bit
10+
LOW bits and background-noise is encoded with low_freq poisson-sampled spikes for the remaining of the sample_duration
11+
bits are encoded both temporally and spatially
12+
"""
13+
14+
def __init__(
15+
self,
16+
seed=0x1B,
17+
low_freq=2 * hertz,
18+
high_freq=20 * hertz,
19+
sample_duration=2 * second,
20+
number_of_bits=2,
21+
features_per_bit=50,
22+
duration_per_bit=0.5 * second,
23+
dt=1 * ms,
24+
max_iter=np.inf,
25+
as_recarray=True,
26+
):
27+
self.seed = seed
28+
self.max_iter = max_iter
29+
self.gen = np.random.RandomState(seed=self.seed)
30+
self.low_freq = float((low_freq * dt).simplified)
31+
self.high_freq = float((high_freq * dt).simplified)
32+
self.sample_duration = int((sample_duration / dt).simplified)
33+
self.duration_per_bit = int((duration_per_bit / dt).simplified)
34+
self.number_of_bits = number_of_bits
35+
self.features_per_bit = features_per_bit
36+
self.as_recarray = as_recarray
37+
38+
def __iter__(self):
39+
worker_info = data.get_worker_info()
40+
m = 1
41+
if worker_info is not None: # multi-process data loading, re-seed the iterator
42+
self.gen = np.random.RandomState(seed=worker_info.id + self.seed)
43+
m = worker_info.num_workers
44+
45+
i = 0
46+
while i < self.max_iter / m:
47+
i += 1
48+
49+
bits = self.gen.randint(0, 2, size=self.number_of_bits)
50+
y = np.sum(bits) % 2
51+
52+
spike_train = self.gen.poisson(
53+
lam=self.low_freq, size=(self.number_of_bits * self.features_per_bit, self.sample_duration)
54+
)
55+
56+
for b in range(self.number_of_bits):
57+
if bits[b]:
58+
spike_train[
59+
b * self.features_per_bit : (b + 1) * self.features_per_bit,
60+
b * self.duration_per_bit : (b + 1) * self.duration_per_bit,
61+
] = self.gen.poisson(lam=self.high_freq, size=(self.features_per_bit, self.duration_per_bit))
62+
63+
if self.as_recarray:
64+
addr, ts = np.nonzero(spike_train)
65+
sample = np.recarray(shape=len(ts), dtype=[("addr", addr.dtype), ("ts", ts.dtype)])
66+
sample.addr = addr
67+
sample.ts = ts
68+
yield sample, y
69+
else:
70+
yield spike_train, y

ebdataset/vision/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
"""This subpackage regroups vision-based spiking dataset"""
12
from .ibm_gesture import IBMGesture, H5IBMGesture
23
from .ini_ucf50 import INIUCF50
34
from .ncaltech101 import NCaltech101

0 commit comments

Comments
 (0)