Skip to content

Commit 58d8544

Browse files
committed
use the openai package instead of plain requests
1 parent bdd0ecf commit 58d8544

File tree

2 files changed

+41
-53
lines changed

2 files changed

+41
-53
lines changed

gpt_code_ui/webapp/main.py

Lines changed: 40 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import re
77
import logging
88
import sys
9+
import openai
910
import pandas as pd
1011

1112
from collections import deque
@@ -18,13 +19,13 @@
1819

1920
load_dotenv('.env')
2021

21-
OPENAI_API_TYPE = os.environ.get("OPENAI_API_TYPE", "openai")
22-
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
23-
OPENAI_BASE_URL = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com")
24-
OPENAI_API_VERSION = os.environ.get("OPENAI_API_VERSION", "2023-03-15-preview")
22+
openai.api_base = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com")
23+
openai.api_type = os.environ.get("OPENAI_API_TYPE", "openai")
24+
openai.api_version = os.environ.get("OPENAI_API_VERSION", "2023-03-15-preview")
25+
openai.api_key = os.environ.get("OPENAI_API_KEY", "")
26+
openai.log = os.getenv("OPENAI_API_LOGLEVEL", "")
2527
AZURE_OPENAI_DEPLOYMENT = os.environ.get("AZURE_OPENAI_DEPLOYMENT", "")
2628

27-
2829
UPLOAD_FOLDER = 'workspace/'
2930
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
3031

@@ -114,52 +115,43 @@ async def get_code(user_prompt, user_openai_key=None, model="gpt-3.5-turbo"):
114115
If the user has just uploaded a file, focus on the file that was most recently uploaded (and optionally all previously uploaded files)
115116
116117
Teacher mode: if the code modifies or produces a file, at the end of the code block insert a print statement that prints a link to it as HTML string: <a href='/download?file=INSERT_FILENAME_HERE'>Download file</a>. Replace INSERT_FILENAME_HERE with the actual filename."""
117-
temperature = 0.7
118-
message_array = [
119-
{
120-
"role": "user",
121-
"content": prompt,
122-
},
123-
]
124-
125-
final_openai_key = OPENAI_API_KEY
118+
126119
if user_openai_key:
127-
final_openai_key = user_openai_key
128-
129-
if OPENAI_API_TYPE == "openai":
130-
data = {
131-
"model": model,
132-
"messages": message_array,
133-
"temperature": temperature,
134-
}
135-
headers = {
136-
"Content-Type": "application/json",
137-
"Authorization": f"Bearer {final_openai_key}",
138-
}
139-
140-
response = requests.post(
141-
f"{OPENAI_BASE_URL}/v1/chat/completions",
142-
data=json.dumps(data),
143-
headers=headers,
144-
)
145-
elif OPENAI_API_TYPE == "azure":
146-
data = {
147-
"messages": message_array,
148-
"temperature": temperature,
149-
}
150-
headers = {
151-
"Content-Type": "application/json",
152-
"api-key": f"{final_openai_key}",
153-
}
154-
155-
response = requests.post(
156-
f"{OPENAI_BASE_URL}/openai/deployments/{AZURE_OPENAI_DEPLOYMENT}/chat/completions?api-version={OPENAI_API_VERSION}",
157-
data=json.dumps(data),
158-
headers=headers,
159-
)
120+
openai.api_key = user_openai_key
121+
122+
arguments = dict(
123+
temperature=0.7,
124+
headers={'x-api-key': openai.api_key},
125+
messages=[
126+
# {"role": "system", "content": system},
127+
{"role": "user", "content": prompt},
128+
]
129+
)
130+
131+
if openai.api_type == 'openai':
132+
arguments["model"] = model
133+
elif openai.api_type == 'azure':
134+
arguments["deployment_id"] = AZURE_OPENAI_DEPLOYMENT
160135
else:
161-
return None, "Error: Invalid OPENAI_PROVIDER", 500
136+
return None, f"Error: Invalid OPENAI_PROVIDER: {openai.api_type}", 500
137+
138+
try:
139+
result_GPT = openai.ChatCompletion.create(**arguments)
140+
141+
if 'error' in result_GPT:
142+
raise openai.APIError(code=result_GPT.error.code, message=result_GPT.error.message)
143+
144+
if result_GPT.choices[0].finish_reason == 'content_filter':
145+
raise openai.APIError('Content Filter')
162146

147+
except openai.OpenAIError as e:
148+
return None, f"Error from API: {e}", 500
149+
150+
try:
151+
content = result_GPT.choices[0].message.content
152+
153+
except AttributeError:
154+
return None, f"Malformed answer from API: {content}", 500
163155

164156
def extract_code(text):
165157
# Match triple backtick blocks first
@@ -179,11 +171,6 @@ def extract_non_code(text):
179171
text = re.sub(r'`(.+?)`', '', text, flags=re.DOTALL)
180172
return text.strip()
181173

182-
183-
if response.status_code != 200:
184-
return None, "Error: " + response.text, 500
185-
186-
content = response.json()["choices"][0]["message"]["content"]
187174
return extract_code(content), extract_non_code(content), 200
188175

189176
# We know this Flask app is for local use. So we can disable the verbose Werkzeug logger

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
'flask-cors>=3,<4',
2222
'python-dotenv>=0.18,<2',
2323
'pandas>=1.3,<2',
24+
'openai>=0.25,<1',
2425
],
2526
entry_points={
2627
'console_scripts': [

0 commit comments

Comments
 (0)