|
7 | 7 | from helm.common.media_object import IMAGE_TYPE, TEXT_TYPE |
8 | 8 | from helm.common.optional_dependencies import handle_module_not_found_error |
9 | 9 | from helm.common.request import wrap_request_time, Request, RequestResult, GeneratedOutput |
10 | | -from helm.tokenizers.tokenizer import Tokenizer |
11 | | -from helm.clients.client import CachingClient, truncate_and_tokenize_response_text |
| 10 | +from helm.clients.client import CachingClient |
12 | 11 |
|
13 | 12 | try: |
14 | 13 | from mistralai import Mistral |
@@ -38,16 +37,12 @@ class MistralAIClient(CachingClient): |
38 | 37 |
|
39 | 38 | def __init__( |
40 | 39 | self, |
41 | | - tokenizer: Tokenizer, |
42 | | - tokenizer_name: str, |
43 | 40 | cache_config: CacheConfig, |
44 | 41 | api_key: str, |
45 | 42 | mistral_model: Optional[str] = None, |
46 | 43 | ): |
47 | 44 | super().__init__(cache_config=cache_config) |
48 | 45 | self.api_key: str = api_key |
49 | | - self.tokenizer = tokenizer |
50 | | - self.tokenizer_name = tokenizer_name |
51 | 46 | self._client = Mistral(api_key=self.api_key) |
52 | 47 | self.mistral_model = mistral_model |
53 | 48 |
|
@@ -161,20 +156,35 @@ def do_it() -> Dict[str, Any]: |
161 | 156 | error: str = f"MistralClient error: {e}" |
162 | 157 | return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[]) |
163 | 158 |
|
164 | | - response_message: Dict[str, Any] = response["choices"][0]["message"] |
165 | | - assert response_message["role"] == "assistant" |
166 | | - response_text: str = response_message["content"] |
| 159 | + response, cached = self.cache.get(cache_key, wrap_request_time(do_it)) |
| 160 | + request_time = response["request_time"] |
| 161 | + del response["request_time"] |
| 162 | + request_datetime = response["request_datetime"] |
| 163 | + del response["request_datetime"] |
| 164 | + chat_completion_response = ChatCompletionResponse.model_validate(response) |
| 165 | + assert len(chat_completion_response.choices) == 1 |
| 166 | + choice = chat_completion_response.choices[0] |
| 167 | + response_message = choice.message |
| 168 | + assert response_message.role == "assistant" |
| 169 | + response_text = str(response_message.content) |
| 170 | + finish_reason = choice.finish_reason |
167 | 171 |
|
168 | 172 | # The Mistral API doesn't support echo. If `echo_prompt` is true, combine the prompt and completion. |
169 | | - text: str = request.prompt + response_text if request.echo_prompt else response_text |
170 | | - sequence = truncate_and_tokenize_response_text(text, request, self.tokenizer, self.tokenizer_name) |
171 | | - completions.append(sequence) |
| 173 | + output_text = request.prompt + response_text if request.echo_prompt else response_text |
| 174 | + completions.append( |
| 175 | + GeneratedOutput( |
| 176 | + text=output_text, |
| 177 | + logprob=0.0, # API doesn't support logprobs or tokens |
| 178 | + tokens=[], # API doesn't support logprobs or tokens |
| 179 | + finish_reason={"reason": finish_reason}, |
| 180 | + ) |
| 181 | + ) |
172 | 182 |
|
173 | 183 | return RequestResult( |
174 | 184 | success=True, |
175 | 185 | cached=cached, |
176 | | - request_time=response["request_time"], |
177 | | - request_datetime=response["request_datetime"], |
| 186 | + request_time=request_time, |
| 187 | + request_datetime=request_datetime, |
178 | 188 | completions=completions, |
179 | 189 | embedding=[], |
180 | 190 | ) |
0 commit comments