Batch encode: lock-free work queue with dynamic window sizing#2029
Batch encode: lock-free work queue with dynamic window sizing#2029sebpop wants to merge 1 commit intohuggingface:mainfrom
Conversation
|
/benchmark |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Replace `inputs.into_maybe_par_iter().map(...).collect()` in `encode_batch`, `encode_batch_char_offsets` and `encode_batch_fast` with a small helper `TokenizerImpl::run_batch` that: - Dispatches to a plain `inputs.into_iter().map(...).collect()` serial loop when parallelism is disabled or only one thread is available, avoiding all rayon involvement for single-threaded callers. - At higher thread counts, uses a lock-free atomic counter (`BatchWorkQueue`) inside one `rayon::scope` with one `s.spawn` per worker. Each worker claims windows of item indices via a single `AtomicUsize::fetch_add`, takes inputs from per-slot `UnsafeCell<Option<EncodeInput>>`, and writes results into per-slot `UnsafeCell<Option<Result<Encoding>>>`. No shared mutable state outside the counter; no final `collect()` on a parallel iterator. The lock-free design is motivated by aarch64 LSE atomic cost: every mutex / condvar the previous parallel-iterator path took hit was a CAS / LDADD emitted by libpthread, and those dominate small-work parallel loops at high thread counts on arm64. Replacing that with a single `fetch_add` per window removes the mutex-backed per-item signaling entirely. ## Cache-line / loop-tiling rationale Shared-memory parallel loops are bottlenecked by the cache coherence protocol when two cores alternate writes to the same cache line: the line "ping-pongs" between their private L1d caches, each transfer costing dozens of cycles. To avoid that, every line should be filled by one producer core, drained (or no longer needed), and only then touched by a different core. This is the cache-aware equivalent of loop tiling / blocking: group the iteration space into chunks whose data footprint is a whole number of cache lines, and give each chunk to a single core. The work queue enforces this three ways: 1. The counter itself lives on its own 64-byte cache line (`#[repr(C, align(64))]` on `AlignedCounter`). A worker's `fetch_add` does not evict any neighbouring data, and reads of the counter do not pull input or result payloads into the core's L1d. 2. Each window is a contiguous run of `window_size` indices, so every worker owns a run of adjacent slots for the duration of one window. With `MAX_WINDOW_SIZE = 8`, a window covers roughly `8 * sizeof(slot)` bytes -- for `Option<EncodeInput>` (~48 B) that is ~6 cache lines; for `Option<Result<Encoding>>` (multi-line per slot) it is even more. Within one window, a worker writes several whole cache lines before any other worker comes near them. 3. Each slot has its own `UnsafeCell` (`Vec<UnsafeCell<Option<T>>>`). `UnsafeCell<T>` is `#[repr(transparent)]`, so the heap layout is byte-identical to a plain `Vec<Option<T>>` (no padding, same alignment, same contiguous packing -- zero runtime overhead vs. the "unsafe fast" version that reborrows the whole `Vec`). What the per-slot cell buys is that `self.0[i].get()` returns `*mut Option<T>` pointing straight at slot `i`, without ever materialising a `&mut Vec<Option<T>>` that would alias the enclosing container (which is UB when two threads touch any distinct indices concurrently). At window boundaries a single cache line can be shared between two successive windows when the slot size does not divide 64 bytes. That is a sequential handoff (window N finishes writes; window N+1 then reads/writes), not a concurrent ping-pong, so the cost is at most one coherence transfer per window-pair. ## Window sizing `window_size = ceil(total / (num_threads * WINDOWS_PER_THREAD))`, clamped to `[1, MAX_WINDOW_SIZE]`. - `WINDOWS_PER_THREAD = 4` keeps several windows per thread so a slow worker on its last item does not stall the whole batch. - `MAX_WINDOW_SIZE = 8` caps per-claim atomic latency and keeps the per-window memory footprint small enough to fit in L1d. Examples: 100 items / 16 threads yields `window_size = 2` (50 windows); 10 000 items / 16 threads yields `window_size = 8` (1250 windows). ## Tests 7 new unit tests in `utils::batch::tests` cover window sizing, `TakeVec` and `ResultVec` round-trip, and `test_parallel_distribution` (4 threads concurrently claiming and writing 100 slots, exercising the Sync bounds under real contention). cargo test --lib --features http: 208 passed, 0 failed. ## Perf evidence On Vera (88-core Olympus, 176 logical), `bpe_benchmark`/`bpe-encode/BPE GPT2 encode batch` at 88T, `perf record -g --call-graph fp -F 4999`. LSE atomic instructions (the direct motivation for the lock-free counter): instruction before after __aarch64_cas4_acq 3.57% 0.61% (-5.9x) __aarch64_ldadd8_acq_rel 1.05% 0.08% (-13x) __aarch64_swp4_rel 0.21% 0.05% __aarch64_ldadd8_relax 0.17% 0.24% __aarch64_swp4_acq 0.12% 0.00% __aarch64_swp8_acq_rel 0.06% 0.00% __aarch64_cas8_acq_rel 0.01% 0.01% total LSE ~5.2% ~1.0% (-4.2x) Rayon / crossbeam-epoch: symbol before after rayon_core::sleep::Sleep::wake_specific_thread 0.57% 0.06% (-10x) crossbeam_epoch::internal::Global::try_advance 25.93% 28.38% crossbeam_epoch::default::with_handle 21.41% 23.12% rayon_core::registry::WorkerThread::wait_until_cold 8.40% 10.72% rayon::iter::plumbing::bridge_producer_consumer::helper 0.20% 0.24% `bridge_producer_consumer::helper` was not a hotspot on this workload before the change (0.20%) and does not move; the observable rayon-side change is `Sleep::wake_specific_thread` dropping ~10x because `rayon::scope` issues one wake per worker per batch call rather than streaming wakes per parallel-iterator split. The three remaining rayon/crossbeam ceiling symbols (`try_advance` + `with_handle` + `wait_until_cold` = ~62% of cycles) stay similar in percentage because total cycles decrease; absolute wall-clock per benchmark iteration drops 35 ms (295 ms -> 260 ms at 88T). Removing that rayon ceiling is a separate change. Throughput on Vera, `bpe-encode/BPE GPT2 encode batch` (data/big.txt, encode_batch through the full post-processor): threads before after change ------- ------ ------ ------ 1T 3.98 MiB/s 4.46 MiB/s +12% 88T 20.97 MiB/s 23.76 MiB/s +13% 176T 18.83 MiB/s 21.58 MiB/s +15%
|
Small semantic note, since @codex flagged it: I chose not to add a shared stop flag because it would add hot-path polling to every batch to improve a rare cold path, and because deterministic-error-at-lowest-index is arguably the better user contract. This tradeoff follows Brendan Gregg's Utilization Saturation and Errors (USE) Method: optimize the common Utilization/Saturation path while keeping the rare Error path correct and bounded. |
|
@sebpop Great timing. I was also looking at #1900 myself. If there is a preferred benchmark command etc for this test, I am happy to use that @sebpop Else I’ll drop a compact summary table and the raw data here once I have the machines spun up and the runs finished. |
|
Thanks @vyalamar, very welcome. See the recipe below — tested on Vera and Grace (Nvidia arm64); should reproduce on x86_64 (Intel and AMD) and on arm64 Graviton with no fiddling. Build the benchmark: Run: You can also run the benchmark under https://github.com/aws/aperf and check flamegraphs and other PMU metrics. I expect AMD and Intel CPUs to behave similarly to arm64. Let me know if you hit any roadblocks. |
Replace
inputs.into_maybe_par_iter().map(...).collect()inencode_batch,encode_batch_char_offsetsandencode_batch_fastwith a small helperTokenizerImpl::run_batchthat:inputs.into_iter().map(...).collect()serial loop when parallelism is disabled or only one thread is available, avoiding all rayon involvement for single-threaded callers.BatchWorkQueue) inside onerayon::scopewith ones.spawnper worker. Each worker claims windows of item indices via a singleAtomicUsize::fetch_add, takes inputs from per-slotUnsafeCell<Option<EncodeInput>>, and writes results into per-slotUnsafeCell<Option<Result<Encoding>>>. No shared mutable state outside the counter; no finalcollect()on a parallel iterator.The lock-free design is motivated by aarch64 LSE atomic cost: every mutex / condvar the previous parallel-iterator path took hit was a CAS / LDADD emitted by libpthread, and those dominate small-work parallel loops at high thread counts on arm64. Replacing that with a single
fetch_addper window removes the mutex-backed per-item signaling entirely.Cache-line / loop-tiling rationale
Shared-memory parallel loops are bottlenecked by the cache coherence protocol when two cores alternate writes to the same cache line: the line "ping-pongs" between their private L1d caches, each transfer costing dozens of cycles. To avoid that, every line should be filled by one producer core, drained (or no longer needed), and only then touched by a different core. This is the cache-aware equivalent of loop tiling / blocking: group the iteration space into chunks whose data footprint is a whole number of cache lines, and give each chunk to a single core.
The work queue enforces this three ways:
The counter itself lives on its own 64-byte cache line (
#[repr(C, align(64))]onAlignedCounter). A worker'sfetch_adddoes not evict any neighbouring data, and reads of the counter do not pull input or result payloads into the core's L1d.Each window is a contiguous run of
window_sizeindices, so every worker owns a run of adjacent slots for the duration of one window. WithMAX_WINDOW_SIZE = 8, a window covers roughly8 * sizeof(slot)bytes -- forOption<EncodeInput>(~48 B) that is ~6 cache lines; forOption<Result<Encoding>>(multi-line per slot) it is even more. Within one window, a worker writes several whole cache lines before any other worker comes near them.Each slot has its own
UnsafeCell(Vec<UnsafeCell<Option<T>>>).UnsafeCell<T>is#[repr(transparent)], so the heap layout is byte-identical to a plainVec<Option<T>>(no padding, same alignment, same contiguous packing -- zero runtime overhead vs. the "unsafe fast" version that reborrows the wholeVec). What the per-slot cell buys is thatself.0[i].get()returns*mut Option<T>pointing straight at sloti, without ever materialising a&mut Vec<Option<T>>that would alias the enclosing container (which is UB when two threads touch any distinct indices concurrently).At window boundaries a single cache line can be shared between two successive windows when the slot size does not divide 64 bytes. That is a sequential handoff (window N finishes writes; window N+1 then reads/writes), not a concurrent ping-pong, so the cost is at most one coherence transfer per window-pair.
Window sizing
window_size = ceil(total / (num_threads * WINDOWS_PER_THREAD)), clamped to[1, MAX_WINDOW_SIZE].WINDOWS_PER_THREAD = 4keeps several windows per thread so a slow worker on its last item does not stall the whole batch.MAX_WINDOW_SIZE = 8caps per-claim atomic latency and keeps the per-window memory footprint small enough to fit in L1d.Examples: 100 items / 16 threads yields
window_size = 2(50 windows); 10 000 items / 16 threads yieldswindow_size = 8(1250 windows).Tests
7 new unit tests in
utils::batch::testscover window sizing,TakeVecandResultVecround-trip, andtest_parallel_distribution(4 threads concurrently claiming and writing 100 slots, exercising the Sync bounds under real contention).cargo test --lib --features http: 208 passed, 0 failed.
Perf evidence
On Vera (88-core Olympus, 176 logical),
bpe_benchmark/bpe-encode/BPE GPT2 encode batchat 88T,perf record -g --call-graph fp -F 4999.LSE atomic instructions (the direct motivation for the lock-free counter):
Rayon / crossbeam-epoch:
bridge_producer_consumer::helperwas not a hotspot on this workload before the change (0.20%) and does not move; the observable rayon-side change isSleep::wake_specific_threaddropping ~10x becauserayon::scopeissues one wake per worker per batch call rather than streaming wakes per parallel-iterator split. The three remaining rayon/crossbeam ceiling symbols (try_advance+with_handle+wait_until_cold= ~62% of cycles) stay similar in percentage because total cycles decrease; absolute wall-clock per benchmark iteration drops 35 ms (295 ms -> 260 ms at 88T). Removing that rayon ceiling is a separate change.Throughput on Vera,
bpe-encode/BPE GPT2 encode batch(data/big.txt, encode_batch through the full post-processor):