Skip to content

Commit 9607fc0

Browse files
authored
Add max_tokens field for LLM interface (#66)
1 parent 8124a22 commit 9607fc0

File tree

2 files changed

+15
-11
lines changed

2 files changed

+15
-11
lines changed

.env.example

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ ANTHROPIC_API_KEY="sk-ant-xxxx"
77
# For Gemini
88
GEMINI_API_KEY="xxxx"
99

10+
# For Cohere
11+
COHERE_API_KEY="xxxx"
12+
1013
# For Azure OpenAI
1114
AZURE_OPENAI_API_KEY="sk-xxxx"
1215
AZURE_OPENAI_API_VERSION="2023-03-15-preview"

continuous_eval/llm_factory.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939

4040
class LLMInterface(ABC):
4141
@abstractmethod
42-
def run(self, prompt: Dict[str, str], temperature: float = 0) -> str:
42+
def run(self, prompt: Dict[str, str], temperature: float = 0, max_tokens: int = 1024) -> str:
4343
pass
4444

4545

@@ -113,11 +113,11 @@ def __init__(self, model):
113113
raise ValueError(
114114
f"Model {model} is not supported. "
115115
"Please choose from one of the following LLM providers: "
116-
"OpenAI gpt models (e.g. gpt-4-turbo-preview, gpt-3.5-turbo-0125), Anthropic claude models (e.g. claude-2.1, claude-instant-1.2), Google Gemini models (e.g. gemini-pro), Azure OpenAI deployment (azure)"
116+
"OpenAI gpt models, Anthropic claude models, Google Gemini models, Azure OpenAI deployment, Cohere models, AWS Bedrock, and VLLM model endpoints."
117117
)
118118

119119
@retry(wait=wait_random_exponential(min=1, max=90), stop=stop_after_attempt(15))
120-
def _llm_response(self, prompt, temperature):
120+
def _llm_response(self, prompt, temperature, max_tokens):
121121
"""
122122
Send a prompt to the LLM and return the response.
123123
"""
@@ -133,7 +133,7 @@ def _llm_response(self, prompt, temperature):
133133
],
134134
seed=0,
135135
temperature=temperature,
136-
max_tokens=1024,
136+
max_tokens=max_tokens,
137137
top_p=1,
138138
frequency_penalty=0,
139139
presence_penalty=0,
@@ -147,7 +147,7 @@ def _llm_response(self, prompt, temperature):
147147
],
148148
seed=0,
149149
temperature=temperature,
150-
max_tokens=1024,
150+
max_tokens=max_tokens,
151151
top_p=1,
152152
frequency_penalty=0,
153153
presence_penalty=0,
@@ -156,14 +156,14 @@ def _llm_response(self, prompt, temperature):
156156
elif ANTHROPIC_AVAILABLE and isinstance(self.client, Anthropic):
157157
response = self.client.completions.create( # type: ignore
158158
model="claude-2.1",
159-
max_tokens_to_sample=1024,
159+
max_tokens_to_sample=max_tokens,
160160
temperature=temperature,
161161
prompt=f"{prompt['system_prompt']}{HUMAN_PROMPT}{prompt['user_prompt']}{AI_PROMPT}",
162162
)
163163
content = response.completion
164164
elif COHERE_AVAILABLE and isinstance(self.client, CohereClient):
165165
prompt = f"{prompt['system_prompt']}\n{prompt['user_prompt']}"
166-
response = self.client.generate(model="command", prompt=prompt, temperature=temperature, max_tokens=1024) # type: ignore
166+
response = self.client.generate(model="command", prompt=prompt, temperature=temperature, max_tokens=max_tokens) # type: ignore
167167
try:
168168
content = response.generations[0].text
169169
except:
@@ -174,7 +174,7 @@ def _llm_response(self, prompt, temperature):
174174
"temperature": temperature,
175175
"top_p": 1,
176176
"top_k": 1,
177-
"max_output_tokens": 1024,
177+
"max_output_tokens": max_tokens,
178178
}
179179
safety_settings = [
180180
{
@@ -207,7 +207,7 @@ def _llm_response(self, prompt, temperature):
207207
HumanMessage(content=prompt["user_prompt"]),
208208
],
209209
temperature=temperature,
210-
max_tokens=1024,
210+
max_tokens=max_tokens,
211211
top_p=1,
212212
)
213213
content = response.dict()["content"]
@@ -218,12 +218,13 @@ def _llm_response(self, prompt, temperature):
218218

219219
return content
220220

221-
def run(self, prompt, temperature=0):
221+
def run(self, prompt, temperature=0, max_tokens=1024):
222222
"""
223223
Run the LLM and return the response.
224224
Default temperature: 0
225+
Default max_tokens: 1024
225226
"""
226-
content = self._llm_response(prompt=prompt, temperature=temperature)
227+
content = self._llm_response(prompt=prompt, temperature=temperature, max_tokens=max_tokens)
227228
return content
228229

229230

0 commit comments

Comments
 (0)