11import numpy as np
22import torch .utils .data as data
3- from quantities import hertz , ms , second
3+ from .. utils . units import hertz , ms , second , wunits
44
55
66class ParityTask (data .IterableDataset ):
@@ -11,6 +11,7 @@ class ParityTask(data.IterableDataset):
1111 bits are encoded both temporally and spatially if sequential=True, otherwise only spatially
1212 """
1313
14+ @wunits (None , (None , None , hertz , hertz , second , None , None , second , second , None , None , None ))
1415 def __init__ (
1516 self ,
1617 seed = 0x1B ,
@@ -27,35 +28,35 @@ def __init__(
2728 ):
2829 self .seed = seed
2930 self .max_iter = max_iter
30- self .gen = np .random .RandomState (seed = self .seed )
31- self .low_freq = float (( low_freq * dt ). simplified )
32- self .high_freq = float (( high_freq * dt ). simplified )
33- self .sample_duration = int (( sample_duration / dt ). simplified )
34- self .duration_per_bit = int (( duration_per_bit / dt ). simplified )
31+ self .rand = np .random .RandomState (seed = self .seed )
32+ self .low_freq = low_freq * dt
33+ self .high_freq = high_freq * dt
34+ self .sample_duration = int (sample_duration / dt )
35+ self .duration_per_bit = int (duration_per_bit / dt )
3536 self .number_of_bits = number_of_bits
3637 self .features_per_bit = features_per_bit
3738 self .as_recarray = as_recarray
3839 self .sequential = sequential
3940 if sequential :
4041 assert (
41- duration_per_bit * number_of_bits > = sample_duration
42+ duration_per_bit * number_of_bits < = sample_duration
4243 ), "Sample duration is not enough to contain every bits"
4344
4445 def __iter__ (self ):
4546 worker_info = data .get_worker_info ()
4647 m = 1
4748 if worker_info is not None : # multi-process data loading, re-seed the iterator
48- self .gen = np .random .RandomState (seed = worker_info .id + self .seed )
49+ self .rand = np .random .RandomState (seed = worker_info .id + self .seed )
4950 m = worker_info .num_workers
5051
5152 i = 0
5253 while i < self .max_iter / m :
5354 i += 1
5455
55- bits = self .gen .randint (0 , 2 , size = self .number_of_bits )
56+ bits = self .rand .randint (0 , 2 , size = self .number_of_bits )
5657 y = np .sum (bits ) % 2
5758
58- spike_train = self .gen .poisson (
59+ spike_train = self .rand .poisson (
5960 lam = self .low_freq , size = (self .number_of_bits * self .features_per_bit , self .sample_duration )
6061 )
6162
@@ -68,7 +69,7 @@ def __iter__(self):
6869 )
6970 spike_train [
7071 b * self .features_per_bit : (b + 1 ) * self .features_per_bit , time_pos
71- ] = self .gen .poisson (lam = self .high_freq , size = (self .features_per_bit , self .duration_per_bit ))
72+ ] = self .rand .poisson (lam = self .high_freq , size = (self .features_per_bit , self .duration_per_bit ))
7273
7374 if self .as_recarray :
7475 ts , addr = np .nonzero (spike_train .T )
0 commit comments