@@ -337,13 +337,19 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
337
337
"target_attn_1" : mock .MagicMock (),
338
338
"target_attn_2" : mock .MagicMock ()
339
339
}
340
+ target_indx_layers : dict [str , mock .MagicMock ] = {}
340
341
# Draft model has one extra attention layer compared to target model
341
342
all_attn_layers = {
342
343
** target_attn_layers , "draft_extra_attn" : mock .MagicMock ()
343
344
}
344
345
346
+ all_indx_layers : dict [str , mock .MagicMock ] = {}
347
+
345
348
# Make mock_get_layers return different values for each call
346
- mock_get_layers .side_effect = [target_attn_layers , all_attn_layers ]
349
+ mock_get_layers .side_effect = [
350
+ target_attn_layers , target_indx_layers , all_attn_layers ,
351
+ all_indx_layers
352
+ ]
347
353
348
354
# Setup mock for pp group to return the appropriate value for world size
349
355
mock_pp_group = mock .MagicMock ()
@@ -658,6 +664,9 @@ def create_deterministic_logits(token_ids, k: int):
658
664
# Mock runner for attention metadata building.
659
665
proposer .runner = mock .MagicMock ()
660
666
proposer .runner .attn_groups .append ([mock .MagicMock ()])
667
+ proposer .runner .attn_groups [0 ][0 ].metadata_builders = [
668
+ attn_metadata_builder
669
+ ]
661
670
proposer .runner .attn_groups [0 ][0 ].get_metadata_builder .return_value = \
662
671
attn_metadata_builder
663
672
proposer ._get_attention_metadata_builder = mock .MagicMock (
0 commit comments