1
- import requests
2
- import re
1
+ import sys
3
2
import logging
4
3
import gradio as gr
5
4
from urllib .parse import urljoin
16
15
logger .info ("Starting app" )
17
16
18
17
settings = AppSettings .load ("./settings.yml" )
18
+ if len (sys .argv ) > 1 :
19
+ settings .hf_model_name = sys .argv [1 ]
19
20
logger .info ("App settings: %s" , settings )
20
21
21
22
backend_url = str (settings .backend_url )
22
23
backend_health_endpoint = urljoin (backend_url , "/health" )
23
24
BACKEND_INITIALISED = False
24
25
25
- # # NOTE(sd109): The Mistral family of models explicitly require a chat
26
- # # history of the form user -> ai -> user -> ... and so don't like having
27
- # # a SystemPrompt at the beginning. Since these models seem to be the
28
- # # best around right now, it makes sense to treat them as special and make
29
- # # sure the web app works correctly with them. To do so, we detect when a
30
- # # mistral model is specified using this regex and then handle it explicitly
31
- # # when contructing the `context` list in the `inference` function below.
32
- # MISTRAL_REGEX = re.compile(r".*mi(s|x)tral.*", re.IGNORECASE)
33
- # IS_MISTRAL_MODEL = MISTRAL_REGEX.match(settings.model_name) is not None
34
- # if IS_MISTRAL_MODEL:
35
- # print(
36
- # "Detected Mistral model - will alter LangChain conversation format appropriately."
37
- # )
38
-
39
26
# Some models disallow 'system' role's their conversation history by raising errors in their chat prompt template, e.g. see
40
27
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2/blob/cf47bb3e18fe41a5351bc36eef76e9c900847c89/tokenizer_config.json#L42
41
28
# Detecting this ahead of time is difficult so for now we use a global variable which stores whether the API has
42
29
# responded with a HTTP 400 error and retry request without system role replaced by
43
30
INCLUDE_SYSTEM_PROMPT = True
31
+ class PossibleSystemPromptException (Exception ):
32
+ pass
44
33
45
34
llm = ChatOpenAI (
46
35
base_url = urljoin (backend_url , "v1" ),
47
- model = settings .model_name ,
36
+ model = settings .hf_model_name ,
48
37
openai_api_key = "required-but-not-used" ,
49
38
temperature = settings .llm_temperature ,
50
39
max_tokens = settings .llm_max_tokens ,
58
47
59
48
60
49
def inference (latest_message , history ):
61
- # Check backend health and warn the user on error
62
- # try:
63
- # response = requests.get(backend_health_endpoint, timeout=5)
64
- # response_code = response.status_code
65
- # if response_code == 200:
66
- # global backend_initialised
67
- # if not backend_initialised:
68
- # # Record the fact that backend was up at one point so we know that
69
- # # any future errors are not related to slow model initialisation
70
- # backend_initialised = True
71
- # elif response_code >= 400 and response_code < 500:
72
- # logging.warn(f"Received HTTP {response_code} response from backend. Full response: {response.text}")
73
- # else:
74
- # # If the server's running (i.e. we get a response) but it's not an HTTP 200
75
- # # we just hope Kubernetes reconciles things for us eventually..
76
- # raise gr.Error("Backend unhealthy - please try again later")
77
- # except Exception as err:
78
- # warnings.warn(f"Error while checking backend health: {err}")
79
- # if backend_initialised:
80
- # # If backend was previously reachable then something unexpected has gone wrong
81
- # raise gr.Error("Backend unreachable")
82
- # else:
83
- # # In this case backend is probably still busy downloading model weights
84
- # raise gr.Error("Backend not ready yet - please try again later")
85
-
86
- # try:
87
- # # To handle Mistral models we have to add the model instruction to
88
- # # the first user message since Mistral requires user -> ai -> user
89
- # # chat format and therefore doesn't allow system prompts.
90
- # context = []
91
- # if not IS_MISTRAL_MODEL:
92
- # context.append(SystemMessage(content=settings.model_instruction))
93
- # for i, (human, ai) in enumerate(history):
94
- # if IS_MISTRAL_MODEL and i == 0:
95
- # context.append(
96
- # HumanMessage(content=f"{settings.model_instruction}\n\n{human}")
97
- # )
98
- # else:
99
- # context.append(HumanMessage(content=human))
100
- # context.append(AIMessage(content=ai))
101
- # context.append(HumanMessage(content=latest_message))
102
-
103
- # response = ""
104
- # for chunk in llm.stream(context):
105
- # # print(chunk)
106
- # # NOTE(sd109): For some reason the '>' character breaks the UI
107
- # # so we need to escape it here.
108
- # # response += chunk.content.replace('>', '\>')
109
- # # UPDATE(sd109): Above bug seems to have been fixed as of gradio 4.15.0
110
- # # but keeping this note here incase we enounter it again
111
- # response += chunk.content
112
- # yield response
113
-
114
- # # For all other errors notify user and log a more detailed warning
115
- # except Exception as err:
116
- # warnings.warn(f"Exception encountered while generating response: {err}")
117
- # raise gr.Error(
118
- # "Unknown error encountered - see application logs for more information."
119
- # )
120
-
121
50
122
51
# Allow mutating global variables
123
52
global BACKEND_INITIALISED , INCLUDE_SYSTEM_PROMPT
@@ -126,22 +55,24 @@ def inference(latest_message, history):
126
55
# Attempt to handle models which disallow system prompts
127
56
# Construct conversation history for model prompt
128
57
if INCLUDE_SYSTEM_PROMPT :
129
- context = [SystemMessage (content = settings .model_instruction )]
58
+ context = [SystemMessage (content = settings .hf_model_instruction )]
130
59
else :
131
60
context = []
132
61
for i , (human , ai ) in enumerate (history ):
133
62
if not INCLUDE_SYSTEM_PROMPT and i == 0 :
134
63
# Mimic system prompt by prepending it to first human message
135
- human = f"{ settings .model_instruction } \n \n { human } "
64
+ human = f"{ settings .hf_model_instruction } \n \n { human } "
136
65
context .append (HumanMessage (content = human ))
137
66
context .append (AIMessage (content = (ai or "" )))
138
67
context .append (HumanMessage (content = latest_message ))
68
+ logger .debug ("Chat context: %s" , context )
139
69
140
70
response = ""
141
71
for chunk in llm .stream (context ):
142
72
143
73
# If this is our first successful response from the backend
144
- # then update the status variable
74
+ # then update the status variable to allow future error messages
75
+ # to be more informative
145
76
if not BACKEND_INITIALISED and len (response ) > 0 :
146
77
BACKEND_INITIALISED = True
147
78
@@ -157,11 +88,11 @@ def inference(latest_message, history):
157
88
logger .error ("Received BadRequestError from backend API: %s" , err )
158
89
message = err .response .json ()['message' ]
159
90
if INCLUDE_SYSTEM_PROMPT :
160
- INCLUDE_SYSTEM_PROMPT = False
161
- # TODO: Somehow retry same inference step without system prompt
162
- pass
163
- ui_message = f"API Error received. This usually means the chosen LLM uses an incompatible prompt format. Error message was: { message } "
164
- raise gr .Error (ui_message )
91
+ raise PossibleSystemPromptException ()
92
+ else :
93
+ # In this case we've already tried without system prompt and still hit a bad request error so something else must be wrong
94
+ ui_message = f"API Error received. This usually means the chosen LLM uses an incompatible prompt format. Error message was: { message } "
95
+ raise gr .Error (ui_message )
165
96
166
97
except openai .APIConnectionError as err :
167
98
if not BACKEND_INITIALISED :
@@ -175,7 +106,7 @@ def inference(latest_message, history):
175
106
gr .Warning ("Internal server error encountered in backend API - see API logs for details." )
176
107
177
108
# Catch-all for unexpected exceptions
178
- except Exception as err :
109
+ except err :
179
110
logger .error ("Unexpected error during inference: %s" , err )
180
111
raise gr .Error ("Unexpected error encountered - see logs for details." )
181
112
@@ -197,9 +128,24 @@ def inference(latest_message, history):
197
128
)
198
129
199
130
131
+ def inference_wrapper (* args ):
132
+ """
133
+ Simple wrapper round the `inference` function which catches certain predictable errors
134
+ such as invalid prompty formats and attempts to mitigate them automatically.
135
+ """
136
+ try :
137
+ for chunk in inference (* args ):
138
+ yield chunk
139
+ except PossibleSystemPromptException :
140
+ logger .warning ("Disabling system prompt and retrying previous request" )
141
+ global INCLUDE_SYSTEM_PROMPT
142
+ INCLUDE_SYSTEM_PROMPT = False
143
+ for chunk in inference (* args ):
144
+ yield chunk
145
+
200
146
# Build main chat interface
201
147
with gr .ChatInterface (
202
- inference ,
148
+ inference_wrapper ,
203
149
chatbot = gr .Chatbot (
204
150
# Height of conversation window in CSS units (string) or pixels (int)
205
151
height = "68vh" ,
@@ -219,5 +165,6 @@ def inference(latest_message, history):
219
165
theme = theme ,
220
166
css = css_overrides ,
221
167
) as app :
168
+ logger .debug ("Gradio chat interface config: %s" , app .config )
222
169
# app.launch(server_name="0.0.0.0") # Do we need this for k8s service?
223
170
app .launch ()
0 commit comments