22
33Run `pytest tests/models/test_mistral.py`.
44"""
5- import pickle
5+ import json
66import uuid
7- from typing import Any , Dict , List
7+ from dataclasses import asdict
8+ from typing import Any , Dict , List , Optional , Tuple
89
910import pytest
1011from mistral_common .protocol .instruct .messages import ImageURLChunk
1415
1516from vllm import EngineArgs , LLMEngine , SamplingParams , TokensPrompt
1617from vllm .multimodal import MultiModalDataBuiltins
18+ from vllm .sequence import Logprob , SampleLogprobs
1719
1820from .utils import check_logprobs_close
1921
@@ -81,13 +83,33 @@ def _create_engine_inputs(urls: List[str]) -> TokensPrompt:
8183LIMIT_MM_PER_PROMPT = dict (image = 4 )
8284
8385MAX_MODEL_LEN = [8192 , 65536 ]
84- FIXTURE_LOGPROBS_CHAT = "tests/models/fixtures/pixtral_chat.pickle "
85- FIXTURE_LOGPROBS_ENGINE = "tests/models/fixtures/pixtral_chat_engine.pickle "
86+ FIXTURE_LOGPROBS_CHAT = "tests/models/fixtures/pixtral_chat.json "
87+ FIXTURE_LOGPROBS_ENGINE = "tests/models/fixtures/pixtral_chat_engine.json "
8688
89+ OutputsLogprobs = List [Tuple [List [int ], str , Optional [SampleLogprobs ]]]
8790
88- def load_logprobs (filename : str ) -> Any :
89- with open (filename , 'rb' ) as f :
90- return pickle .load (f )
91+
92+ # For the test author to store golden output in JSON
93+ def _dump_outputs_w_logprobs (outputs : OutputsLogprobs , filename : str ) -> None :
94+ json_data = [(tokens , text ,
95+ [{k : asdict (v )
96+ for k , v in token_logprobs .items ()}
97+ for token_logprobs in (logprobs or [])])
98+ for tokens , text , logprobs in outputs ]
99+
100+ with open (filename , "w" ) as f :
101+ json .dump (json_data , f )
102+
103+
104+ def load_outputs_w_logprobs (filename : str ) -> OutputsLogprobs :
105+ with open (filename , "rb" ) as f :
106+ json_data = json .load (f )
107+
108+ return [(tokens , text ,
109+ [{int (k ): Logprob (** v )
110+ for k , v in token_logprobs .items ()}
111+ for token_logprobs in logprobs ])
112+ for tokens , text , logprobs in json_data ]
91113
92114
93115@pytest .mark .skip (
@@ -103,7 +125,7 @@ def test_chat(
103125 model : str ,
104126 dtype : str ,
105127) -> None :
106- EXPECTED_CHAT_LOGPROBS = load_logprobs (FIXTURE_LOGPROBS_CHAT )
128+ EXPECTED_CHAT_LOGPROBS = load_outputs_w_logprobs (FIXTURE_LOGPROBS_CHAT )
107129 with vllm_runner (
108130 model ,
109131 dtype = dtype ,
@@ -120,10 +142,10 @@ def test_chat(
120142 outputs .extend (output )
121143
122144 logprobs = vllm_runner ._final_steps_generate_w_logprobs (outputs )
123- check_logprobs_close (outputs_0_lst = logprobs ,
124- outputs_1_lst = EXPECTED_CHAT_LOGPROBS ,
125- name_0 = "output " ,
126- name_1 = "h100_ref " )
145+ check_logprobs_close (outputs_0_lst = EXPECTED_CHAT_LOGPROBS ,
146+ outputs_1_lst = logprobs ,
147+ name_0 = "h100_ref " ,
148+ name_1 = "output " )
127149
128150
129151@pytest .mark .skip (
@@ -133,7 +155,7 @@ def test_chat(
133155@pytest .mark .parametrize ("model" , MODELS )
134156@pytest .mark .parametrize ("dtype" , ["bfloat16" ])
135157def test_model_engine (vllm_runner , model : str , dtype : str ) -> None :
136- EXPECTED_ENGINE_LOGPROBS = load_logprobs (FIXTURE_LOGPROBS_ENGINE )
158+ EXPECTED_ENGINE_LOGPROBS = load_outputs_w_logprobs (FIXTURE_LOGPROBS_ENGINE )
137159 args = EngineArgs (
138160 model = model ,
139161 tokenizer_mode = "mistral" ,
@@ -162,7 +184,7 @@ def test_model_engine(vllm_runner, model: str, dtype: str) -> None:
162184 break
163185
164186 logprobs = vllm_runner ._final_steps_generate_w_logprobs (outputs )
165- check_logprobs_close (outputs_0_lst = logprobs ,
166- outputs_1_lst = EXPECTED_ENGINE_LOGPROBS ,
167- name_0 = "output " ,
168- name_1 = "h100_ref " )
187+ check_logprobs_close (outputs_0_lst = EXPECTED_ENGINE_LOGPROBS ,
188+ outputs_1_lst = logprobs ,
189+ name_0 = "h100_ref " ,
190+ name_1 = "output " )
0 commit comments