Skip to content

Commit 31893e0

Browse files
authored
Merge pull request #189 from ahmedmustahid/main
refactor: Isolate Azure OpenAI configuration
2 parents bf5b0c5 + 4a9fcf8 commit 31893e0

File tree

1 file changed

+111
-73
lines changed

1 file changed

+111
-73
lines changed

textgrad/engine/openai.py

Lines changed: 111 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,86 +1,79 @@
11
try:
2-
from openai import OpenAI, AzureOpenAI
2+
from openai import AzureOpenAI, OpenAI
33
except ImportError:
4-
raise ImportError("If you'd like to use OpenAI models, please install the openai package by running `pip install openai`, and add 'OPENAI_API_KEY' to your environment variables.")
4+
raise ImportError(
5+
"If you'd like to use OpenAI models, please install the openai package by running `pip install openai`, and add 'OPENAI_API_KEY' to your environment variables."
6+
)
57

6-
import os
7-
import json
88
import base64
9+
import json
10+
import os
11+
from typing import List, Union
12+
913
import platformdirs
1014
from tenacity import (
1115
retry,
1216
stop_after_attempt,
1317
wait_random_exponential,
1418
)
15-
from typing import List, Union
1619

17-
from .base import EngineLM, CachedEngine
20+
from .base import CachedEngine, EngineLM
1821
from .engine_utils import get_image_type_from_bytes
1922

2023
# Default base URL for OLLAMA
21-
OLLAMA_BASE_URL = 'http://localhost:11434/v1'
24+
OLLAMA_BASE_URL = "http://localhost:11434/v1"
2225

2326
# Check if the user set the OLLAMA_BASE_URL environment variable
2427
if os.getenv("OLLAMA_BASE_URL"):
2528
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL")
2629

27-
class ChatOpenAI(EngineLM, CachedEngine):
30+
31+
class BaseOpenAIEngine(EngineLM, CachedEngine):
2832
DEFAULT_SYSTEM_PROMPT = "You are a helpful, creative, and smart assistant."
2933

3034
def __init__(
3135
self,
32-
model_string: str="gpt-3.5-turbo-0613",
33-
system_prompt: str=DEFAULT_SYSTEM_PROMPT,
34-
is_multimodal: bool=False,
35-
base_url: str=None,
36-
**kwargs):
37-
"""
38-
:param model_string:
39-
:param system_prompt:
40-
:param base_url: Used to support Ollama
41-
"""
42-
root = platformdirs.user_cache_dir("textgrad")
43-
cache_path = os.path.join(root, f"cache_openai_{model_string}.db")
44-
36+
cache_path: str,
37+
system_prompt: str,
38+
model_string: str,
39+
is_multimodal: bool = False,
40+
):
4541
super().__init__(cache_path=cache_path)
46-
4742
self.system_prompt = system_prompt
48-
self.base_url = base_url
49-
50-
if not base_url:
51-
if os.getenv("OPENAI_API_KEY") is None:
52-
raise ValueError("Please set the OPENAI_API_KEY environment variable if you'd like to use OpenAI models.")
53-
54-
self.client = OpenAI(
55-
api_key=os.getenv("OPENAI_API_KEY")
56-
)
57-
elif base_url and base_url == OLLAMA_BASE_URL:
58-
self.client = OpenAI(
59-
base_url=base_url,
60-
api_key="ollama"
61-
)
62-
else:
63-
raise ValueError("Invalid base URL provided. Please use the default OLLAMA base URL or None.")
64-
6543
self.model_string = model_string
6644
self.is_multimodal = is_multimodal
6745

6846
@retry(wait=wait_random_exponential(min=1, max=5), stop=stop_after_attempt(5))
69-
def generate(self, content: Union[str, List[Union[str, bytes]]], system_prompt: str=None, **kwargs):
47+
def generate(
48+
self,
49+
content: Union[str, List[Union[str, bytes]]],
50+
system_prompt: str = None,
51+
**kwargs,
52+
):
7053
if isinstance(content, str):
71-
return self._generate_from_single_prompt(content, system_prompt=system_prompt, **kwargs)
72-
54+
return self._generate_from_single_prompt(
55+
content, system_prompt=system_prompt, **kwargs
56+
)
57+
7358
elif isinstance(content, list):
7459
has_multimodal_input = any(isinstance(item, bytes) for item in content)
7560
if (has_multimodal_input) and (not self.is_multimodal):
76-
raise NotImplementedError("Multimodal generation is only supported for Claude-3 and beyond.")
77-
78-
return self._generate_from_multiple_input(content, system_prompt=system_prompt, **kwargs)
61+
raise NotImplementedError(
62+
"Multimodal generation is only supported for Claude-3 and beyond."
63+
)
64+
65+
return self._generate_from_multiple_input(
66+
content, system_prompt=system_prompt, **kwargs
67+
)
7968

8069
def _generate_from_single_prompt(
81-
self, prompt: str, system_prompt: str=None, temperature=0, max_tokens=2000, top_p=0.99
70+
self,
71+
prompt: str,
72+
system_prompt: str = None,
73+
temperature=0,
74+
max_tokens=2000,
75+
top_p=0.99,
8276
):
83-
8477
sys_prompt_arg = system_prompt if system_prompt else self.system_prompt
8578

8679
cache_or_none = self._check_cache(sys_prompt_arg + prompt)
@@ -109,31 +102,34 @@ def __call__(self, prompt, **kwargs):
109102
return self.generate(prompt, **kwargs)
110103

111104
def _format_content(self, content: List[Union[str, bytes]]) -> List[dict]:
112-
"""Helper function to format a list of strings and bytes into a list of dictionaries to pass as messages to the API.
113-
"""
105+
"""Helper function to format a list of strings and bytes into a list of dictionaries to pass as messages to the API."""
114106
formatted_content = []
115107
for item in content:
116108
if isinstance(item, bytes):
117109
# For now, bytes are assumed to be images
118110
image_type = get_image_type_from_bytes(item)
119-
base64_image = base64.b64encode(item).decode('utf-8')
120-
formatted_content.append({
121-
"type": "image_url",
122-
"image_url": {
123-
"url": f"data:image/{image_type};base64,{base64_image}"
111+
base64_image = base64.b64encode(item).decode("utf-8")
112+
formatted_content.append(
113+
{
114+
"type": "image_url",
115+
"image_url": {
116+
"url": f"data:image/{image_type};base64,{base64_image}"
117+
},
124118
}
125-
})
119+
)
126120
elif isinstance(item, str):
127-
formatted_content.append({
128-
"type": "text",
129-
"text": item
130-
})
121+
formatted_content.append({"type": "text", "text": item})
131122
else:
132123
raise ValueError(f"Unsupported input type: {type(item)}")
133124
return formatted_content
134125

135126
def _generate_from_multiple_input(
136-
self, content: List[Union[str, bytes]], system_prompt=None, temperature=0, max_tokens=2000, top_p=0.99
127+
self,
128+
content: List[Union[str, bytes]],
129+
system_prompt=None,
130+
temperature=0,
131+
max_tokens=2000,
132+
top_p=0.99,
137133
):
138134
sys_prompt_arg = system_prompt if system_prompt else self.system_prompt
139135
formatted_content = self._format_content(content)
@@ -158,20 +154,60 @@ def _generate_from_multiple_input(
158154
self._save_cache(cache_key, response_text)
159155
return response_text
160156

161-
class AzureChatOpenAI(ChatOpenAI):
157+
158+
class ChatOpenAI(BaseOpenAIEngine):
159+
def __init__(
160+
self,
161+
model_string: str = "gpt-3.5-turbo-0613",
162+
system_prompt: str = BaseOpenAIEngine.DEFAULT_SYSTEM_PROMPT,
163+
is_multimodal: bool = False,
164+
base_url: str = None,
165+
**kwargs,
166+
):
167+
"""
168+
:param model_string:
169+
:param system_prompt:
170+
:param base_url: Used to support Ollama
171+
"""
172+
root = platformdirs.user_cache_dir("textgrad")
173+
cache_path = os.path.join(root, f"cache_openai_{model_string}.db")
174+
175+
super().__init__(cache_path, system_prompt, model_string, is_multimodal)
176+
177+
self.base_url = base_url
178+
179+
if not base_url:
180+
if os.getenv("OPENAI_API_KEY") is None:
181+
raise ValueError(
182+
"Please set the OPENAI_API_KEY environment variable if you'd like to use OpenAI models."
183+
)
184+
185+
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
186+
elif base_url and base_url == OLLAMA_BASE_URL:
187+
self.client = OpenAI(base_url=base_url, api_key="ollama")
188+
else:
189+
raise ValueError(
190+
"Invalid base URL provided. Please use the default OLLAMA base URL or None."
191+
)
192+
193+
194+
class AzureChatOpenAI(BaseOpenAIEngine):
162195
def __init__(
163196
self,
164197
model_string="gpt-35-turbo",
165-
system_prompt=ChatOpenAI.DEFAULT_SYSTEM_PROMPT,
166-
**kwargs):
198+
system_prompt=BaseOpenAIEngine.DEFAULT_SYSTEM_PROMPT,
199+
is_multimodal: bool = False,
200+
**kwargs,
201+
):
167202
"""
168203
Initializes an interface for interacting with Azure's OpenAI models.
169204
170-
This class extends the ChatOpenAI class to use Azure's OpenAI API instead of OpenAI's API. It sets up the necessary client with the appropriate API version, API key, and endpoint from environment variables.
205+
This class extends the EngineLM and CachedEngine classes to use Azure's OpenAI API instead of OpenAI's API. It sets up the necessary client with the appropriate API version, API key, and endpoint from environment variables.
171206
172-
:param model_string: The model identifier for Azure OpenAI. Defaults to 'gpt-3.5-turbo'.
173-
:param system_prompt: The default system prompt to use when generating responses. Defaults to ChatOpenAI's default system prompt.
174-
:param kwargs: Additional keyword arguments to pass to the ChatOpenAI constructor.
207+
:param model_string: The model identifier for Azure OpenAI. Defaults to 'gpt-35-turbo'.
208+
:param system_prompt: The default system prompt to use when generating responses. Defaults to the default system prompt.
209+
:param is_multimodal: Whether this is a multimodal model. Defaults to False.
210+
:param kwargs: Additional keyword arguments.
175211
176212
Environment variables:
177213
- AZURE_OPENAI_API_KEY: The API key for authenticating with Azure OpenAI.
@@ -182,19 +218,21 @@ def __init__(
182218
ValueError: If the AZURE_OPENAI_API_KEY environment variable is not set.
183219
"""
184220
root = platformdirs.user_cache_dir("textgrad")
185-
cache_path = os.path.join(root, f"cache_azure_{model_string}.db") # Changed cache path to differentiate from OpenAI cache
221+
cache_path = os.path.join(
222+
root, f"cache_azure_{model_string}.db"
223+
) # Changed cache path to differentiate from OpenAI cache
186224

187-
super().__init__(cache_path=cache_path, system_prompt=system_prompt, **kwargs)
225+
super().__init__(cache_path, system_prompt, model_string, is_multimodal)
188226

189-
self.system_prompt = system_prompt
190227
api_version = os.getenv("AZURE_OPENAI_API_VERSION", "2023-07-01-preview")
191228
if os.getenv("AZURE_OPENAI_API_KEY") is None:
192-
raise ValueError("Please set the AZURE_OPENAI_API_KEY, AZURE_OPENAI_API_BASE, and AZURE_OPENAI_API_VERSION environment variables if you'd like to use Azure OpenAI models.")
193-
229+
raise ValueError(
230+
"Please set the AZURE_OPENAI_API_KEY, AZURE_OPENAI_API_BASE, and AZURE_OPENAI_API_VERSION environment variables if you'd like to use Azure OpenAI models."
231+
)
232+
194233
self.client = AzureOpenAI(
195234
api_version=api_version,
196235
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
197236
azure_endpoint=os.getenv("AZURE_OPENAI_API_BASE"),
198237
azure_deployment=model_string,
199238
)
200-
self.model_string = model_string

0 commit comments

Comments
 (0)