|
| 1 | +import json |
1 | 2 | import os |
2 | 3 | import warnings |
3 | 4 | from abc import ABC, abstractmethod |
4 | 5 | from typing import Dict |
5 | 6 |
|
6 | 7 | from dotenv import load_dotenv |
7 | 8 | from openai import OpenAI |
8 | | -from tenacity import retry, stop_after_attempt, wait_random_exponential # for exponential backoff |
| 9 | +from tenacity import stop_after_attempt # for exponential backoff |
| 10 | +from tenacity import retry, wait_random_exponential |
9 | 11 |
|
10 | 12 | load_dotenv() |
11 | 13 |
|
|
39 | 41 |
|
40 | 42 | class LLMInterface(ABC): |
41 | 43 | @abstractmethod |
42 | | - def run(self, prompt: Dict[str, str], temperature: float = 0, max_tokens: int = 1024) -> str: |
| 44 | + def run(self, prompt: Dict[str, str], temperature: float = 0) -> str: |
| 45 | + pass |
| 46 | + |
| 47 | + @abstractmethod |
| 48 | + def json(self, prompt: Dict[str, str], temperature: float = 0) -> str: |
43 | 49 | pass |
44 | 50 |
|
45 | 51 |
|
@@ -113,11 +119,11 @@ def __init__(self, model): |
113 | 119 | raise ValueError( |
114 | 120 | f"Model {model} is not supported. " |
115 | 121 | "Please choose from one of the following LLM providers: " |
116 | | - "OpenAI gpt models, Anthropic claude models, Google Gemini models, Azure OpenAI deployment, Cohere models, AWS Bedrock, and VLLM model endpoints." |
| 122 | + "OpenAI gpt models (e.g. gpt-4o-mini, gpt-4o), Anthropic claude models (e.g. claude-3.5-sonnet), Google Gemini models (e.g. gemini-pro), Azure OpenAI deployment (azure)" |
117 | 123 | ) |
118 | 124 |
|
119 | | - @retry(wait=wait_random_exponential(min=1, max=90), stop=stop_after_attempt(15)) |
120 | | - def _llm_response(self, prompt, temperature, max_tokens): |
| 125 | + @retry(wait=wait_random_exponential(min=1, max=90), stop=stop_after_attempt(50)) |
| 126 | + def _llm_response(self, prompt, temperature, max_tokens=1024): |
121 | 127 | """ |
122 | 128 | Send a prompt to the LLM and return the response. |
123 | 129 | """ |
@@ -163,7 +169,7 @@ def _llm_response(self, prompt, temperature, max_tokens): |
163 | 169 | content = response.completion |
164 | 170 | elif COHERE_AVAILABLE and isinstance(self.client, CohereClient): |
165 | 171 | prompt = f"{prompt['system_prompt']}\n{prompt['user_prompt']}" |
166 | | - response = self.client.generate(model="command", prompt=prompt, temperature=temperature, max_tokens=max_tokens) # type: ignore |
| 172 | + response = self.client.generate(model="command", prompt=prompt, temperature=temperature, max_tokens=1024) # type: ignore |
167 | 173 | try: |
168 | 174 | content = response.generations[0].text |
169 | 175 | except: |
@@ -222,10 +228,18 @@ def run(self, prompt, temperature=0, max_tokens=1024): |
222 | 228 | """ |
223 | 229 | Run the LLM and return the response. |
224 | 230 | Default temperature: 0 |
225 | | - Default max_tokens: 1024 |
226 | 231 | """ |
227 | 232 | content = self._llm_response(prompt=prompt, temperature=temperature, max_tokens=max_tokens) |
228 | 233 | return content |
229 | 234 |
|
| 235 | + def json(self, prompt, temperature=0, max_tokens=1024, **kwargs): |
| 236 | + llm_output = self.run(prompt, temperature, max_tokens=max_tokens) |
| 237 | + if "{" in llm_output: |
| 238 | + first_bracket = llm_output.index("{") |
| 239 | + json_output = llm_output[first_bracket:].strip("```").strip(" ") |
| 240 | + else: |
| 241 | + json_output = llm_output.strip("```").strip(" ").replace("json", "") |
| 242 | + return json.loads(json_output) |
| 243 | + |
230 | 244 |
|
231 | | -DefaultLLM = lambda: LLMFactory(model=os.getenv("EVAL_LLM", "gpt-3.5-turbo-0125")) |
| 245 | +DefaultLLM = lambda: LLMFactory(model=os.getenv("EVAL_LLM", "gpt-4o-mini")) |
0 commit comments