3939
4040class LLMInterface (ABC ):
4141 @abstractmethod
42- def run (self , prompt : Dict [str , str ], temperature : float = 0 ) -> str :
42+ def run (self , prompt : Dict [str , str ], temperature : float = 0 , max_tokens : int = 1024 ) -> str :
4343 pass
4444
4545
@@ -113,11 +113,11 @@ def __init__(self, model):
113113 raise ValueError (
114114 f"Model { model } is not supported. "
115115 "Please choose from one of the following LLM providers: "
116- "OpenAI gpt models (e.g. gpt-4-turbo-preview, gpt-3.5-turbo-0125), Anthropic claude models (e.g. claude-2.1, claude-instant-1.2), Google Gemini models (e.g. gemini-pro) , Azure OpenAI deployment (azure) "
116+ "OpenAI gpt models, Anthropic claude models, Google Gemini models, Azure OpenAI deployment, Cohere models, AWS Bedrock, and VLLM model endpoints. "
117117 )
118118
119119 @retry (wait = wait_random_exponential (min = 1 , max = 90 ), stop = stop_after_attempt (15 ))
120- def _llm_response (self , prompt , temperature ):
120+ def _llm_response (self , prompt , temperature , max_tokens ):
121121 """
122122 Send a prompt to the LLM and return the response.
123123 """
@@ -133,7 +133,7 @@ def _llm_response(self, prompt, temperature):
133133 ],
134134 seed = 0 ,
135135 temperature = temperature ,
136- max_tokens = 1024 ,
136+ max_tokens = max_tokens ,
137137 top_p = 1 ,
138138 frequency_penalty = 0 ,
139139 presence_penalty = 0 ,
@@ -147,7 +147,7 @@ def _llm_response(self, prompt, temperature):
147147 ],
148148 seed = 0 ,
149149 temperature = temperature ,
150- max_tokens = 1024 ,
150+ max_tokens = max_tokens ,
151151 top_p = 1 ,
152152 frequency_penalty = 0 ,
153153 presence_penalty = 0 ,
@@ -156,14 +156,14 @@ def _llm_response(self, prompt, temperature):
156156 elif ANTHROPIC_AVAILABLE and isinstance (self .client , Anthropic ):
157157 response = self .client .completions .create ( # type: ignore
158158 model = "claude-2.1" ,
159- max_tokens_to_sample = 1024 ,
159+ max_tokens_to_sample = max_tokens ,
160160 temperature = temperature ,
161161 prompt = f"{ prompt ['system_prompt' ]} { HUMAN_PROMPT } { prompt ['user_prompt' ]} { AI_PROMPT } " ,
162162 )
163163 content = response .completion
164164 elif COHERE_AVAILABLE and isinstance (self .client , CohereClient ):
165165 prompt = f"{ prompt ['system_prompt' ]} \n { prompt ['user_prompt' ]} "
166- response = self .client .generate (model = "command" , prompt = prompt , temperature = temperature , max_tokens = 1024 ) # type: ignore
166+ response = self .client .generate (model = "command" , prompt = prompt , temperature = temperature , max_tokens = max_tokens ) # type: ignore
167167 try :
168168 content = response .generations [0 ].text
169169 except :
@@ -174,7 +174,7 @@ def _llm_response(self, prompt, temperature):
174174 "temperature" : temperature ,
175175 "top_p" : 1 ,
176176 "top_k" : 1 ,
177- "max_output_tokens" : 1024 ,
177+ "max_output_tokens" : max_tokens ,
178178 }
179179 safety_settings = [
180180 {
@@ -207,7 +207,7 @@ def _llm_response(self, prompt, temperature):
207207 HumanMessage (content = prompt ["user_prompt" ]),
208208 ],
209209 temperature = temperature ,
210- max_tokens = 1024 ,
210+ max_tokens = max_tokens ,
211211 top_p = 1 ,
212212 )
213213 content = response .dict ()["content" ]
@@ -218,12 +218,13 @@ def _llm_response(self, prompt, temperature):
218218
219219 return content
220220
221- def run (self , prompt , temperature = 0 ):
221+ def run (self , prompt , temperature = 0 , max_tokens = 1024 ):
222222 """
223223 Run the LLM and return the response.
224224 Default temperature: 0
225+ Default max_tokens: 1024
225226 """
226- content = self ._llm_response (prompt = prompt , temperature = temperature )
227+ content = self ._llm_response (prompt = prompt , temperature = temperature , max_tokens = max_tokens )
227228 return content
228229
229230
0 commit comments