Skip to content

Commit 68c5cb9

Browse files
authored
fix: improve token assignment in late chunking (#144)
1 parent 49ddf53 commit 68c5cb9

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

src/raglite/_embed.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from raglite._typing import FloatMatrix, IntVector
1414

1515

16-
def embed_strings_with_late_chunking( # noqa: PLR0915
16+
def embed_strings_with_late_chunking( # noqa: C901,PLR0915
1717
sentences: list[str], *, config: RAGLiteConfig | None = None
1818
) -> FloatMatrix:
1919
"""Embed a document's sentences with late chunking."""
@@ -117,11 +117,15 @@ def _create_segment(
117117
segment_start_index, content_start_index, segment_end_index = segment
118118
segment_sentences = sentences[segment_start_index:segment_end_index]
119119
segment_embedding = np.asarray(embedder.embed("".join(segment_sentences)))
120-
# Split the segment embeddings into embedding matrices per sentence.
120+
# Split the segment embeddings into embedding matrices per sentence using the largest
121+
# remainder method.
121122
segment_tokens = num_tokens[segment_start_index:segment_end_index]
122-
sentence_size = np.round(
123-
len(segment_embedding) * (segment_tokens / np.sum(segment_tokens))
124-
).astype(np.intp)
123+
sentence_size_frac = len(segment_embedding) * (segment_tokens / np.sum(segment_tokens))
124+
sentence_size = np.floor(sentence_size_frac).astype(np.intp)
125+
remainder = len(segment_embedding) - np.sum(sentence_size)
126+
if remainder > 0: # Assign the remaining tokens to sentences with largest fractional parts.
127+
top_remainders = np.argsort(sentence_size_frac - sentence_size)[-remainder:]
128+
sentence_size[top_remainders] += 1
125129
sentence_matrices = np.split(segment_embedding, np.cumsum(sentence_size)[:-1])
126130
# Compute the segment sentence embeddings by averaging the token embeddings.
127131
content_sentence_embeddings = [

0 commit comments

Comments
 (0)