|
1 | 1 | import requests
|
2 | 2 | import warnings
|
| 3 | +import re |
3 | 4 | import rich
|
4 | 5 | import gradio as gr
|
5 | 6 | from urllib.parse import urljoin
|
|
17 | 18 | backend_health_endpoint = urljoin(backend_url, "/health")
|
18 | 19 | backend_initialised = False
|
19 | 20 |
|
| 21 | +# NOTE(sd109): The Mistral family of models explicitly require a chat |
| 22 | +# history of the form user -> ai -> user -> ... and so don't like having |
| 23 | +# a SystemPrompt at the beginning. Since these models seem to be the |
| 24 | +# best around right now, it makes sense to treat them as special and make |
| 25 | +# sure the web app works correctly with them. To do so, we detect when a |
| 26 | +# mistral model is specified using this regex and then handle it explicitly |
| 27 | +# when contructing the `context` list in the `inference` function below. |
| 28 | +MISTRAL_REGEX = re.compile(r".*mi(s|x)tral.*", re.IGNORECASE) |
| 29 | +IS_MISTRAL_MODEL = (MISTRAL_REGEX.match(settings.model_name) is not None) |
| 30 | +if IS_MISTRAL_MODEL: |
| 31 | + print("Detected Mistral model - will alter LangChain conversation format appropriately.") |
| 32 | + |
20 | 33 | llm = ChatOpenAI(
|
21 | 34 | base_url=urljoin(backend_url, "v1"),
|
22 | 35 | model = settings.model_name,
|
@@ -57,9 +70,17 @@ def inference(latest_message, history):
|
57 | 70 |
|
58 | 71 |
|
59 | 72 | try:
|
60 |
| - context = [SystemMessage(content=settings.model_instruction)] |
61 |
| - for human, ai in history: |
62 |
| - context.append(HumanMessage(content=human)) |
| 73 | + # To handle Mistral models we have to add the model instruction to |
| 74 | + # the first user message since Mistral requires user -> ai -> user |
| 75 | + # chat format and therefore doesn't allow system prompts. |
| 76 | + context = [] |
| 77 | + if not IS_MISTRAL_MODEL: |
| 78 | + context.append(SystemMessage(content=settings.model_instruction)) |
| 79 | + for i, (human, ai) in enumerate(history): |
| 80 | + if IS_MISTRAL_MODEL and i == 0: |
| 81 | + context.append(HumanMessage(content=f"{settings.model_instruction}\n\n{human}")) |
| 82 | + else: |
| 83 | + context.append(HumanMessage(content=human)) |
63 | 84 | context.append(AIMessage(content=ai))
|
64 | 85 | context.append(HumanMessage(content=latest_message))
|
65 | 86 |
|
|
0 commit comments