1- import sys
21import logging
2+ import openai
3+
34import gradio as gr
4- from urllib .parse import urljoin
5- from config import AppSettings
65
6+ from urllib .parse import urljoin
77from langchain .schema import HumanMessage , AIMessage , SystemMessage
88from langchain_openai import ChatOpenAI
9- import openai
9+ from typing import Dict , List
10+ from pydantic import BaseModel , ConfigDict
11+ from utils import LLMParams , load_settings
1012
1113logging .basicConfig ()
1214logger = logging .getLogger (__name__ )
1315logger .setLevel (logging .INFO )
1416
15- logger .info ("Starting app" )
1617
17- settings = AppSettings .load ()
18- if len (sys .argv ) > 1 :
19- settings .hf_model_name = sys .argv [1 ]
20- logger .info ("App settings: %s" , settings )
18+ class AppSettings (BaseModel ):
19+ # Basic config
20+ host_address : str
21+ backend_url : str
22+ model_name : str
23+ model_instruction : str
24+ page_title : str
25+ llm_params : LLMParams
26+ # Theme customisation
27+ theme_params : Dict [str , str | list ]
28+ theme_params_extended : Dict [str , str ]
29+ css_overrides : str | None
30+ custom_javascript : str | None
31+ # Error on typos and suppress warnings for fields with 'model_' prefix
32+ model_config = ConfigDict (protected_namespaces = (), extra = "forbid" )
33+
34+
35+ # class AppSettings(BaseModel):
36+ # hf_model_name: str = Field(
37+ # description="The model to use when constructing the LLM Chat client. This should match the model name running on the vLLM backend",
38+ # )
39+ # backend_url: HttpUrl = Field(
40+ # description="The address of the OpenAI compatible API server (either in-cluster or externally hosted)"
41+ # )
42+ # page_title: str = Field(default="Large Language Model")
43+ # page_description: Optional[str] = Field(default=None)
44+ # hf_model_instruction: str = Field(
45+ # default="You are a helpful and cheerful AI assistant. Please respond appropriately."
46+ # )
47+
48+ # # Model settings
49+
50+ # # For available parameters, see https://docs.vllm.ai/en/latest/dev/sampling_params.html
51+ # # which is based on https://platform.openai.com/docs/api-reference/completions/create
52+ # llm_max_tokens: int = Field(default=500)
53+ # llm_temperature: float = Field(default=0)
54+ # llm_top_p: float = Field(default=1)
55+ # llm_top_k: float = Field(default=-1)
56+ # llm_presence_penalty: float = Field(default=0, ge=-2, le=2)
57+ # llm_frequency_penalty: float = Field(default=0, ge=-2, le=2)
58+
59+ # # UI theming
60+
61+ # # Variables explicitly passed to gradio.theme.Default()
62+ # # For example:
63+ # # {"primary_hue": "red"}
64+ # theme_params: dict[str, Union[str, List[str]]] = Field(default_factory=dict)
65+ # # Overrides for theme.body_background_fill property
66+ # theme_background_colour: Optional[str] = Field(default=None)
67+ # # Provides arbitrary CSS and JS overrides to the UI,
68+ # # see https://www.gradio.app/guides/custom-CSS-and-JS
69+ # css_overrides: Optional[str] = Field(default=None)
70+ # custom_javascript: Optional[str] = Field(default=None)
71+
72+
73+ settings = AppSettings (** load_settings ())
74+ logger .info (settings )
2175
2276backend_url = str (settings .backend_url )
2377backend_health_endpoint = urljoin (backend_url , "/health" )
@@ -36,29 +90,19 @@ class PossibleSystemPromptException(Exception):
3690
3791llm = ChatOpenAI (
3892 base_url = urljoin (backend_url , "v1" ),
39- model = settings .hf_model_name ,
93+ model = settings .model_name ,
4094 openai_api_key = "required-but-not-used" ,
41- temperature = settings .llm_temperature ,
42- max_tokens = settings .llm_max_tokens ,
43- # model_kwargs={
44- # "top_p": settings.llm_top_p,
45- # "frequency_penalty": settings.llm_frequency_penalty,
46- # "presence_penalty": settings.llm_presence_penalty,
47- # # Additional parameters supported by vLLM but not OpenAI API
48- # # https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#extra-parameters
49- # "extra_body": {
50- # "top_k": settings.llm_top_k,
51- # }
52- top_p = settings .llm_top_p ,
53- frequency_penalty = settings .llm_frequency_penalty ,
54- presence_penalty = settings .llm_presence_penalty ,
55- # Additional parameters supported by vLLM but not OpenAI API
56- # https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#extra-parameters
95+ temperature = settings .llm_params .temperature ,
96+ max_tokens = settings .llm_params .max_tokens ,
97+ top_p = settings .llm_params .top_p ,
98+ frequency_penalty = settings .llm_params .frequency_penalty ,
99+ presence_penalty = settings .llm_params .presence_penalty ,
57100 extra_body = {
58- "top_k" : settings .llm_top_k ,
101+ "top_k" : settings .llm_params . top_k ,
59102 },
60103 streaming = True ,
61104)
105+ logger .info (llm )
62106
63107
64108def inference (latest_message , history ):
@@ -67,13 +111,13 @@ def inference(latest_message, history):
67111
68112 try :
69113 if INCLUDE_SYSTEM_PROMPT :
70- context = [SystemMessage (content = settings .hf_model_instruction )]
114+ context = [SystemMessage (content = settings .model_instruction )]
71115 else :
72116 context = []
73117 for i , (human , ai ) in enumerate (history ):
74118 if not INCLUDE_SYSTEM_PROMPT and i == 0 :
75119 # Mimic system prompt by prepending it to first human message
76- human = f"{ settings .hf_model_instruction } \n \n { human } "
120+ human = f"{ settings .model_instruction } \n \n { human } "
77121 context .append (HumanMessage (content = human ))
78122 context .append (AIMessage (content = (ai or "" )))
79123 context .append (HumanMessage (content = latest_message ))
@@ -131,8 +175,8 @@ def inference(latest_message, history):
131175
132176# UI theming
133177theme = gr .themes .Default (** settings .theme_params )
134- if settings .theme_background_colour :
135- theme .body_background_fill = settings . theme_background_colour
178+ theme . set ( ** settings .theme_params_extended )
179+ # theme.set(text)
136180
137181
138182def inference_wrapper (* args ):
@@ -153,7 +197,7 @@ def inference_wrapper(*args):
153197
154198
155199# Build main chat interface
156- with gr .ChatInterface (
200+ app = gr .ChatInterface (
157201 inference_wrapper ,
158202 chatbot = gr .Chatbot (
159203 # Height of conversation window in CSS units (string) or pixels (int)
@@ -167,24 +211,54 @@ def inference_wrapper(*args):
167211 scale = 7 ,
168212 ),
169213 title = settings .page_title ,
170- description = settings .page_description ,
171214 retry_btn = "Retry" ,
172215 undo_btn = "Undo" ,
173216 clear_btn = "Clear" ,
174217 analytics_enabled = False ,
175218 theme = theme ,
176219 css = settings .css_overrides ,
177220 js = settings .custom_javascript ,
178- ) as app :
179- logger .debug ("Gradio chat interface config: %s" , app .config )
180- # For running locally in tilt dev setup
181- if len (sys .argv ) > 2 and sys .argv [2 ] == "localhost" :
182- app .launch ()
183- # For running on cluster
184- else :
185- app .queue (
186- # Allow 10 concurrent requests to backend
187- # vLLM backend should be clever enough to
188- # batch these requests appropriately.
189- default_concurrency_limit = 10 ,
190- ).launch (server_name = "0.0.0.0" )
221+ )
222+ logger .debug ("Gradio chat interface config: %s" , app .config )
223+ app .queue (
224+ # Allow 10 concurrent requests to backend
225+ # vLLM backend should be clever enough to
226+ # batch these requests appropriately.
227+ default_concurrency_limit = 10 ,
228+ ).launch (server_name = settings .host_address )
229+
230+ # with gr.ChatInterface(
231+ # inference_wrapper,
232+ # chatbot=gr.Chatbot(
233+ # # Height of conversation window in CSS units (string) or pixels (int)
234+ # height="68vh",
235+ # show_copy_button=True,
236+ # ),
237+ # textbox=gr.Textbox(
238+ # placeholder="Ask me anything...",
239+ # container=False,
240+ # # Ratio of text box to submit button width
241+ # scale=7,
242+ # ),
243+ # title=settings.page_title,
244+ # description=settings.page_description,
245+ # retry_btn="Retry",
246+ # undo_btn="Undo",
247+ # clear_btn="Clear",
248+ # analytics_enabled=False,
249+ # theme=theme,
250+ # css=settings.css_overrides,
251+ # js=settings.custom_javascript,
252+ # ) as app:
253+ # logger.debug("Gradio chat interface config: %s", app.config)
254+ # # For running locally in tilt dev setup
255+ # if len(sys.argv) > 2 and sys.argv[2] == "localhost":
256+ # app.launch()
257+ # # For running on cluster
258+ # else:
259+ # app.queue(
260+ # # Allow 10 concurrent requests to backend
261+ # # vLLM backend should be clever enough to
262+ # # batch these requests appropriately.
263+ # default_concurrency_limit=10,
264+ # ).launch(server_name=settings.host_address)
0 commit comments