|
13 | 13 | from raglite._typing import FloatMatrix, IntVector |
14 | 14 |
|
15 | 15 |
|
16 | | -def embed_strings_with_late_chunking( # noqa: PLR0915 |
| 16 | +def embed_strings_with_late_chunking( # noqa: C901,PLR0915 |
17 | 17 | sentences: list[str], *, config: RAGLiteConfig | None = None |
18 | 18 | ) -> FloatMatrix: |
19 | 19 | """Embed a document's sentences with late chunking.""" |
@@ -117,11 +117,15 @@ def _create_segment( |
117 | 117 | segment_start_index, content_start_index, segment_end_index = segment |
118 | 118 | segment_sentences = sentences[segment_start_index:segment_end_index] |
119 | 119 | segment_embedding = np.asarray(embedder.embed("".join(segment_sentences))) |
120 | | - # Split the segment embeddings into embedding matrices per sentence. |
| 120 | + # Split the segment embeddings into embedding matrices per sentence using the largest |
| 121 | + # remainder method. |
121 | 122 | segment_tokens = num_tokens[segment_start_index:segment_end_index] |
122 | | - sentence_size = np.round( |
123 | | - len(segment_embedding) * (segment_tokens / np.sum(segment_tokens)) |
124 | | - ).astype(np.intp) |
| 123 | + sentence_size_frac = len(segment_embedding) * (segment_tokens / np.sum(segment_tokens)) |
| 124 | + sentence_size = np.floor(sentence_size_frac).astype(np.intp) |
| 125 | + remainder = len(segment_embedding) - np.sum(sentence_size) |
| 126 | + if remainder > 0: # Assign the remaining tokens to sentences with largest fractional parts. |
| 127 | + top_remainders = np.argsort(sentence_size_frac - sentence_size)[-remainder:] |
| 128 | + sentence_size[top_remainders] += 1 |
125 | 129 | sentence_matrices = np.split(segment_embedding, np.cumsum(sentence_size)[:-1]) |
126 | 130 | # Compute the segment sentence embeddings by averaging the token embeddings. |
127 | 131 | content_sentence_embeddings = [ |
|
0 commit comments