Skip to content

Commit 9b6ea3c

Browse files
authored
Combine tallying of buckets for both passes of float32 sorting (#132)
1 parent c1a335e commit 9b6ea3c

File tree

2 files changed

+48
-43
lines changed

2 files changed

+48
-43
lines changed

rust/spark-internal-rs/src/sort.rs

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,9 @@ pub struct Sort32Buffers {
7676
/// output indices
7777
pub ordering: Vec<u32>,
7878
/// bucket counts / offsets (length == RADIX_BASE)
79-
pub buckets16: Vec<u32>,
79+
pub buckets16lo: Vec<u32>,
80+
/// bucket counts / offsets (length == RADIX_BASE)
81+
pub buckets16hi: Vec<u32>,
8082
/// scratch space for indices
8183
pub scratch: Vec<u32>,
8284
}
@@ -93,8 +95,11 @@ impl Sort32Buffers {
9395
if self.scratch.len() < max_splats {
9496
self.scratch.resize(max_splats, 0);
9597
}
96-
if self.buckets16.len() < RADIX_BASE {
97-
self.buckets16.resize(RADIX_BASE, 0);
98+
if self.buckets16lo.len() < RADIX_BASE {
99+
self.buckets16lo.resize(RADIX_BASE, 0);
100+
}
101+
if self.buckets16hi.len() < RADIX_BASE {
102+
self.buckets16hi.resize(RADIX_BASE, 0);
98103
}
99104
}
100105
}
@@ -109,20 +114,24 @@ pub fn sort32_internal(
109114
// make sure our buffers can hold `max_splats`
110115
buffers.ensure_size(max_splats);
111116

112-
let Sort32Buffers { readback, ordering, buckets16, scratch } = buffers;
117+
let Sort32Buffers { readback, ordering, buckets16lo, buckets16hi, scratch } = buffers;
113118
let keys = &readback[..num_splats];
114119

115-
// ——— Pass #1: bucket by inv(low 16 bits) ———
116-
buckets16.fill(0);
120+
// tally low and high buckets
121+
buckets16lo.fill(0);
122+
buckets16hi.fill(0);
117123
for &key in keys.iter() {
118124
if key < DEPTH_INFINITY_F32 {
119125
let inv = !key;
120-
buckets16[(inv & 0xFFFF) as usize] += 1;
126+
buckets16lo[(inv & 0xFFFF) as usize] += 1;
127+
buckets16hi[(inv >> 16) as usize] += 1;
121128
}
122129
}
130+
131+
// ——— Pass #1: bucket by inv(low 16 bits) ———
123132
// exclusive prefix‑sum → starting offsets
124133
let mut total: u32 = 0;
125-
for slot in buckets16.iter_mut() {
134+
for slot in buckets16lo.iter_mut() {
126135
let cnt = *slot;
127136
*slot = total;
128137
total = total.wrapping_add(cnt);
@@ -134,21 +143,15 @@ pub fn sort32_internal(
134143
if key < DEPTH_INFINITY_F32 {
135144
let inv = !key;
136145
let lo = (inv & 0xFFFF) as usize;
137-
scratch[buckets16[lo] as usize] = i as u32;
138-
buckets16[lo] += 1;
146+
scratch[buckets16lo[lo] as usize] = i as u32;
147+
buckets16lo[lo] += 1;
139148
}
140149
}
141150

142151
// ——— Pass #2: bucket by inv(high 16 bits) ———
143-
buckets16.fill(0);
144-
for &idx in scratch.iter().take(active_splats as usize) {
145-
let key = keys[idx as usize];
146-
let inv = !key;
147-
buckets16[(inv >> 16) as usize] += 1;
148-
}
149152
// exclusive prefix‑sum again
150153
let mut sum: u32 = 0;
151-
for slot in buckets16.iter_mut() {
154+
for slot in buckets16hi.iter_mut() {
152155
let cnt = *slot;
153156
*slot = sum;
154157
sum = sum.wrapping_add(cnt);
@@ -158,16 +161,16 @@ pub fn sort32_internal(
158161
let key = keys[idx as usize];
159162
let inv = !key;
160163
let hi = (inv >> 16) as usize;
161-
ordering[buckets16[hi] as usize] = idx;
162-
buckets16[hi] += 1;
164+
ordering[buckets16hi[hi] as usize] = idx;
165+
buckets16hi[hi] += 1;
163166
}
164167

165168
// sanity‑check: last bucket should have consumed all entries
166-
if buckets16[RADIX_BASE - 1] != active_splats {
169+
if buckets16hi[RADIX_BASE - 1] != active_splats {
167170
return Err(anyhow!(
168171
"Expected {} active splats but got {}",
169172
active_splats,
170-
buckets16[RADIX_BASE - 1]
173+
buckets16hi[RADIX_BASE - 1]
171174
));
172175
}
173176

src/worker.ts

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,8 @@ function sortDoubleSplats({
527527
}
528528

529529
const DEPTH_INFINITY_F32 = 0x7f800000;
530-
let bucket16: Uint32Array | null = null;
530+
let bucket16lo: Uint32Array | null = null;
531+
let bucket16hi: Uint32Array | null = null;
531532
let scratchSplats: Uint32Array | null = null;
532533

533534
// two-pass radix sort (base 65536) of 32-bit keys in readback,
@@ -546,29 +547,36 @@ function sort32Splats({
546547
const BASE = 1 << 16; // 65536
547548

548549
// allocate once
549-
if (!bucket16) {
550-
bucket16 = new Uint32Array(BASE);
550+
if (!bucket16lo) {
551+
bucket16lo = new Uint32Array(BASE);
552+
}
553+
if (!bucket16hi) {
554+
bucket16hi = new Uint32Array(BASE);
551555
}
552556
if (!scratchSplats || scratchSplats.length < maxSplats) {
553557
scratchSplats = new Uint32Array(maxSplats);
554558
}
555559

556-
//
557-
// ——— Pass #1: bucket by inv(lo 16 bits) ———
558-
//
559-
bucket16.fill(0);
560+
// tally low and high buckets
561+
bucket16lo.fill(0);
562+
bucket16hi.fill(0);
560563
for (let i = 0; i < numSplats; ++i) {
561564
const key = readback[i];
562565
if (key < DEPTH_INFINITY_F32) {
563566
const inv = ~key >>> 0;
564-
bucket16[inv & 0xffff] += 1;
567+
bucket16lo[inv & 0xffff] += 1;
568+
bucket16hi[inv >>> 16] += 1;
565569
}
566570
}
571+
572+
//
573+
// ——— Pass #1: bucket by inv(lo 16 bits) ———
574+
//
567575
// exclusive prefix‑sum → starting offsets
568576
let total = 0;
569577
for (let b = 0; b < BASE; ++b) {
570-
const c = bucket16[b];
571-
bucket16[b] = total;
578+
const c = bucket16lo[b];
579+
bucket16lo[b] = total;
572580
total += c;
573581
}
574582
const activeSplats = total;
@@ -578,38 +586,32 @@ function sort32Splats({
578586
const key = readback[i];
579587
if (key < DEPTH_INFINITY_F32) {
580588
const inv = ~key >>> 0;
581-
scratchSplats[bucket16[inv & 0xffff]++] = i;
589+
scratchSplats[bucket16lo[inv & 0xffff]++] = i;
582590
}
583591
}
584592

585593
//
586594
// ——— Pass #2: bucket by inv(hi 16 bits) ———
587595
//
588-
bucket16.fill(0);
589-
for (let k = 0; k < activeSplats; ++k) {
590-
const idx = scratchSplats[k];
591-
const inv = ~readback[idx] >>> 0;
592-
bucket16[inv >>> 16] += 1;
593-
}
594596
// exclusive prefix‑sum again
595597
let sum = 0;
596598
for (let b = 0; b < BASE; ++b) {
597-
const c = bucket16[b];
598-
bucket16[b] = sum;
599+
const c = bucket16hi[b];
600+
bucket16hi[b] = sum;
599601
sum += c;
600602
}
601603

602604
// scatter into final ordering by high bits of inv
603605
for (let k = 0; k < activeSplats; ++k) {
604606
const idx = scratchSplats[k];
605607
const inv = ~readback[idx] >>> 0;
606-
ordering[bucket16[inv >>> 16]++] = idx;
608+
ordering[bucket16hi[inv >>> 16]++] = idx;
607609
}
608610

609611
// sanity‑check: the last bucket should have eaten all entries
610-
if (bucket16[BASE - 1] !== activeSplats) {
612+
if (bucket16hi[BASE - 1] !== activeSplats) {
611613
throw new Error(
612-
`Expected ${activeSplats} active splats but got ${bucket16[BASE - 1]}`,
614+
`Expected ${activeSplats} active splats but got ${bucket16hi[BASE - 1]}`,
613615
);
614616
}
615617

0 commit comments

Comments
 (0)