Skip to content

Commit bee138e

Browse files
author
sd109
committed
Retry without system prompt on HTTP 400 error
1 parent 1216494 commit bee138e

File tree

1 file changed

+34
-87
lines changed

1 file changed

+34
-87
lines changed

chart/web-app/app.py

Lines changed: 34 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
import requests
2-
import re
1+
import sys
32
import logging
43
import gradio as gr
54
from urllib.parse import urljoin
@@ -16,35 +15,25 @@
1615
logger.info("Starting app")
1716

1817
settings = AppSettings.load("./settings.yml")
18+
if len(sys.argv) > 1:
19+
settings.hf_model_name = sys.argv[1]
1920
logger.info("App settings: %s", settings)
2021

2122
backend_url = str(settings.backend_url)
2223
backend_health_endpoint = urljoin(backend_url, "/health")
2324
BACKEND_INITIALISED = False
2425

25-
# # NOTE(sd109): The Mistral family of models explicitly require a chat
26-
# # history of the form user -> ai -> user -> ... and so don't like having
27-
# # a SystemPrompt at the beginning. Since these models seem to be the
28-
# # best around right now, it makes sense to treat them as special and make
29-
# # sure the web app works correctly with them. To do so, we detect when a
30-
# # mistral model is specified using this regex and then handle it explicitly
31-
# # when contructing the `context` list in the `inference` function below.
32-
# MISTRAL_REGEX = re.compile(r".*mi(s|x)tral.*", re.IGNORECASE)
33-
# IS_MISTRAL_MODEL = MISTRAL_REGEX.match(settings.model_name) is not None
34-
# if IS_MISTRAL_MODEL:
35-
# print(
36-
# "Detected Mistral model - will alter LangChain conversation format appropriately."
37-
# )
38-
3926
# Some models disallow 'system' role's their conversation history by raising errors in their chat prompt template, e.g. see
4027
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2/blob/cf47bb3e18fe41a5351bc36eef76e9c900847c89/tokenizer_config.json#L42
4128
# Detecting this ahead of time is difficult so for now we use a global variable which stores whether the API has
4229
# responded with a HTTP 400 error and retry request without system role replaced by
4330
INCLUDE_SYSTEM_PROMPT = True
31+
class PossibleSystemPromptException(Exception):
32+
pass
4433

4534
llm = ChatOpenAI(
4635
base_url=urljoin(backend_url, "v1"),
47-
model=settings.model_name,
36+
model=settings.hf_model_name,
4837
openai_api_key="required-but-not-used",
4938
temperature=settings.llm_temperature,
5039
max_tokens=settings.llm_max_tokens,
@@ -58,66 +47,6 @@
5847

5948

6049
def inference(latest_message, history):
61-
# Check backend health and warn the user on error
62-
# try:
63-
# response = requests.get(backend_health_endpoint, timeout=5)
64-
# response_code = response.status_code
65-
# if response_code == 200:
66-
# global backend_initialised
67-
# if not backend_initialised:
68-
# # Record the fact that backend was up at one point so we know that
69-
# # any future errors are not related to slow model initialisation
70-
# backend_initialised = True
71-
# elif response_code >= 400 and response_code < 500:
72-
# logging.warn(f"Received HTTP {response_code} response from backend. Full response: {response.text}")
73-
# else:
74-
# # If the server's running (i.e. we get a response) but it's not an HTTP 200
75-
# # we just hope Kubernetes reconciles things for us eventually..
76-
# raise gr.Error("Backend unhealthy - please try again later")
77-
# except Exception as err:
78-
# warnings.warn(f"Error while checking backend health: {err}")
79-
# if backend_initialised:
80-
# # If backend was previously reachable then something unexpected has gone wrong
81-
# raise gr.Error("Backend unreachable")
82-
# else:
83-
# # In this case backend is probably still busy downloading model weights
84-
# raise gr.Error("Backend not ready yet - please try again later")
85-
86-
# try:
87-
# # To handle Mistral models we have to add the model instruction to
88-
# # the first user message since Mistral requires user -> ai -> user
89-
# # chat format and therefore doesn't allow system prompts.
90-
# context = []
91-
# if not IS_MISTRAL_MODEL:
92-
# context.append(SystemMessage(content=settings.model_instruction))
93-
# for i, (human, ai) in enumerate(history):
94-
# if IS_MISTRAL_MODEL and i == 0:
95-
# context.append(
96-
# HumanMessage(content=f"{settings.model_instruction}\n\n{human}")
97-
# )
98-
# else:
99-
# context.append(HumanMessage(content=human))
100-
# context.append(AIMessage(content=ai))
101-
# context.append(HumanMessage(content=latest_message))
102-
103-
# response = ""
104-
# for chunk in llm.stream(context):
105-
# # print(chunk)
106-
# # NOTE(sd109): For some reason the '>' character breaks the UI
107-
# # so we need to escape it here.
108-
# # response += chunk.content.replace('>', '\>')
109-
# # UPDATE(sd109): Above bug seems to have been fixed as of gradio 4.15.0
110-
# # but keeping this note here incase we enounter it again
111-
# response += chunk.content
112-
# yield response
113-
114-
# # For all other errors notify user and log a more detailed warning
115-
# except Exception as err:
116-
# warnings.warn(f"Exception encountered while generating response: {err}")
117-
# raise gr.Error(
118-
# "Unknown error encountered - see application logs for more information."
119-
# )
120-
12150

12251
# Allow mutating global variables
12352
global BACKEND_INITIALISED, INCLUDE_SYSTEM_PROMPT
@@ -126,22 +55,24 @@ def inference(latest_message, history):
12655
# Attempt to handle models which disallow system prompts
12756
# Construct conversation history for model prompt
12857
if INCLUDE_SYSTEM_PROMPT:
129-
context = [SystemMessage(content=settings.model_instruction)]
58+
context = [SystemMessage(content=settings.hf_model_instruction)]
13059
else:
13160
context = []
13261
for i, (human, ai) in enumerate(history):
13362
if not INCLUDE_SYSTEM_PROMPT and i == 0:
13463
# Mimic system prompt by prepending it to first human message
135-
human = f"{settings.model_instruction}\n\n{human}"
64+
human = f"{settings.hf_model_instruction}\n\n{human}"
13665
context.append(HumanMessage(content=human))
13766
context.append(AIMessage(content=(ai or "")))
13867
context.append(HumanMessage(content=latest_message))
68+
logger.debug("Chat context: %s", context)
13969

14070
response = ""
14171
for chunk in llm.stream(context):
14272

14373
# If this is our first successful response from the backend
144-
# then update the status variable
74+
# then update the status variable to allow future error messages
75+
# to be more informative
14576
if not BACKEND_INITIALISED and len(response) > 0:
14677
BACKEND_INITIALISED = True
14778

@@ -157,11 +88,11 @@ def inference(latest_message, history):
15788
logger.error("Received BadRequestError from backend API: %s", err)
15889
message = err.response.json()['message']
15990
if INCLUDE_SYSTEM_PROMPT:
160-
INCLUDE_SYSTEM_PROMPT = False
161-
# TODO: Somehow retry same inference step without system prompt
162-
pass
163-
ui_message = f"API Error received. This usually means the chosen LLM uses an incompatible prompt format. Error message was: {message}"
164-
raise gr.Error(ui_message)
91+
raise PossibleSystemPromptException()
92+
else:
93+
# In this case we've already tried without system prompt and still hit a bad request error so something else must be wrong
94+
ui_message = f"API Error received. This usually means the chosen LLM uses an incompatible prompt format. Error message was: {message}"
95+
raise gr.Error(ui_message)
16596

16697
except openai.APIConnectionError as err:
16798
if not BACKEND_INITIALISED:
@@ -175,7 +106,7 @@ def inference(latest_message, history):
175106
gr.Warning("Internal server error encountered in backend API - see API logs for details.")
176107

177108
# Catch-all for unexpected exceptions
178-
except Exception as err:
109+
except err:
179110
logger.error("Unexpected error during inference: %s", err)
180111
raise gr.Error("Unexpected error encountered - see logs for details.")
181112

@@ -197,9 +128,24 @@ def inference(latest_message, history):
197128
)
198129

199130

131+
def inference_wrapper(*args):
132+
"""
133+
Simple wrapper round the `inference` function which catches certain predictable errors
134+
such as invalid prompty formats and attempts to mitigate them automatically.
135+
"""
136+
try:
137+
for chunk in inference(*args):
138+
yield chunk
139+
except PossibleSystemPromptException:
140+
logger.warning("Disabling system prompt and retrying previous request")
141+
global INCLUDE_SYSTEM_PROMPT
142+
INCLUDE_SYSTEM_PROMPT = False
143+
for chunk in inference(*args):
144+
yield chunk
145+
200146
# Build main chat interface
201147
with gr.ChatInterface(
202-
inference,
148+
inference_wrapper,
203149
chatbot=gr.Chatbot(
204150
# Height of conversation window in CSS units (string) or pixels (int)
205151
height="68vh",
@@ -219,5 +165,6 @@ def inference(latest_message, history):
219165
theme=theme,
220166
css=css_overrides,
221167
) as app:
168+
logger.debug("Gradio chat interface config: %s", app.config)
222169
# app.launch(server_name="0.0.0.0") # Do we need this for k8s service?
223170
app.launch()

0 commit comments

Comments
 (0)