Skip to content

Commit 2eac98c

Browse files
authored
Remove tokenization from MistralAIClient (#4084)
1 parent 8e675d8 commit 2eac98c

File tree

1 file changed

+24
-14
lines changed

1 file changed

+24
-14
lines changed

src/helm/clients/mistral_client.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77
from helm.common.media_object import IMAGE_TYPE, TEXT_TYPE
88
from helm.common.optional_dependencies import handle_module_not_found_error
99
from helm.common.request import wrap_request_time, Request, RequestResult, GeneratedOutput
10-
from helm.tokenizers.tokenizer import Tokenizer
11-
from helm.clients.client import CachingClient, truncate_and_tokenize_response_text
10+
from helm.clients.client import CachingClient
1211

1312
try:
1413
from mistralai import Mistral
@@ -38,16 +37,12 @@ class MistralAIClient(CachingClient):
3837

3938
def __init__(
4039
self,
41-
tokenizer: Tokenizer,
42-
tokenizer_name: str,
4340
cache_config: CacheConfig,
4441
api_key: str,
4542
mistral_model: Optional[str] = None,
4643
):
4744
super().__init__(cache_config=cache_config)
4845
self.api_key: str = api_key
49-
self.tokenizer = tokenizer
50-
self.tokenizer_name = tokenizer_name
5146
self._client = Mistral(api_key=self.api_key)
5247
self.mistral_model = mistral_model
5348

@@ -161,20 +156,35 @@ def do_it() -> Dict[str, Any]:
161156
error: str = f"MistralClient error: {e}"
162157
return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
163158

164-
response_message: Dict[str, Any] = response["choices"][0]["message"]
165-
assert response_message["role"] == "assistant"
166-
response_text: str = response_message["content"]
159+
response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
160+
request_time = response["request_time"]
161+
del response["request_time"]
162+
request_datetime = response["request_datetime"]
163+
del response["request_datetime"]
164+
chat_completion_response = ChatCompletionResponse.model_validate(response)
165+
assert len(chat_completion_response.choices) == 1
166+
choice = chat_completion_response.choices[0]
167+
response_message = choice.message
168+
assert response_message.role == "assistant"
169+
response_text = str(response_message.content)
170+
finish_reason = choice.finish_reason
167171

168172
# The Mistral API doesn't support echo. If `echo_prompt` is true, combine the prompt and completion.
169-
text: str = request.prompt + response_text if request.echo_prompt else response_text
170-
sequence = truncate_and_tokenize_response_text(text, request, self.tokenizer, self.tokenizer_name)
171-
completions.append(sequence)
173+
output_text = request.prompt + response_text if request.echo_prompt else response_text
174+
completions.append(
175+
GeneratedOutput(
176+
text=output_text,
177+
logprob=0.0, # API doesn't support logprobs or tokens
178+
tokens=[], # API doesn't support logprobs or tokens
179+
finish_reason={"reason": finish_reason},
180+
)
181+
)
172182

173183
return RequestResult(
174184
success=True,
175185
cached=cached,
176-
request_time=response["request_time"],
177-
request_datetime=response["request_datetime"],
186+
request_time=request_time,
187+
request_datetime=request_datetime,
178188
completions=completions,
179189
embedding=[],
180190
)

0 commit comments

Comments
 (0)