Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 81 additions & 17 deletions tokenizers/src/tokenizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use std::{
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};

use crate::utils::batch::{BatchWorkQueue, ResultVec, TakeVec};
use crate::utils::iter::ResultShunt;
use crate::utils::parallelism::*;
use crate::utils::progress::{ProgressBar, ProgressStyle};
Expand Down Expand Up @@ -1300,7 +1301,11 @@ where
PP: PostProcessor + Send + Sync,
D: Decoder + Send + Sync,
{
/// Encode all the sentences in parallel, using multiple threads
/// Encode all the sentences in parallel, using multiple threads.
///
/// Uses a lock-free work queue with cache-line-sized windows instead of
/// rayon's `bridge_producer_consumer`, eliminating its synchronization
/// overhead at higher thread counts.
pub fn encode_batch<'s, E>(
&self,
inputs: Vec<E>,
Expand All @@ -1309,13 +1314,10 @@ where
where
E: Into<EncodeInput<'s>> + Send,
{
let mut encodings = inputs
.into_maybe_par_iter()
.map(|input| self.encode(input, add_special_tokens))
.collect::<Result<Vec<Encoding>>>()?;
let mut encodings =
self.run_batch(inputs, |this, input| this.encode(input, add_special_tokens))?;

if let Some(params) = &self.padding {
// We do the padding here to make sure we handle the batch padding
pad_encodings(&mut encodings, params)?;
}

Expand All @@ -1332,20 +1334,22 @@ where
where
E: Into<EncodeInput<'s>> + Send,
{
let mut encodings = inputs
.into_maybe_par_iter()
.map(|input| self.encode_char_offsets(input, add_special_tokens))
.collect::<Result<Vec<Encoding>>>()?;
let mut encodings = self.run_batch(inputs, |this, input| {
this.encode_char_offsets(input, add_special_tokens)
})?;

if let Some(params) = &self.padding {
// We do the padding here to make sure we handle the batch padding
pad_encodings(&mut encodings, params)?;
}

Ok(encodings)
}

/// Encode all the sentences in parallel, using multiple threads
/// Encode all the sentences in parallel, using multiple threads.
///
/// Uses a lock-free work queue with cache-line-sized windows instead of
/// rayon's `bridge_producer_consumer`, eliminating its synchronization
/// overhead at higher thread counts.
pub fn encode_batch_fast<'s, E>(
&self,
inputs: Vec<E>,
Expand All @@ -1354,19 +1358,79 @@ where
where
E: Into<EncodeInput<'s>> + Send,
{
let mut encodings = inputs
.into_maybe_par_iter()
.map(|input| self.encode_fast(input, add_special_tokens))
.collect::<Result<Vec<Encoding>>>()?;
let mut encodings = self.run_batch(inputs, |this, input| {
this.encode_fast(input, add_special_tokens)
})?;

if let Some(params) = &self.padding {
// We do the padding here to make sure we handle the batch padding
pad_encodings(&mut encodings, params)?;
}

Ok(encodings)
}

/// Shared implementation for all batch encode variants.
///
/// Distributes work items across threads using a lock-free atomic counter.
/// Each thread claims a dynamically-sized window of items, processes them,
/// and writes results directly to pre-allocated slots.
///
/// Uses `rayon::scope` to run on the existing rayon thread pool, avoiding
/// the cost of creating/destroying OS threads on every call.
fn run_batch<'s, E, F>(&self, inputs: Vec<E>, encode_fn: F) -> Result<Vec<Encoding>>
where
E: Into<EncodeInput<'s>> + Send,
F: Fn(&Self, EncodeInput<'s>) -> Result<Encoding> + Sync,
{
let n = inputs.len();
if n == 0 {
return Ok(vec![]);
}

let parallelism = get_parallelism();
let num_threads = if parallelism {
current_num_threads().min(n)
} else {
1
};

if num_threads <= 1 {
return inputs
.into_iter()
.map(|input| encode_fn(self, input.into()))
.collect::<Result<Vec<Encoding>>>();
}

// Lock-free batch distribution: atomic counter hands out
// dynamically-sized windows of item indices to worker threads.
let inputs = TakeVec::new(
inputs
.into_iter()
.map(|e| e.into())
.collect::<Vec<EncodeInput<'s>>>(),
);
let results: ResultVec<Result<Encoding>> = ResultVec::new(n);
let queue = BatchWorkQueue::new(n, num_threads);

rayon::scope(|s| {
for _ in 0..num_threads {
s.spawn(|_| {
while let Some((start, end)) = queue.claim_window() {
for i in start..end {
let input = inputs.take(i);
results.set(i, encode_fn(self, input));
}
}
});
}
});

results
.into_vec()
.into_iter()
.collect::<Result<Vec<Encoding>>>()
}

/// Decode all sentences in parallel
pub fn decode_batch(
&self,
Expand Down
Loading
Loading