|
| 1 | +import numpy as np |
| 2 | +import torch.utils.data as data |
| 3 | +from quantities import hertz, ms, second |
| 4 | + |
| 5 | + |
| 6 | +class ParityTask(data.IterableDataset): |
| 7 | + """Create a spike-based population encoded parity (or n-bits xor) task of max_iter samples |
| 8 | + if max_iter is not specified or is np.inf, this dataset will keep generating samples forever |
| 9 | + each HIGH bits is encoded with high_freq poisson-sampled spikes of shape features_per_bit x duration_per_bit |
| 10 | + LOW bits and background-noise is encoded with low_freq poisson-sampled spikes for the remaining of the sample_duration |
| 11 | + bits are encoded both temporally and spatially |
| 12 | + """ |
| 13 | + |
| 14 | + def __init__( |
| 15 | + self, |
| 16 | + seed=0x1B, |
| 17 | + low_freq=2 * hertz, |
| 18 | + high_freq=20 * hertz, |
| 19 | + sample_duration=2 * second, |
| 20 | + number_of_bits=2, |
| 21 | + features_per_bit=50, |
| 22 | + duration_per_bit=0.5 * second, |
| 23 | + dt=1 * ms, |
| 24 | + max_iter=np.inf, |
| 25 | + as_recarray=True, |
| 26 | + ): |
| 27 | + self.seed = seed |
| 28 | + self.max_iter = max_iter |
| 29 | + self.gen = np.random.RandomState(seed=self.seed) |
| 30 | + self.low_freq = float((low_freq * dt).simplified) |
| 31 | + self.high_freq = float((high_freq * dt).simplified) |
| 32 | + self.sample_duration = int((sample_duration / dt).simplified) |
| 33 | + self.duration_per_bit = int((duration_per_bit / dt).simplified) |
| 34 | + self.number_of_bits = number_of_bits |
| 35 | + self.features_per_bit = features_per_bit |
| 36 | + self.as_recarray = as_recarray |
| 37 | + |
| 38 | + def __iter__(self): |
| 39 | + worker_info = data.get_worker_info() |
| 40 | + m = 1 |
| 41 | + if worker_info is not None: # multi-process data loading, re-seed the iterator |
| 42 | + self.gen = np.random.RandomState(seed=worker_info.id + self.seed) |
| 43 | + m = worker_info.num_workers |
| 44 | + |
| 45 | + i = 0 |
| 46 | + while i < self.max_iter / m: |
| 47 | + i += 1 |
| 48 | + |
| 49 | + bits = self.gen.randint(0, 2, size=self.number_of_bits) |
| 50 | + y = np.sum(bits) % 2 |
| 51 | + |
| 52 | + spike_train = self.gen.poisson( |
| 53 | + lam=self.low_freq, size=(self.number_of_bits * self.features_per_bit, self.sample_duration) |
| 54 | + ) |
| 55 | + |
| 56 | + for b in range(self.number_of_bits): |
| 57 | + if bits[b]: |
| 58 | + spike_train[ |
| 59 | + b * self.features_per_bit : (b + 1) * self.features_per_bit, |
| 60 | + b * self.duration_per_bit : (b + 1) * self.duration_per_bit, |
| 61 | + ] = self.gen.poisson(lam=self.high_freq, size=(self.features_per_bit, self.duration_per_bit)) |
| 62 | + |
| 63 | + if self.as_recarray: |
| 64 | + addr, ts = np.nonzero(spike_train) |
| 65 | + sample = np.recarray(shape=len(ts), dtype=[("addr", addr.dtype), ("ts", ts.dtype)]) |
| 66 | + sample.addr = addr |
| 67 | + sample.ts = ts |
| 68 | + yield sample, y |
| 69 | + else: |
| 70 | + yield spike_train, y |
0 commit comments