@@ -138,113 +138,68 @@ TEST_F(TextPrefillerTest, PrefillCallsPrefillChunkOnceWhenPromptFits) {
138138TEST_F (
139139 TextPrefillerTest,
140140 PrefillCallsPrefillChunkMultipleTimesWhenPromptExceedsMaxLen) {
141- // Create a spy TextPrefiller with max_seq_len = 3
141+ // Create a real TextPrefiller with max_seq_len = 3 and parallel prefill
142142 const int64_t max_seq_len = 3 ;
143- auto prefiller = createMockTextPrefiller (max_seq_len);
143+ auto prefiller = createTextPrefiller (max_seq_len, true , true );
144144
145145 // Create prompt tokens with size > max_seq_len
146146 std::vector<uint64_t > prompt_tokens = {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 };
147147 int64_t start_pos = 0 ;
148148
149- // Set up expectations for prefill_chunk calls
150- {
151- InSequence seq; // Ensure calls happen in the expected order
152-
153- // First chunk: tokens [1, 2, 3]
154- EXPECT_CALL (*prefiller, prefill_chunk (_, _))
155- .WillOnce ([&](std::vector<uint64_t >& tokens, int64_t & pos) {
156- EXPECT_EQ (tokens.size (), 3 );
157- EXPECT_EQ (tokens[0 ], 1 );
158- EXPECT_EQ (tokens[1 ], 2 );
159- EXPECT_EQ (tokens[2 ], 3 );
160- EXPECT_EQ (pos, 0 );
161- return Result<uint64_t >(10 );
162- });
163-
164- // Second chunk: tokens [4, 5, 6]
165- EXPECT_CALL (*prefiller, prefill_chunk (_, _))
166- .WillOnce ([&](std::vector<uint64_t >& tokens, int64_t & pos) {
167- EXPECT_EQ (tokens.size (), 3 );
168- EXPECT_EQ (tokens[0 ], 4 );
169- EXPECT_EQ (tokens[1 ], 5 );
170- EXPECT_EQ (tokens[2 ], 6 );
171- EXPECT_EQ (pos, 3 );
172- return Result<uint64_t >(20 );
173- });
174-
175- // Third chunk: tokens [7, 8]
176- EXPECT_CALL (*prefiller, prefill_chunk (_, _))
177- .WillOnce ([&](std::vector<uint64_t >& tokens, int64_t & pos) {
178- EXPECT_EQ (tokens.size (), 2 );
179- EXPECT_EQ (tokens[0 ], 7 );
180- EXPECT_EQ (tokens[1 ], 8 );
181- EXPECT_EQ (pos, 6 );
182- return Result<uint64_t >(30 );
183- });
184- }
149+ // Track all tokens and positions passed to text_decoder_runner step
150+ struct StepCall {
151+ std::vector<uint64_t > tokens;
152+ int64_t pos;
153+ };
154+ std::vector<StepCall> step_calls;
155+
156+ // Set up expectations for text_decoder_runner step calls
157+ EXPECT_CALL (text_decoder_runner_, step (_, _))
158+ .Times (3 ) // Should be called 3 times for 3 chunks
159+ .WillRepeatedly (
160+ [&](executorch::extension::TensorPtr& tokens, int64_t pos) {
161+ // Extract token values from tensor
162+ std::vector<uint64_t > token_values;
163+ int64_t num_tokens = tokens->size (1 );
164+ auto * token_data = tokens->const_data_ptr <int64_t >();
165+ for (int64_t i = 0 ; i < num_tokens; i++) {
166+ token_values.push_back (static_cast <uint64_t >(token_data[i]));
167+ }
168+ step_calls.push_back ({token_values, pos});
169+ return Result<executorch::aten::Tensor>(tensor);
170+ });
185171
186172 // Call prefill
187173 auto result = prefiller->prefill (prompt_tokens, start_pos);
188174
189175 // Verify the result
190176 EXPECT_EQ (result.error (), Error::Ok);
191- EXPECT_EQ (result.get (), 30 ); // Should return the token from the last chunk
192-
193- // Verify that start_pos has been updated correctly
194- EXPECT_EQ (start_pos, prompt_tokens.size ());
195- }
196-
197- // Test that prefill() handles edge cases correctly
198- TEST_F (TextPrefillerTest, PrefillHandlesEdgeCasesCorrectly) {
199- // Create a spy TextPrefiller with max_seq_len = 1
200- const int64_t max_seq_len = 1 ;
201- auto prefiller = createMockTextPrefiller (max_seq_len);
202-
203- // Create prompt tokens with size > max_seq_len
204- std::vector<uint64_t > prompt_tokens = {1 , 2 , 3 };
205- int64_t start_pos = 5 ; // Non-zero starting position
206-
207- // Set up expectations for prefill_chunk calls
208- {
209- InSequence seq;
210-
211- // First chunk: token [1]
212- EXPECT_CALL (*prefiller, prefill_chunk (_, _))
213- .WillOnce ([&](std::vector<uint64_t >& tokens, int64_t & pos) {
214- EXPECT_EQ (tokens.size (), 1 );
215- EXPECT_EQ (tokens[0 ], 1 );
216- EXPECT_EQ (pos, 5 );
217- return Result<uint64_t >(10 );
218- });
219-
220- // Second chunk: token [2]
221- EXPECT_CALL (*prefiller, prefill_chunk (_, _))
222- .WillOnce ([&](std::vector<uint64_t >& tokens, int64_t & pos) {
223- EXPECT_EQ (tokens.size (), 1 );
224- EXPECT_EQ (tokens[0 ], 2 );
225- EXPECT_EQ (pos, 6 );
226- return Result<uint64_t >(20 );
227- });
228-
229- // Third chunk: token [3]
230- EXPECT_CALL (*prefiller, prefill_chunk (_, _))
231- .WillOnce ([&](std::vector<uint64_t >& tokens, int64_t & pos) {
232- EXPECT_EQ (tokens.size (), 1 );
233- EXPECT_EQ (tokens[0 ], 3 );
234- EXPECT_EQ (pos, 7 );
235- return Result<uint64_t >(30 );
236- });
237- }
238-
239- // Call prefill
240- auto result = prefiller->prefill (prompt_tokens, start_pos);
241177
242- // Verify the result
243- EXPECT_EQ (result.error (), Error::Ok);
244- EXPECT_EQ (result.get (), 30 );
178+ // Verify that step was called 3 times with correct tokens and positions
179+ ASSERT_EQ (step_calls.size (), 3 );
180+
181+ // First chunk: tokens [1, 2, 3] at position 0
182+ EXPECT_EQ (step_calls[0 ].tokens .size (), 3 );
183+ EXPECT_EQ (step_calls[0 ].tokens [0 ], 1 );
184+ EXPECT_EQ (step_calls[0 ].tokens [1 ], 2 );
185+ EXPECT_EQ (step_calls[0 ].tokens [2 ], 3 );
186+ EXPECT_EQ (step_calls[0 ].pos , 0 );
187+
188+ // Second chunk: tokens [4, 5, 6] at position 3
189+ EXPECT_EQ (step_calls[1 ].tokens .size (), 3 );
190+ EXPECT_EQ (step_calls[1 ].tokens [0 ], 4 );
191+ EXPECT_EQ (step_calls[1 ].tokens [1 ], 5 );
192+ EXPECT_EQ (step_calls[1 ].tokens [2 ], 6 );
193+ EXPECT_EQ (step_calls[1 ].pos , 3 );
194+
195+ // Third chunk: tokens [7, 8] at position 6
196+ EXPECT_EQ (step_calls[2 ].tokens .size (), 2 );
197+ EXPECT_EQ (step_calls[2 ].tokens [0 ], 7 );
198+ EXPECT_EQ (step_calls[2 ].tokens [1 ], 8 );
199+ EXPECT_EQ (step_calls[2 ].pos , 6 );
245200
246201 // Verify that start_pos has been updated correctly
247- EXPECT_EQ (start_pos, 8 ); // 5 (initial) + 3 (tokens)
202+ EXPECT_EQ (start_pos, prompt_tokens. size ());
248203}
249204
250205// Test that prefill() handles errors from prefill_chunk correctly
@@ -305,4 +260,119 @@ TEST_F(TextPrefillerTest, PrefillChunkWorksWithParallelPrefill) {
305260 // Verify that start_pos has been updated correctly
306261 EXPECT_EQ (start_pos, prompt_tokens.size ());
307262}
263+ // Test that prefill_chunk updates start_pos correctly with parallel prefill
264+ TEST_F (TextPrefillerTest, PrefillChunkUpdatesStartPosCorrectlyParallel) {
265+ // Create a TextPrefiller with parallel prefill enabled
266+ auto prefiller = createTextPrefiller (10 , true , true );
267+
268+ // Set up expectations for the text decoder runner
269+ int64_t captured_pos = -1 ;
270+ EXPECT_CALL (text_decoder_runner_, step (_, _))
271+ .WillOnce ([&](executorch::extension::TensorPtr& tokens, int64_t pos) {
272+ captured_pos = pos;
273+ // Verify tokens shape is [1, num_tokens]
274+ EXPECT_EQ (tokens->dim (), 2 );
275+ EXPECT_EQ (tokens->size (0 ), 1 );
276+ EXPECT_EQ (tokens->size (1 ), 3 );
277+ return Result<executorch::aten::Tensor>(tensor);
278+ });
279+
280+ // Create prompt tokens
281+ std::vector<uint64_t > prompt_tokens = {1 , 2 , 3 };
282+ int64_t start_pos = 5 ; // Non-zero starting position
283+
284+ // Call prefill_chunk directly
285+ auto result = prefiller->prefill_chunk (prompt_tokens, start_pos);
286+
287+ // Verify the result
288+ EXPECT_EQ (result.error (), Error::Ok);
289+
290+ // Verify that step was called with the original start_pos
291+ EXPECT_EQ (captured_pos, 5 );
292+
293+ // Verify that start_pos has been updated by the number of tokens
294+ // This is the key test: start_pos should be updated exactly once
295+ EXPECT_EQ (start_pos, 8 ); // 5 + 3 tokens
296+ }
297+
298+ // Test that prefill_chunk updates start_pos correctly with sequential prefill
299+ TEST_F (TextPrefillerTest, PrefillChunkUpdatesStartPosCorrectlySequential) {
300+ // Create a TextPrefiller with sequential prefill (parallel disabled)
301+ auto prefiller = createTextPrefiller (10 , true , false );
302+
303+ // Track all positions passed to step
304+ std::vector<int64_t > captured_positions;
305+ EXPECT_CALL (text_decoder_runner_, step (_, _))
306+ .Times (3 )
307+ .WillRepeatedly (
308+ [&](executorch::extension::TensorPtr& tokens, int64_t pos) {
309+ captured_positions.push_back (pos);
310+ // Verify tokens shape is [1, 1] for sequential prefill
311+ EXPECT_EQ (tokens->dim (), 2 );
312+ EXPECT_EQ (tokens->size (0 ), 1 );
313+ EXPECT_EQ (tokens->size (1 ), 1 );
314+ return Result<executorch::aten::Tensor>(tensor);
315+ });
316+
317+ // Create prompt tokens
318+ std::vector<uint64_t > prompt_tokens = {1 , 2 , 3 };
319+ int64_t start_pos = 10 ; // Non-zero starting position
320+
321+ // Call prefill_chunk directly
322+ auto result = prefiller->prefill_chunk (prompt_tokens, start_pos);
323+
324+ // Verify the result
325+ EXPECT_EQ (result.error (), Error::Ok);
326+
327+ // Verify that step was called with incrementing positions
328+ ASSERT_EQ (captured_positions.size (), 3 );
329+ EXPECT_EQ (captured_positions[0 ], 10 ); // First token at initial start_pos
330+ EXPECT_EQ (captured_positions[1 ], 11 ); // Second token at start_pos + 1
331+ EXPECT_EQ (captured_positions[2 ], 12 ); // Third token at start_pos + 2
332+
333+ // Verify that start_pos has been updated by the number of tokens
334+ // This is the key test: start_pos should be updated exactly once per token
335+ EXPECT_EQ (start_pos, 13 ); // 10 + 3 tokens
336+ }
337+
338+ // Test that prefill with chunking updates start_pos correctly across chunks.
339+ // This test would have caught the bug where start_pos was being updated twice.
340+ TEST_F (
341+ TextPrefillerTest,
342+ PrefillWithChunkingUpdatesStartPosCorrectlyAcrossChunks) {
343+ // Create a TextPrefiller with max_seq_len = 3 and parallel prefill
344+ auto prefiller = createTextPrefiller (3 , true , true );
345+
346+ // Track all positions passed to step
347+ std::vector<int64_t > captured_positions;
348+ EXPECT_CALL (text_decoder_runner_, step (_, _))
349+ .Times (3 ) // Should be called 3 times: [1,2,3], [4,5,6], [7,8]
350+ .WillRepeatedly (
351+ [&](executorch::extension::TensorPtr& tokens, int64_t pos) {
352+ captured_positions.push_back (pos);
353+ return Result<executorch::aten::Tensor>(tensor);
354+ });
355+
356+ // Create prompt tokens that exceed max_seq_len
357+ std::vector<uint64_t > prompt_tokens = {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 };
358+ int64_t start_pos = 100 ; // Non-zero starting position
359+
360+ // Call prefill (which will chunk internally)
361+ auto result = prefiller->prefill (prompt_tokens, start_pos);
362+
363+ // Verify the result
364+ EXPECT_EQ (result.error (), Error::Ok);
365+
366+ // Verify that step was called with correct positions for each chunk
367+ // If start_pos were updated twice (the bug), these would be wrong
368+ ASSERT_EQ (captured_positions.size (), 3 );
369+ EXPECT_EQ (captured_positions[0 ], 100 ); // Chunk 1: tokens [1,2,3]
370+ EXPECT_EQ (captured_positions[1 ], 103 ); // Chunk 2: tokens [4,5,6]
371+ EXPECT_EQ (captured_positions[2 ], 106 ); // Chunk 3: tokens [7,8]
372+
373+ // Verify that final start_pos is correct
374+ // This is the key test for the bug: start_pos should be exactly
375+ // initial_pos + num_tokens, not double-incremented
376+ EXPECT_EQ (start_pos, 108 ); // 100 + 8 tokens
377+ }
308378} // namespace
0 commit comments