Skip to content

Commit 33c3bb3

Browse files
committed
WIP: unit tests
1 parent b71ba6e commit 33c3bb3

File tree

2 files changed

+192
-9
lines changed

2 files changed

+192
-9
lines changed

src/vitessce/data_utils/spatialdata_points_zorder.py

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def sdata_morton_sort_points(sdata, element):
156156

157157
return sdata
158158

159-
def sdata_morton_query_rect(sdata, element, orig_rect):
159+
def sdata_morton_query_rect_aux(sdata, element, orig_rect):
160160
#orig_rect = [[50, 50], [100, 150]] # [[x0, y0], [x1, y1]]
161161
#norm_rect = [
162162
# orig_coord_to_norm_coord(orig_rect[0], orig_x_min=0, orig_x_max=100, orig_y_min=0, orig_y_max=200),
@@ -190,15 +190,34 @@ def sdata_morton_query_rect(sdata, element, orig_rect):
190190
stop_level = None,
191191
merge = True,
192192
)
193+
194+
return morton_intervals
195+
196+
197+
def sdata_morton_query_rect(sdata, element, orig_rect):
198+
sorted_ddf = sdata.points[element]
199+
morton_intervals = sdata_morton_query_rect_aux(sdata, element, orig_rect)
200+
193201
# Get morton code column as a list of integers
194202
morton_sorted = sorted_ddf["morton_code_2d"].compute().values.tolist()
195203

196204
# Get a list of row ranges that match the morton intervals.
197205
# (This uses binary searches internally to find the matching row indices).
198206
# [ (row_start, row_end), ... ]
199207
matching_row_ranges = zquery_rows(morton_sorted, morton_intervals, merge = True)
208+
200209
return matching_row_ranges
201210

211+
def sdata_morton_query_rect_debug(sdata, element, orig_rect):
212+
# This is the same as the above sdata_morton_query_rect function,
213+
# but it also returns the list of row indices that were checked
214+
# during the binary searches.
215+
sorted_ddf = sdata.points[element]
216+
morton_intervals = sdata_morton_query_rect_aux(sdata, element, orig_rect)
217+
morton_sorted = sorted_ddf["morton_code_2d"].compute().values.tolist()
218+
matching_row_ranges, rows_checked = zquery_rows_aux(morton_sorted, morton_intervals, merge = True)
219+
return matching_row_ranges, rows_checked
220+
202221
# --------------------------
203222
# Functions for rectangle queries.
204223
# --------------------------
@@ -334,23 +353,42 @@ def zcover_rectangle(rx0:int, ry0:int, rx1:int, ry1:int, bits:int, stop_level: O
334353
# Morton intervals -> row ranges in a Morton-sorted column
335354
# --------------------------
336355

337-
def zquery_rows(morton_sorted: List[int], intervals: List[Tuple[int,int]], merge: bool = True) -> List[Tuple[int,int]]:
356+
def zquery_rows_aux(morton_sorted: List[int], intervals: List[Tuple[int,int]], merge: bool = True) -> Tuple[List[Tuple[int,int]], List[int]]:
338357
"""
339358
For each Z-interval [zlo, zhi], binary-search in the sorted Morton column
340359
and return row index half-open ranges [i, j) to scan.
341360
"""
361+
362+
# Keep track of which keys were looked at during the binary searches.
363+
# This is used for analysis / debugging, for instance, to enable
364+
# evaluating how many HTTP requests would be needed in network-based case
365+
# (which will also depend on Arrow row group size).
366+
recorded_keys = []
367+
def record_key_check(k: int) -> int:
368+
# TODO: Does recorded_keys need to be marked as a global here?
369+
recorded_keys.append(k)
370+
return k
371+
342372
ranges: List[Tuple[int,int]] = []
373+
# TODO: can these multiple binary searches be optimized?
374+
# Since we are doing many searches in the same array, and in each search we learn where more elements are located.
343375
for zlo, zhi in intervals:
344-
i = bisect_left(morton_sorted, zlo)
345-
j = bisect_right(morton_sorted, zhi)
376+
i = bisect_left(morton_sorted, zlo, key=record_key_check)
377+
# TODO: use lo=i in bisect_right to limit the search range?
378+
# TODO: can the second binary search be further optimized since we just did a binary search via bisect_left?
379+
j = bisect_right(morton_sorted, zhi, key=record_key_check)
346380
if i < j:
347381
ranges.append((i, j))
348382

349-
# TODO: record exactly which rows were queried,
350-
# to enable evaluating how many HTTP requests would be needed in network-based case
351-
# (will also depend on Arrow row group size)
352-
353-
return merge_adjacent(ranges) if merge else ranges
383+
result = merge_adjacent(ranges) if merge else ranges
384+
return result, recorded_keys
385+
386+
def zquery_rows(morton_sorted: List[int], intervals: List[Tuple[int,int]], merge: bool = True) -> List[Tuple[int,int]]:
387+
"""
388+
For each Z-interval [zlo, zhi], binary-search in the sorted Morton column
389+
and return row index half-open ranges [i, j) to scan.
390+
"""
391+
return zquery_rows_aux(morton_sorted, intervals, merge=merge)[0]
354392

355393

356394
def row_ranges_to_row_indices(intervals: List[Tuple[int,int]]) -> List[int]:

tests/test_sdata_points_zorder.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
import pytest
2+
from os.path import join
3+
4+
from spatialdata import read_zarr
5+
6+
from vitessce.data_utils.spatialdata_points_zorder import (
7+
# Function for computing codes and sorting
8+
sdata_morton_sort_points,
9+
# Functions for querying
10+
sdata_morton_query_rect_debug,
11+
row_ranges_to_row_indices,
12+
orig_coord_to_norm_coord,
13+
)
14+
15+
def is_sorted(l):
16+
return all(l[i] <= l[i + 1] for i in range(len(l) - 1))
17+
18+
def get_sdata():
19+
data_dir = join("docs", "notebooks", "data")
20+
spatialdata_filepath = join(data_dir, "xenium_rep1_io.spatialdata.zarr")
21+
22+
sdata = read_zarr(spatialdata_filepath)
23+
return sdata
24+
25+
@pytest.mark.skip(reason="Temporarily disable")
26+
def test_zorder_sorting():
27+
# TODO: use fixture here
28+
sdata = get_sdata()
29+
30+
sdata_morton_sort_points(sdata, "transcripts")
31+
32+
# Check that the morton codes are sorted
33+
sorted_ddf = sdata.points["transcripts"]
34+
morton_sorted = sorted_ddf["morton_code_2d"].compute().values.tolist()
35+
36+
assert is_sorted(morton_sorted)
37+
38+
39+
def test_zorder_query():
40+
sdata = get_sdata()
41+
42+
sdata_morton_sort_points(sdata, "transcripts")
43+
44+
# Query a rectangle that should return some points
45+
orig_rect = [[50.0, 50.0], [100.0, 150.0]] # x0, y0, x1, y1
46+
matching_row_ranges, rows_checked = sdata_morton_query_rect_debug(sdata, "transcripts", orig_rect)
47+
rect_row_indices = row_ranges_to_row_indices(matching_row_ranges)
48+
49+
# Cannot use df.iloc on a dask dataframe, so convert it to pandas first
50+
ddf = sdata.points["transcripts"]
51+
df = ddf.compute()
52+
df = df.reset_index(drop=True)
53+
estimated_row_indices = df.iloc[rect_row_indices].index.tolist()
54+
55+
assert df.shape[0] == 42638083
56+
57+
# Do the same query the "dumb" way, by checking all points
58+
in_rect = (
59+
(df["x"] >= orig_rect[0][0])
60+
& (df["x"] <= orig_rect[1][0])
61+
& (df["y"] >= orig_rect[0][1])
62+
& (df["y"] <= orig_rect[1][1])
63+
)
64+
dumb_df_subset = df.loc[in_rect]
65+
# Get the row indices of the points in the rectangle
66+
# (these are the indices in the original dataframe)
67+
exact_row_indices = dumb_df_subset.index.tolist()
68+
69+
# Check that the estimated rows 100% contain the exact rows.
70+
# A.issubset(B) checks that all elements of A are in B ("A is a subset of B").
71+
assert set(exact_row_indices).issubset(set(estimated_row_indices))
72+
assert len(exact_row_indices) == 614
73+
assert len(estimated_row_indices) <= 631
74+
75+
# Check that the number of rows checked is less than the total number of points
76+
assert len(rows_checked) <= 45237
77+
assert len(matching_row_ranges) == 24 # Kind of an implementation detail.
78+
79+
# Do a second check, this time against x_uint/y_uint (the normalized coordinates)
80+
# TODO: does this ensure that estimated == exact?
81+
82+
bounding_box = ddf.attrs["bounding_box"]
83+
x_min = bounding_box["x_min"]
84+
x_max = bounding_box["x_max"]
85+
y_min = bounding_box["y_min"]
86+
y_max = bounding_box["y_max"]
87+
norm_rect = [
88+
orig_coord_to_norm_coord(orig_rect[0], orig_x_min=x_min, orig_x_max=x_max, orig_y_min=y_min, orig_y_max=y_max),
89+
orig_coord_to_norm_coord(orig_rect[1], orig_x_min=x_min, orig_x_max=x_max, orig_y_min=y_min, orig_y_max=y_max)
90+
]
91+
92+
in_rect_norm = (
93+
(df["x_uint"] >= norm_rect[0][0])
94+
& (df["x_uint"] <= norm_rect[1][0])
95+
& (df["y_uint"] >= norm_rect[0][1])
96+
& (df["y_uint"] <= norm_rect[1][1])
97+
)
98+
dumb_df_subset_norm = df.loc[in_rect_norm]
99+
# Get the row indices of the points in the rectangle
100+
# (these are the indices in the original dataframe)
101+
exact_row_indices_norm = dumb_df_subset_norm.index.tolist()
102+
assert set(exact_row_indices_norm).issubset(set(estimated_row_indices))
103+
assert len(exact_row_indices_norm) == 617
104+
assert len(estimated_row_indices) <= 631
105+
106+
107+
108+
"""
109+
# ========= Another query ==========
110+
orig_rect = [[500, 500], [600, 600]] # x0, y0, x1, y1
111+
112+
# Query using z-order
113+
matching_row_ranges, rows_checked = sdata_morton_query_rect_debug(sdata, "transcripts", orig_rect)
114+
rect_row_indices = row_ranges_to_row_indices(matching_row_ranges)
115+
estimated_row_indices = df.iloc[rect_row_indices].index.tolist()
116+
117+
# Query the "dumb" way
118+
in_rect = (
119+
(df["x"] >= orig_rect[0][0])
120+
& (df["x"] <= orig_rect[1][0])
121+
& (df["y"] >= orig_rect[0][1])
122+
& (df["y"] <= orig_rect[1][1])
123+
)
124+
dumb_df_subset = df.loc[in_rect]
125+
exact_row_indices = dumb_df_subset.index.tolist()
126+
127+
diff_rows = set(estimated_row_indices) - set(exact_row_indices)
128+
# print("Rows in estimated but not exact:", diff_rows)
129+
print(df.iloc[list(diff_rows)])
130+
raise NotImplementedError("Debugging")
131+
132+
# Check that the estimated rows contain all of the exact rows.
133+
assert len(set(exact_row_indices).intersection(set(estimated_row_indices))) == 0
134+
assert len(exact_row_indices) <= 1123 # TODO: update
135+
assert len(estimated_row_indices) <= 1163 # TODO: update
136+
137+
"""
138+
139+
140+
141+
142+
143+
144+
145+

0 commit comments

Comments
 (0)