Skip to content

Commit 2721f21

Browse files
committed
Add zorder functions to data_utils
1 parent 09eb63f commit 2721f21

File tree

2 files changed

+367
-0
lines changed

2 files changed

+367
-0
lines changed

src/vitessce/data_utils/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,9 @@
1717
from .multivec import (
1818
adata_to_multivec_zarr,
1919
)
20+
from .spatialdata_points_zorder import (
21+
# Function for computing codes and sorting
22+
sdata_morton_sort_points,
23+
# Functions for querying
24+
sdata_morton_query_rect,
25+
)
Lines changed: 361 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,361 @@
1+
from typing import Tuple, List, Optional
2+
3+
from bisect import bisect_left, bisect_right
4+
import pandas as pd
5+
import numpy as np
6+
7+
8+
from spatialdata import get_element_annotators
9+
import dask.dataframe as dd
10+
11+
### Perform rectangle range queries
12+
13+
# Given a query rectangle region (x_0, y_0) to (x_1, y_1):
14+
# Steps:
15+
# 0. Knowledge of MORTON_CODE_NUM_BITS (and by extension MORTON_CODE_VALUE_MIN/MORTON_CODE_VALUE_MAX).
16+
# 1. `O(2)`: Get the first and last rows of the dataframe, to identify (x_min, y_min) and (x_max, y_max) respectively. These values are needed for normalization.
17+
# - Update: this does not work. The first point is only guaranteed to have y_min, but its X coordinate may be any value. The last row seems to not be guaranteed to have either x_max or y_max.
18+
# 2. `O(?)`: Get the morton code value intervals covering the query rectangle region (in morton code space).
19+
# 3. `O(num_intervals * log(N))`: For each morton code (morton_start, morton_end) interval, perform binary search to identify the corresponding table row ranges.
20+
# 4. Concatenate the table rows in the resulting table row ranges.
21+
# 5. If the rectangle covering is loose (as opposed to exact), filter the resulting rows to only those in the query rectangle region.
22+
23+
MORTON_CODE_NUM_BITS = 32 # Resulting morton codes will be stored as uint32.
24+
MORTON_CODE_VALUE_MIN = 0
25+
MORTON_CODE_VALUE_MAX = 2**(MORTON_CODE_NUM_BITS/2) - 1
26+
27+
# --------------------------
28+
# Functions for computing Morton codes for SpatialData points (2D).
29+
# --------------------------
30+
31+
def norm_series_to_uint(series, v_min, v_max):
32+
"""
33+
Scale numeric Series (int or float) to integer grid [0, 2^bits-1], handling NaNs.
34+
"""
35+
# Cast to float64
36+
series_f64 = series.astype("float64")
37+
# Normalize the array values to be between 0.0 and 1.0
38+
norm_series_f64 = (series_f64 - v_min) / (v_max - v_min)
39+
# Clip to ensure no values are outside 0/1 range
40+
clipped_norm_series_f64 = np.clip(norm_series_f64, 0.0, 1.0)
41+
# Multiply by the morton code max-value to scale from [0,1] to [0,65535]
42+
out = (clipped_norm_series_f64 * MORTON_CODE_VALUE_MAX).astype(np.uint32)
43+
# Set NaNs to 0.
44+
out = out.fillna(0)
45+
return out
46+
47+
def norm_ddf_to_uint(ddf):
48+
[x_min, x_max, y_min, y_max] = [ddf["x"].min().compute(), ddf["x"].max().compute(), ddf["y"].min().compute(), ddf["y"].max().compute()]
49+
ddf["x_uint"] = norm_series_to_uint(ddf["x"], x_min, x_max)
50+
ddf["y_uint"] = norm_series_to_uint(ddf["y"], y_min, y_max)
51+
52+
# Insert the bounding box as metadata for the sdata.points[element] Points element dataframe.
53+
# TODO: does anything special need to be done to ensure this is saved to disk?
54+
ddf.attrs["bounding_box"] = {
55+
"x_min": float(x_min),
56+
"x_max": float(x_max),
57+
"y_min": float(y_min),
58+
"y_max": float(y_max),
59+
}
60+
61+
return ddf
62+
63+
def _part1by1_16(x):
64+
"""
65+
Spread each 16-bit value into 32 bits by inserting zeros between bits.
66+
Input: uint32 array (values must fit in 16 bits)
67+
Output: uint32 array (bit-spread)
68+
"""
69+
70+
assert x.dtype.name == 'uint32'
71+
72+
# Mask away any bits above 16 (just in case input wasn't clean).
73+
x = x & np.uint32(0x0000FFFF)
74+
75+
# First spread: shift left by 8 bits, OR with original, then mask.
76+
# After this, groups of 8 bits are separated by 8 zeros.
77+
x = (x | np.left_shift(x, 8)) & np.uint32(0x00FF00FF)
78+
79+
# Spread further: now groups of 4 bits separated by 4 zeros.
80+
x = (x | np.left_shift(x, 4)) & np.uint32(0x0F0F0F0F)
81+
82+
# Spread further: groups of 2 bits separated by 2 zeros.
83+
x = (x | np.left_shift(x, 2)) & np.uint32(0x33333333)
84+
85+
# Final spread: single bits separated by a zero bit.
86+
# Now each original bit is in every other position (positions 0,2,4,...).
87+
x = (x | np.left_shift(x, 1)) & np.uint32(0x55555555)
88+
89+
return x
90+
91+
def _part1by1_32(x):
92+
"""
93+
Spread each 32-bit value into 64 bits by inserting zeros between bits.
94+
Input: uint64 array (values must fit in 32 bits)
95+
Output: uint64 array (bit-spread)
96+
"""
97+
98+
assert x.dtype.name == 'uint64'
99+
100+
# Mask away any bits above 32 (safety).
101+
x = x.astype(np.uint64) & np.uint64(0x00000000FFFFFFFF)
102+
103+
# First spread: separate into 16-bit chunks spaced out.
104+
x = (x | np.left_shift(x, 16)) & np.uint64(0x0000FFFF0000FFFF)
105+
106+
# Spread further: each 8-bit chunk separated.
107+
x = (x | np.left_shift(x, 8)) & np.uint64(0x00FF00FF00FF00FF)
108+
109+
# Spread further: each 4-bit nibble separated.
110+
x = (x | np.left_shift(x, 4)) & np.uint64(0x0F0F0F0F0F0F0F0F)
111+
112+
# Spread further: 2-bit groups separated.
113+
x = (x | np.left_shift(x, 2)) & np.uint64(0x3333333333333333)
114+
115+
# Final spread: single bits separated by zeros.
116+
# Now each original bit occupies every other position (0,2,4,...).
117+
x = (x | np.left_shift(x, 1)) & np.uint64(0x5555555555555555)
118+
119+
return x
120+
121+
def morton_interleave(ddf):
122+
"""
123+
Vectorized Morton interleave for integer arrays xi, yi
124+
already scaled to [0, 2^bits - 1].
125+
Returns Morton codes as uint32 (if bits<=16) or uint64 (if bits<=32).
126+
"""
127+
128+
xi = ddf["x_uint"]
129+
yi = ddf["y_uint"]
130+
131+
# Spread x and y bits into even (x) and odd (y) positions.
132+
xs = _part1by1_16(xi)
133+
ys = _part1by1_16(yi)
134+
135+
# Interleave: shift y bits left by 1 so they go into odd positions,
136+
# then OR with x bits in even positions.
137+
code = np.left_shift(ys.astype(np.uint64), 1) | xs.astype(np.uint64)
138+
139+
# Fits in 32 bits since we only had 16+16 input bits.
140+
return code.astype(np.uint32)
141+
142+
def sdata_morton_sort_points(sdata, element):
143+
ddf = sdata.points[element]
144+
145+
# Compute morton codes
146+
ddf = norm_ddf_to_uint(ddf)
147+
ddf["morton_code_2d"] = morton_interleave(ddf)
148+
149+
if "z" in ddf.columns:
150+
num_unique_z = ddf["z"].unique().shape[0].compute()
151+
if num_unique_z < 100:
152+
# Heuristic for interpreting the 3D data as 2.5D
153+
# Reference: https://github.com/scverse/spatialdata/issues/961
154+
sorted_ddf = ddf.sort_values(by=["z", "morton_code_2d"], ascending=True)
155+
else:
156+
# TODO: include z as a dimension in the morton code in the 3D case?
157+
158+
# For now, just return the data sorted by 2D code.
159+
sorted_ddf = ddf.sort_values(by="morton_code_2d", ascending=True)
160+
else:
161+
sorted_ddf = ddf.sort_values(by="morton_code_2d", ascending=True)
162+
sdata.points[element] = sorted_ddf
163+
164+
annotating_tables = get_element_annotators(sdata, element)
165+
166+
# TODO: Sort any annotating table(s) as well.
167+
168+
return sdata
169+
170+
def sdata_morton_query_rect(sdata, element, orig_rect):
171+
#orig_rect = [[50, 50], [100, 150]] # [[x0, y0], [x1, y1]]
172+
#norm_rect = [
173+
# orig_coord_to_norm_coord(orig_rect[0], orig_x_min=0, orig_x_max=100, orig_y_min=0, orig_y_max=200),
174+
# orig_coord_to_norm_coord(orig_rect[1], orig_x_min=0, orig_x_max=100, orig_y_min=0, orig_y_max=200)
175+
#]
176+
177+
sorted_ddf = sdata.points[element]
178+
179+
# TODO: fail if no morton_code_2d column
180+
# TODO: fail if not sorted as expected
181+
# TODO: fail if no bounding box metadata
182+
183+
bounding_box = sorted_ddf.attrs["bounding_box"]
184+
x_min = bounding_box["x_min"]
185+
x_max = bounding_box["x_max"]
186+
y_min = bounding_box["y_min"]
187+
y_max = bounding_box["y_max"]
188+
189+
190+
norm_rect = [
191+
orig_coord_to_norm_coord(orig_rect[0], orig_x_min=x_min, orig_x_max=x_max, orig_y_min=y_min, orig_y_max=y_max),
192+
orig_coord_to_norm_coord(orig_rect[1], orig_x_min=x_min, orig_x_max=x_max, orig_y_min=y_min, orig_y_max=y_max)
193+
]
194+
195+
# Get a list of morton code intervals that cover this rectangle region
196+
# [ (morton_start, morton_end), ... ]
197+
morton_intervals = zcover_rectangle(
198+
rx0 = norm_rect[0][0], ry0 = norm_rect[0][1],
199+
rx1 = norm_rect[1][0], ry1 = norm_rect[1][1],
200+
bits = 16,
201+
stop_level = None,
202+
merge = True,
203+
)
204+
# Get morton code column as a list of integers
205+
morton_sorted = sorted_ddf["morton_code_2d"].compute().values.tolist()
206+
207+
# Get a list of row ranges that match the morton intervals.
208+
# (This uses binary searches internally to find the matching row indices).
209+
# [ (row_start, row_end), ... ]
210+
matching_row_ranges = zquery_rows(morton_sorted, morton_intervals, merge = True)
211+
return matching_row_ranges
212+
213+
# --------------------------
214+
# Functions for rectangle queries.
215+
# --------------------------
216+
217+
# Convert a coordinate from the normalized [0, 65535] space to the original space.
218+
def norm_coord_to_orig_coord(norm_coord, orig_x_min, orig_x_max, orig_y_min, orig_y_max):
219+
[norm_x, norm_y] = norm_coord
220+
orig_x_range = orig_x_max - orig_x_min
221+
orig_y_range = orig_y_max - orig_y_min
222+
return [
223+
(orig_x_min + (norm_x / MORTON_CODE_VALUE_MAX) * orig_x_range),
224+
(orig_y_min + (norm_y / MORTON_CODE_VALUE_MAX) * orig_y_range),
225+
]
226+
227+
# Convert a coordinate from the original space to the [0, 65535] normalized space.
228+
def orig_coord_to_norm_coord(orig_coord, orig_x_min, orig_x_max, orig_y_min, orig_y_max):
229+
[orig_x, orig_y] = orig_coord
230+
orig_x_range = orig_x_max - orig_x_min
231+
orig_y_range = orig_y_max - orig_y_min
232+
return [
233+
((orig_x - orig_x_min) / orig_x_range) * MORTON_CODE_VALUE_MAX,
234+
((orig_y - orig_y_min) / orig_y_range) * MORTON_CODE_VALUE_MAX,
235+
]
236+
237+
# --------------------------
238+
# Quadtree / Z-interval helpers
239+
# --------------------------
240+
def intersects(ax0:int, ay0:int, ax1:int, ay1:int,
241+
bx0:int, by0:int, bx1:int, by1:int) -> bool:
242+
"""Axis-aligned box intersection (inclusive integer bounds)."""
243+
return not (ax1 < bx0 or bx1 < ax0 or ay1 < by0 or by1 < ay0)
244+
245+
def contained(ix0:int, iy0:int, ix1:int, iy1:int,
246+
ox0:int, oy0:int, ox1:int, oy1:int) -> bool:
247+
"""Is inner box entirely inside outer box? (inclusive integer bounds)"""
248+
return (ox0 <= ix0 <= ix1 <= ox1) and (oy0 <= iy0 <= iy1 <= oy1)
249+
250+
def point_inside(x:int, y:int, rx0:int, ry0:int, rx1:int, ry1:int) -> bool:
251+
return (rx0 <= x <= rx1) and (ry0 <= y <= ry1)
252+
253+
def cell_range(prefix: int, level: int, bits: int) -> Tuple[int, int]:
254+
"""
255+
All Morton codes in a quadtree cell share the same prefix (2*level bits).
256+
Fill the remaining lower bits with 0s (lo) or 1s (hi).
257+
"""
258+
shift = 2 * (bits - level)
259+
lo = prefix << shift
260+
hi = ((prefix + 1) << shift) - 1
261+
return lo, hi
262+
263+
def merge_adjacent(intervals: List[Tuple[int,int]]) -> List[Tuple[int,int]]:
264+
"""Merge overlapping or directly adjacent intervals."""
265+
if not intervals:
266+
return []
267+
intervals.sort(key=lambda t: t[0])
268+
merged = [intervals[0]]
269+
for lo, hi in intervals[1:]:
270+
mlo, mhi = merged[-1]
271+
if lo <= mhi + 1:
272+
merged[-1] = (mlo, max(mhi, hi))
273+
else:
274+
merged.append((lo, hi))
275+
return merged
276+
277+
# --------------------------
278+
# Rectangle -> list of Morton intervals
279+
# --------------------------
280+
281+
def zcover_rectangle(rx0:int, ry0:int, rx1:int, ry1:int, bits:int, stop_level: Optional[int] = None, merge: bool = True) -> List[Tuple[int,int]]:
282+
"""
283+
Compute a (near-)minimal set of Morton code ranges covering the rectangle
284+
[rx0..rx1] x [ry0..ry1] on an integer grid [0..2^bits-1]^2.
285+
286+
- If stop_level is None: exact cover (descend to exact containment).
287+
- If stop_level is set (0..bits): stop descending at that level, adding
288+
partially-overlapping cells as whole ranges (superset cover).
289+
"""
290+
if not (0 <= rx0 <= rx1 <= (1<<bits)-1 and 0 <= ry0 <= ry1 <= (1<<bits)-1):
291+
raise ValueError("Rectangle out of bounds for given bits.")
292+
293+
intervals: List[Tuple[int,int]] = []
294+
295+
# stack entries: (prefix, level, xmin, ymin, xmax, ymax)
296+
stack = [(0, 0, 0, 0, (1<<bits)-1, (1<<bits)-1)]
297+
298+
while stack:
299+
prefix, level, xmin, ymin, xmax, ymax = stack.pop()
300+
301+
if not intersects(xmin, ymin, xmax, ymax, rx0, ry0, rx1, ry1):
302+
continue
303+
304+
# If we stop at this level for a loose cover, add full cell range.
305+
if stop_level is not None and level == stop_level:
306+
intervals.append(cell_range(prefix, level, bits))
307+
continue
308+
309+
# Fully contained: add full cell range.
310+
if contained(xmin, ymin, xmax, ymax, rx0, ry0, rx1, ry1):
311+
intervals.append(cell_range(prefix, level, bits))
312+
continue
313+
314+
# Leaf cell: single lattice point (only happens when level==bits)
315+
if level == bits:
316+
if point_inside(xmin, ymin, rx0, ry0, rx1, ry1):
317+
intervals.append(cell_range(prefix, level, bits))
318+
continue
319+
320+
# Otherwise, split into 4 children (Morton order: 00,01,10,11)
321+
midx = (xmin + xmax) // 2
322+
midy = (ymin + ymax) // 2
323+
324+
# q0: (x<=midx, y<=midy) -> child code 0b00
325+
stack.append(((prefix << 2) | 0,
326+
level+1,
327+
xmin, ymin, midx, midy))
328+
# q1: (x>midx, y<=midy) -> child code 0b01
329+
stack.append(((prefix << 2) | 1,
330+
level+1,
331+
midx+1, ymin, xmax, midy))
332+
# q2: (x<=midx, y>midy) -> child code 0b10
333+
stack.append(((prefix << 2) | 2,
334+
level+1,
335+
xmin, midy+1, midx, ymax))
336+
# q3: (x>midx, y>midy) -> child code 0b11
337+
stack.append(((prefix << 2) | 3,
338+
level+1,
339+
midx+1, midy+1, xmax, ymax))
340+
341+
return merge_adjacent(intervals) if merge else intervals
342+
343+
344+
# --------------------------
345+
# Morton intervals -> row ranges in a Morton-sorted column
346+
# --------------------------
347+
348+
def zquery_rows(morton_sorted: List[int], intervals: List[Tuple[int,int]], merge: bool = True) -> List[Tuple[int,int]]:
349+
"""
350+
For each Z-interval [zlo, zhi], binary-search in the sorted Morton column
351+
and return row index half-open ranges [i, j) to scan.
352+
"""
353+
ranges: List[Tuple[int,int]] = []
354+
for zlo, zhi in intervals:
355+
i = bisect_left(morton_sorted, zlo)
356+
j = bisect_right(morton_sorted, zhi)
357+
if i < j:
358+
ranges.append((i, j))
359+
return merge_adjacent(ranges) if merge else ranges
360+
361+

0 commit comments

Comments
 (0)