1
- import sys
2
1
import logging
2
+ import openai
3
+
3
4
import gradio as gr
4
- from urllib .parse import urljoin
5
- from config import AppSettings
6
5
6
+ from urllib .parse import urljoin
7
7
from langchain .schema import HumanMessage , AIMessage , SystemMessage
8
8
from langchain_openai import ChatOpenAI
9
- import openai
9
+ from typing import Dict , List
10
+ from pydantic import BaseModel , ConfigDict
11
+ from utils import LLMParams , load_settings
10
12
11
13
logging .basicConfig ()
12
14
logger = logging .getLogger (__name__ )
13
15
logger .setLevel (logging .INFO )
14
16
15
- logger .info ("Starting app" )
16
17
17
- settings = AppSettings .load ()
18
- if len (sys .argv ) > 1 :
19
- settings .hf_model_name = sys .argv [1 ]
20
- logger .info ("App settings: %s" , settings )
18
+ class AppSettings (BaseModel ):
19
+ # Basic config
20
+ host_address : str
21
+ backend_url : str
22
+ model_name : str
23
+ model_instruction : str
24
+ page_title : str
25
+ llm_params : LLMParams
26
+ # Theme customisation
27
+ theme_params : Dict [str , str | list ]
28
+ theme_params_extended : Dict [str , str ]
29
+ css_overrides : str | None
30
+ custom_javascript : str | None
31
+ # Error on typos and suppress warnings for fields with 'model_' prefix
32
+ model_config = ConfigDict (protected_namespaces = (), extra = "forbid" )
33
+
34
+
35
+ # class AppSettings(BaseModel):
36
+ # hf_model_name: str = Field(
37
+ # description="The model to use when constructing the LLM Chat client. This should match the model name running on the vLLM backend",
38
+ # )
39
+ # backend_url: HttpUrl = Field(
40
+ # description="The address of the OpenAI compatible API server (either in-cluster or externally hosted)"
41
+ # )
42
+ # page_title: str = Field(default="Large Language Model")
43
+ # page_description: Optional[str] = Field(default=None)
44
+ # hf_model_instruction: str = Field(
45
+ # default="You are a helpful and cheerful AI assistant. Please respond appropriately."
46
+ # )
47
+
48
+ # # Model settings
49
+
50
+ # # For available parameters, see https://docs.vllm.ai/en/latest/dev/sampling_params.html
51
+ # # which is based on https://platform.openai.com/docs/api-reference/completions/create
52
+ # llm_max_tokens: int = Field(default=500)
53
+ # llm_temperature: float = Field(default=0)
54
+ # llm_top_p: float = Field(default=1)
55
+ # llm_top_k: float = Field(default=-1)
56
+ # llm_presence_penalty: float = Field(default=0, ge=-2, le=2)
57
+ # llm_frequency_penalty: float = Field(default=0, ge=-2, le=2)
58
+
59
+ # # UI theming
60
+
61
+ # # Variables explicitly passed to gradio.theme.Default()
62
+ # # For example:
63
+ # # {"primary_hue": "red"}
64
+ # theme_params: dict[str, Union[str, List[str]]] = Field(default_factory=dict)
65
+ # # Overrides for theme.body_background_fill property
66
+ # theme_background_colour: Optional[str] = Field(default=None)
67
+ # # Provides arbitrary CSS and JS overrides to the UI,
68
+ # # see https://www.gradio.app/guides/custom-CSS-and-JS
69
+ # css_overrides: Optional[str] = Field(default=None)
70
+ # custom_javascript: Optional[str] = Field(default=None)
71
+
72
+
73
+ settings = AppSettings (** load_settings ())
74
+ logger .info (settings )
21
75
22
76
backend_url = str (settings .backend_url )
23
77
backend_health_endpoint = urljoin (backend_url , "/health" )
@@ -36,29 +90,19 @@ class PossibleSystemPromptException(Exception):
36
90
37
91
llm = ChatOpenAI (
38
92
base_url = urljoin (backend_url , "v1" ),
39
- model = settings .hf_model_name ,
93
+ model = settings .model_name ,
40
94
openai_api_key = "required-but-not-used" ,
41
- temperature = settings .llm_temperature ,
42
- max_tokens = settings .llm_max_tokens ,
43
- # model_kwargs={
44
- # "top_p": settings.llm_top_p,
45
- # "frequency_penalty": settings.llm_frequency_penalty,
46
- # "presence_penalty": settings.llm_presence_penalty,
47
- # # Additional parameters supported by vLLM but not OpenAI API
48
- # # https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#extra-parameters
49
- # "extra_body": {
50
- # "top_k": settings.llm_top_k,
51
- # }
52
- top_p = settings .llm_top_p ,
53
- frequency_penalty = settings .llm_frequency_penalty ,
54
- presence_penalty = settings .llm_presence_penalty ,
55
- # Additional parameters supported by vLLM but not OpenAI API
56
- # https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#extra-parameters
95
+ temperature = settings .llm_params .temperature ,
96
+ max_tokens = settings .llm_params .max_tokens ,
97
+ top_p = settings .llm_params .top_p ,
98
+ frequency_penalty = settings .llm_params .frequency_penalty ,
99
+ presence_penalty = settings .llm_params .presence_penalty ,
57
100
extra_body = {
58
- "top_k" : settings .llm_top_k ,
101
+ "top_k" : settings .llm_params . top_k ,
59
102
},
60
103
streaming = True ,
61
104
)
105
+ logger .info (llm )
62
106
63
107
64
108
def inference (latest_message , history ):
@@ -67,13 +111,13 @@ def inference(latest_message, history):
67
111
68
112
try :
69
113
if INCLUDE_SYSTEM_PROMPT :
70
- context = [SystemMessage (content = settings .hf_model_instruction )]
114
+ context = [SystemMessage (content = settings .model_instruction )]
71
115
else :
72
116
context = []
73
117
for i , (human , ai ) in enumerate (history ):
74
118
if not INCLUDE_SYSTEM_PROMPT and i == 0 :
75
119
# Mimic system prompt by prepending it to first human message
76
- human = f"{ settings .hf_model_instruction } \n \n { human } "
120
+ human = f"{ settings .model_instruction } \n \n { human } "
77
121
context .append (HumanMessage (content = human ))
78
122
context .append (AIMessage (content = (ai or "" )))
79
123
context .append (HumanMessage (content = latest_message ))
@@ -131,8 +175,8 @@ def inference(latest_message, history):
131
175
132
176
# UI theming
133
177
theme = gr .themes .Default (** settings .theme_params )
134
- if settings .theme_background_colour :
135
- theme .body_background_fill = settings . theme_background_colour
178
+ theme . set ( ** settings .theme_params_extended )
179
+ # theme.set(text)
136
180
137
181
138
182
def inference_wrapper (* args ):
@@ -153,7 +197,7 @@ def inference_wrapper(*args):
153
197
154
198
155
199
# Build main chat interface
156
- with gr .ChatInterface (
200
+ app = gr .ChatInterface (
157
201
inference_wrapper ,
158
202
chatbot = gr .Chatbot (
159
203
# Height of conversation window in CSS units (string) or pixels (int)
@@ -167,24 +211,54 @@ def inference_wrapper(*args):
167
211
scale = 7 ,
168
212
),
169
213
title = settings .page_title ,
170
- description = settings .page_description ,
171
214
retry_btn = "Retry" ,
172
215
undo_btn = "Undo" ,
173
216
clear_btn = "Clear" ,
174
217
analytics_enabled = False ,
175
218
theme = theme ,
176
219
css = settings .css_overrides ,
177
220
js = settings .custom_javascript ,
178
- ) as app :
179
- logger .debug ("Gradio chat interface config: %s" , app .config )
180
- # For running locally in tilt dev setup
181
- if len (sys .argv ) > 2 and sys .argv [2 ] == "localhost" :
182
- app .launch ()
183
- # For running on cluster
184
- else :
185
- app .queue (
186
- # Allow 10 concurrent requests to backend
187
- # vLLM backend should be clever enough to
188
- # batch these requests appropriately.
189
- default_concurrency_limit = 10 ,
190
- ).launch (server_name = "0.0.0.0" )
221
+ )
222
+ logger .debug ("Gradio chat interface config: %s" , app .config )
223
+ app .queue (
224
+ # Allow 10 concurrent requests to backend
225
+ # vLLM backend should be clever enough to
226
+ # batch these requests appropriately.
227
+ default_concurrency_limit = 10 ,
228
+ ).launch (server_name = settings .host_address )
229
+
230
+ # with gr.ChatInterface(
231
+ # inference_wrapper,
232
+ # chatbot=gr.Chatbot(
233
+ # # Height of conversation window in CSS units (string) or pixels (int)
234
+ # height="68vh",
235
+ # show_copy_button=True,
236
+ # ),
237
+ # textbox=gr.Textbox(
238
+ # placeholder="Ask me anything...",
239
+ # container=False,
240
+ # # Ratio of text box to submit button width
241
+ # scale=7,
242
+ # ),
243
+ # title=settings.page_title,
244
+ # description=settings.page_description,
245
+ # retry_btn="Retry",
246
+ # undo_btn="Undo",
247
+ # clear_btn="Clear",
248
+ # analytics_enabled=False,
249
+ # theme=theme,
250
+ # css=settings.css_overrides,
251
+ # js=settings.custom_javascript,
252
+ # ) as app:
253
+ # logger.debug("Gradio chat interface config: %s", app.config)
254
+ # # For running locally in tilt dev setup
255
+ # if len(sys.argv) > 2 and sys.argv[2] == "localhost":
256
+ # app.launch()
257
+ # # For running on cluster
258
+ # else:
259
+ # app.queue(
260
+ # # Allow 10 concurrent requests to backend
261
+ # # vLLM backend should be clever enough to
262
+ # # batch these requests appropriately.
263
+ # default_concurrency_limit=10,
264
+ # ).launch(server_name=settings.host_address)
0 commit comments