Skip to content

Commit eb8384e

Browse files
authored
llm factory update (#76)
1 parent a65add6 commit eb8384e

File tree

1 file changed

+22
-8
lines changed

1 file changed

+22
-8
lines changed

continuous_eval/llm_factory.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1+
import json
12
import os
23
import warnings
34
from abc import ABC, abstractmethod
45
from typing import Dict
56

67
from dotenv import load_dotenv
78
from openai import OpenAI
8-
from tenacity import retry, stop_after_attempt, wait_random_exponential # for exponential backoff
9+
from tenacity import stop_after_attempt # for exponential backoff
10+
from tenacity import retry, wait_random_exponential
911

1012
load_dotenv()
1113

@@ -39,7 +41,11 @@
3941

4042
class LLMInterface(ABC):
4143
@abstractmethod
42-
def run(self, prompt: Dict[str, str], temperature: float = 0, max_tokens: int = 1024) -> str:
44+
def run(self, prompt: Dict[str, str], temperature: float = 0) -> str:
45+
pass
46+
47+
@abstractmethod
48+
def json(self, prompt: Dict[str, str], temperature: float = 0) -> str:
4349
pass
4450

4551

@@ -113,11 +119,11 @@ def __init__(self, model):
113119
raise ValueError(
114120
f"Model {model} is not supported. "
115121
"Please choose from one of the following LLM providers: "
116-
"OpenAI gpt models, Anthropic claude models, Google Gemini models, Azure OpenAI deployment, Cohere models, AWS Bedrock, and VLLM model endpoints."
122+
"OpenAI gpt models (e.g. gpt-4o-mini, gpt-4o), Anthropic claude models (e.g. claude-3.5-sonnet), Google Gemini models (e.g. gemini-pro), Azure OpenAI deployment (azure)"
117123
)
118124

119-
@retry(wait=wait_random_exponential(min=1, max=90), stop=stop_after_attempt(15))
120-
def _llm_response(self, prompt, temperature, max_tokens):
125+
@retry(wait=wait_random_exponential(min=1, max=90), stop=stop_after_attempt(50))
126+
def _llm_response(self, prompt, temperature, max_tokens=1024):
121127
"""
122128
Send a prompt to the LLM and return the response.
123129
"""
@@ -163,7 +169,7 @@ def _llm_response(self, prompt, temperature, max_tokens):
163169
content = response.completion
164170
elif COHERE_AVAILABLE and isinstance(self.client, CohereClient):
165171
prompt = f"{prompt['system_prompt']}\n{prompt['user_prompt']}"
166-
response = self.client.generate(model="command", prompt=prompt, temperature=temperature, max_tokens=max_tokens) # type: ignore
172+
response = self.client.generate(model="command", prompt=prompt, temperature=temperature, max_tokens=1024) # type: ignore
167173
try:
168174
content = response.generations[0].text
169175
except:
@@ -222,10 +228,18 @@ def run(self, prompt, temperature=0, max_tokens=1024):
222228
"""
223229
Run the LLM and return the response.
224230
Default temperature: 0
225-
Default max_tokens: 1024
226231
"""
227232
content = self._llm_response(prompt=prompt, temperature=temperature, max_tokens=max_tokens)
228233
return content
229234

235+
def json(self, prompt, temperature=0, max_tokens=1024, **kwargs):
236+
llm_output = self.run(prompt, temperature, max_tokens=max_tokens)
237+
if "{" in llm_output:
238+
first_bracket = llm_output.index("{")
239+
json_output = llm_output[first_bracket:].strip("```").strip(" ")
240+
else:
241+
json_output = llm_output.strip("```").strip(" ").replace("json", "")
242+
return json.loads(json_output)
243+
230244

231-
DefaultLLM = lambda: LLMFactory(model=os.getenv("EVAL_LLM", "gpt-3.5-turbo-0125"))
245+
DefaultLLM = lambda: LLMFactory(model=os.getenv("EVAL_LLM", "gpt-4o-mini"))

0 commit comments

Comments
 (0)