@@ -324,6 +324,7 @@ def test_prepare_inputs_padded():
324324@pytest .mark .parametrize ("attn_backend" , get_attn_backend_list_based_on_platform ())
325325@pytest .mark .parametrize ("pp_size" , [1 , 2 ])
326326@pytest .mark .parametrize ("use_distinct_embed_tokens" , [True , False ])
327+ @pytest .mark .parametrize ("use_distinct_lm_head" , [True , False ])
327328@mock .patch ("vllm.v1.spec_decode.eagle.get_pp_group" )
328329@mock .patch ("vllm.v1.spec_decode.eagle.get_layers_from_vllm_config" )
329330@mock .patch ("vllm.v1.spec_decode.eagle.get_model" )
@@ -335,6 +336,7 @@ def test_load_model(
335336 attn_backend ,
336337 pp_size ,
337338 use_distinct_embed_tokens ,
339+ use_distinct_lm_head ,
338340 monkeypatch ,
339341):
340342 monkeypatch .setenv ("VLLM_ATTENTION_BACKEND" , attn_backend )
@@ -350,12 +352,13 @@ def test_load_model(
350352
351353 # Setup draft model mock
352354 mock_model = mock .MagicMock ()
355+ mock_model .model = mock .MagicMock ()
356+ mock_model .has_own_embed_tokens = use_distinct_embed_tokens
353357 if use_distinct_embed_tokens :
354- # Some models can have a different hidden size than the target model,
355- # so we test that their embed_tokens doesn't get overwritten
356- mock_model .model .embed_tokens .weight .shape = (131072 , 2048 )
357- else :
358- mock_model .model .embed_tokens .weight .shape = (131072 , 4096 )
358+ mock_model .model .embed_tokens = mock .MagicMock ()
359+ mock_model .has_own_lm_head = use_distinct_lm_head
360+ if use_distinct_lm_head :
361+ mock_model .lm_head = mock .MagicMock ()
359362
360363 mock_get_model .return_value = mock_model
361364
@@ -391,15 +394,13 @@ class _TargetModelStub(LlamaForCausalLM):
391394
392395 target_model = mock .create_autospec (_TargetModelStub , instance = True )
393396 target_model .model = mock .MagicMock ()
394- target_model .model .embed_tokens .weight .shape = (131072 , 4096 )
397+ target_model .lm_head = mock .MagicMock ()
398+ target_model .model .embed_tokens = mock .MagicMock ()
395399
396400 from vllm .model_executor .models import SupportsMultiModal
397401
398402 assert not isinstance (target_model , SupportsMultiModal )
399403
400- if method == "eagle" :
401- target_model .lm_head = mock .MagicMock ()
402-
403404 # Create proposer using the helper function
404405 proposer = _create_proposer (method , num_speculative_tokens = 8 )
405406
@@ -409,18 +410,18 @@ class _TargetModelStub(LlamaForCausalLM):
409410 # Verify common interactions
410411 mock_get_model .assert_called_once ()
411412
412- # Verify that EAGLE models gain the lm head from the target model
413- if method == "eagle" :
414- assert proposer .model .lm_head == target_model .lm_head
413+ # Verify that the lm head is set correctly
414+ if use_distinct_lm_head :
415+ assert proposer .model .lm_head is not target_model .lm_head
416+ else :
417+ assert proposer .model .lm_head is target_model .lm_head
415418
416419 # Verify that the embed tokens are set correctly
417420 # If pp_size is > 1, the embed tokens should be distinct
418421 if pp_size > 1 or use_distinct_embed_tokens :
419- assert proposer .model .model .embed_tokens != target_model .model .embed_tokens
422+ assert proposer .model .model .embed_tokens is not target_model .model .embed_tokens
420423 else :
421- # When pp_size is 1 and the draft and target models have
422- # embed_tokens of the same shape, they should be shared.
423- assert proposer .model .model .embed_tokens == target_model .model .embed_tokens
424+ assert proposer .model .model .embed_tokens is target_model .model .embed_tokens
424425
425426
426427@pytest .mark .parametrize ("method" , ["eagle" , "eagle3" ])
0 commit comments