Skip to content

Commit dedd91f

Browse files
Modify Auto Domain for RegularTimeseries (#109)
1 parent 0a8bdba commit dedd91f

4 files changed

Lines changed: 38 additions & 25 deletions

File tree

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
3333
### Changed
3434
- Change minimum python version to 3.10 ([#93](https://github.com/neuro-galaxy/temporaldata/pull/93))
3535
- Optimized performance of `Interval.coalesce()` ([#97](https://github.com/neuro-galaxy/temporaldata/pull/97))
36+
- New auto domain for `RegularTimeseries` to have no impact when doing `rts.slice(rts.domain.start[0], rts.domain.end[-1])` ([#109](https://github.com/neuro-galaxy/temporaldata/pull/109))
37+
3638

3739
## [0.1.3] - 2025-03-21
3840
### Added

temporaldata/regular_ts.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def __init__(
6969
)
7070
domain = Interval(
7171
start=np.array([domain_start]),
72-
end=np.array([domain_start + (len(self) - 1) / sampling_rate]),
72+
end=np.array([domain_start + len(self) / sampling_rate]),
7373
)
7474
self._domain = domain
7575

@@ -128,9 +128,7 @@ def _time_to_idx(
128128
# Determine index and reconstruct the actual timestamp of that sample
129129
idx = math.ceil(idx_float)
130130

131-
# For the end index, the reconstruction logic shifts by 1 sample
132-
recon_idx = idx if is_start else idx - 1
133-
actual_time = domain_start + (recon_idx / self.sampling_rate)
131+
actual_time = domain_start + (idx / self.sampling_rate)
134132

135133
return idx, actual_time
136134

@@ -285,7 +283,7 @@ def _maybe_first_dim(self):
285283
# this is because we are dealing with numerical noise
286284
# we know the domain and the sampling rate, we can infer the number of pts
287285
domain_length = self.domain.end[-1] - self.domain.start[0]
288-
return int(np.round(domain_length * self.sampling_rate)) + 1
286+
return int(np.round(domain_length * self.sampling_rate))
289287

290288
# otherwise nothing was loaded, return the first dim of the h5py dataset
291289
return self.__dict__[self.keys()[0]].shape[0]

tests/test_data.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
1-
import pytest
2-
import os
31
import copy
2+
import logging
3+
import os
4+
import tempfile
5+
46
import h5py
57
import numpy as np
6-
import tempfile
7-
import logging
8+
import pytest
9+
810
from temporaldata import (
911
ArrayDict,
12+
Data,
13+
Interval,
1014
IrregularTimeSeries,
1115
RegularTimeSeries,
12-
Interval,
13-
Data,
1416
)
1517

1618

@@ -569,7 +571,7 @@ def test_data_auto_domain():
569571
)
570572

571573
assert np.allclose(data.domain.start, np.array([0, 5]))
572-
assert np.allclose(data.domain.end, np.array([3.996, 6]))
574+
assert np.allclose(data.domain.end, np.array([4, 6]))
573575

574576

575577
def test_data_save(tmp_path):

tests/test_regular_ts.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,38 +28,38 @@ def _test_regulartimeseries(data):
2828
assert len(data) == 100
2929

3030
assert data.domain.start[0] == 0.0
31-
assert data.domain.end[0] == 9.9
31+
assert data.domain.end[0] == 10.0
3232

3333
data_slice = data.slice(2.0, 8.0, reset_origin=False)
3434
assert np.allclose(data_slice.lfp, data.lfp[20:80])
3535
assert data_slice.domain.start[0] == 2.0
36-
assert data_slice.domain.end[0] == 7.9
36+
assert data_slice.domain.end[0] == 8.0
3737
assert np.allclose(data_slice.timestamps, np.arange(2.0, 8.0, 0.1))
3838

3939
data_slice = data.slice(2.0, 8.0, reset_origin=True)
4040
assert np.allclose(data_slice.lfp, data.lfp[20:80])
4141
assert data_slice.domain.start[0] == 0.0
42-
assert data_slice.domain.end[0] == 5.9
42+
assert data_slice.domain.end[0] == 6.0
4343
assert np.allclose(data_slice.timestamps, np.arange(0.0, 6.0, 0.1))
4444

4545
# try slicing with skewed start and end
4646
# the sampling frequency is 10
4747
data_slice = data.slice(2.03, 8.09, reset_origin=True)
4848
assert np.allclose(data_slice.lfp, data.lfp[21:81])
4949
assert np.allclose(data_slice.domain.start, np.array([0.07]))
50-
assert np.allclose(data_slice.domain.end, np.array([5.97]))
50+
assert np.allclose(data_slice.domain.end, np.array([6.07]))
5151
assert np.allclose(data_slice.timestamps, np.arange(0.07, 5.98, 0.1))
5252

5353
data_slice = data.slice(4.051, 12.0, reset_origin=True)
5454
assert np.allclose(data_slice.lfp, data.lfp[41:])
5555
assert np.allclose(data_slice.domain.start, np.array([0.049]))
56-
assert np.allclose(data_slice.domain.end, np.array([5.849]))
56+
assert np.allclose(data_slice.domain.end, np.array([5.949]))
5757
assert np.allclose(data_slice.timestamps, np.arange(0.049, 5.88, 0.1))
5858

5959
data_slice = data.slice(4.051, 12.0, reset_origin=False)
6060
assert np.allclose(data_slice.lfp, data.lfp[41:])
6161
assert np.allclose(data_slice.domain.start, np.array([4.1]))
62-
assert np.allclose(data_slice.domain.end, np.array([9.9]))
62+
assert np.allclose(data_slice.domain.end, np.array([10.0]))
6363
assert np.allclose(data_slice.timestamps, np.arange(4.1, 10.0, 0.1))
6464

6565
data_slice = data.slice(-10, 20, reset_origin=False)
@@ -68,13 +68,12 @@ def _test_regulartimeseries(data):
6868
assert np.allclose(data_slice.domain.end, data.domain.end)
6969
assert np.allclose(data_slice.timestamps, data.timestamps)
7070

71-
# TODO update when we update the domain convention
7271
domain_start, domain_end = data.domain.start[0], data.domain.end[-1]
7372
data_slice = data.slice(domain_start, domain_end, reset_origin=False)
74-
assert np.allclose(data_slice.lfp, data.lfp[:-1])
73+
assert np.allclose(data_slice.lfp, data.lfp)
7574
assert np.allclose(data_slice.domain.start, data.domain.start)
76-
assert np.allclose(data_slice.domain.end, data.domain.end - 0.1)
77-
assert np.allclose(data_slice.timestamps, data.timestamps[:-1])
75+
assert np.allclose(data_slice.domain.end, data.domain.end)
76+
assert np.allclose(data_slice.timestamps, data.timestamps)
7877

7978
data = RegularTimeSeries(
8079
lfp=np.random.random((100, 48)), sampling_rate=10, domain="auto"
@@ -103,7 +102,7 @@ def _test_regulartimeseries_with_domain_start(data):
103102
assert len(data) == 100
104103

105104
assert data.domain.start[0] == 1.0
106-
assert data.domain.end[0] == 10.9
105+
assert data.domain.end[0] == 11.0
107106

108107
data_slice = data.slice(3.0, 9.0)
109108
assert np.allclose(data_slice.lfp, data.lfp[20:80])
@@ -210,12 +209,12 @@ def test_lazy_regular_timeseries(test_filepath):
210209
# even if no other attribute is loaded
211210
assert len(data.timestamps) == 500
212211
assert data.domain.start[0] == 1.0
213-
assert data.domain.end[0] == 2.996
212+
assert data.domain.end[0] == 3.0
214213
assert np.allclose(data.timestamps, np.arange(1.0, 3.0, 1 / 250.0))
215214

216215
data = data.slice(1.0, 2.0, reset_origin=True)
217216
assert data.domain.start[0] == 0.0
218-
assert data.domain.end[0] == 0.996
217+
assert data.domain.end[0] == 1.0
219218
assert len(data.timestamps) == 250
220219
assert np.allclose(data.timestamps, np.arange(0.0, 1.0, 1 / 250.0))
221220

@@ -259,27 +258,35 @@ def test_slice_numerical_instability():
259258
end = 1.0 - eps
260259
sliced_ts = ts.slice(start, end, reset_origin=False)
261260
assert np.allclose(sliced_ts.timestamps, np.array([0.0, 0.25, 0.5, 0.75]))
261+
assert sliced_ts.domain.start[0] == 0.0
262+
assert sliced_ts.domain.end[-1] == 1.0
262263

263264
# `end` is infinitesimally larger than an exact timestamp (1.0000001 scenario).
264265
# As the end is larger due to numerical instability even if the interval is [start, end), 1.0 should still be EXCLUDED
265266
start = 0.25
266267
end = 1.0 + eps
267268
sliced_ts = ts.slice(start, end, reset_origin=False)
268269
assert np.allclose(sliced_ts.timestamps, np.array([0.25, 0.5, 0.75]))
270+
assert sliced_ts.domain.start[0] == 0.25
271+
assert sliced_ts.domain.end[-1] == 1.0
269272

270273
# `start` is computed slightly larger than an exact timestamp.
271274
# As the start is larger due to numerical instability 0.25 should be INCLUDED.
272275
start = 0.25 + eps
273276
end = 1.0
274277
sliced_ts = ts.slice(start, end, reset_origin=False)
275278
assert np.allclose(sliced_ts.timestamps, np.array([0.25, 0.5, 0.75]))
279+
assert sliced_ts.domain.start[0] == 0.25
280+
assert sliced_ts.domain.end[-1] == 1.0
276281

277282
# Maximum Precision Limits via np.nextafter
278283
# np.nextafter gives the very next representable float in memory.
279284
start = 0.5
280285
end = np.nextafter(1.0, 0.0) # The largest possible float strictly less than 1.0
281286
sliced_ts = ts.slice(start, end, reset_origin=False)
282287
assert np.allclose(sliced_ts.timestamps, np.array([0.5, 0.75]))
288+
assert sliced_ts.domain.start[0] == 0.5
289+
assert sliced_ts.domain.end[-1] == 1.0
283290

284291
# Should still treat `end` as 1.0 and EXCLUDED it
285292
start = 0.5
@@ -295,6 +302,8 @@ def test_slice_numerical_instability():
295302
end = 1.0
296303
sliced_ts = ts.slice(start, end, reset_origin=False)
297304
assert np.allclose(sliced_ts.timestamps, np.array([0.25, 0.5, 0.75]))
305+
assert sliced_ts.domain.start[0] == 0.25
306+
assert sliced_ts.domain.end[-1] == 1.0
298307

299308
ts = RegularTimeSeries(value=np.zeros((40)), sampling_rate=10, domain="auto")
300309
# Expected timestamps: [0.0, 0.1, 0.2, ...]
@@ -305,3 +314,5 @@ def test_slice_numerical_instability():
305314
end = start * 3 # 0.9000000000000001
306315
sliced_ts = ts.slice(start, end, reset_origin=False)
307316
assert np.allclose(sliced_ts.timestamps, np.array([0.3, 0.4, 0.5, 0.6, 0.7, 0.8]))
317+
assert sliced_ts.domain.start[0] == 0.3
318+
assert sliced_ts.domain.end[-1] == 0.9

0 commit comments

Comments
 (0)