Skip to content

Commit 60d48cd

Browse files
author
sd109
committed
Handle required Mistral chat format explicitly
1 parent 708dfc5 commit 60d48cd

File tree

1 file changed

+22
-3
lines changed

1 file changed

+22
-3
lines changed

chart/web-app/app.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import requests
22
import warnings
3+
import re
34
import rich
45
import gradio as gr
56
from urllib.parse import urljoin
@@ -17,6 +18,16 @@
1718
backend_health_endpoint = urljoin(backend_url, "/health")
1819
backend_initialised = False
1920

21+
# NOTE(sd109): The Mistral family of models explicitly require a chat
22+
# history of the form user -> ai -> user -> ... and so don't like having
23+
# a SystemPrompt at the beginning. Since these models seem to be the
24+
# best around right now, it makes sense to treat them as special and make
25+
# sure the web app works correctly with them. To do so, we detect when a
26+
# mistral model is specified using this regex and then handle it explicitly
27+
# when contructing the `context` list in the `inference` function below.
28+
MISTRAL_REGEX = re.compile(r"mi(s|x)tral", re.IGNORECASE)
29+
IS_MISTRAL_MODEL = (MISTRAL_REGEX.match(settings.model_name) is not None)
30+
2031
llm = ChatOpenAI(
2132
base_url=urljoin(backend_url, "v1"),
2233
model = settings.model_name,
@@ -57,9 +68,17 @@ def inference(latest_message, history):
5768

5869

5970
try:
60-
context = [SystemMessage(content=settings.model_instruction)]
61-
for human, ai in history:
62-
context.append(HumanMessage(content=human))
71+
# To handle Mistral models we have to add the model instruction to
72+
# the first user message since Mistral requires user -> ai -> user
73+
# chat format and therefore doesn't allow system prompts.
74+
context = []
75+
if not IS_MISTRAL_MODEL:
76+
context.append(SystemMessage(content=settings.model_instruction))
77+
for i, (human, ai) in enumerate(history):
78+
if IS_MISTRAL_MODEL and i == 0:
79+
context.append(HumanMessage(content=f"{settings.model_instruction}\n\n{human}"))
80+
else:
81+
context.append(HumanMessage(content=human))
6382
context.append(AIMessage(content=ai))
6483
context.append(HumanMessage(content=latest_message))
6584

0 commit comments

Comments
 (0)