@@ -74,7 +74,7 @@ class Agent(LlmAgent):
7474 name (str): The name of the agent.
7575 description (str): A description of the agent, useful in A2A scenarios.
7676 instruction (Union[str, InstructionProvider]): The instruction or instruction provider.
77- model_name (str): Name of the model used by the agent.
77+ model_name (Union[ str, List[str]] ): Name of the model used by the agent.
7878 model_provider (str): Provider of the model (e.g., openai).
7979 model_api_base (str): The base URL of the model API.
8080 model_api_key (str): The API key for accessing the model.
@@ -93,7 +93,9 @@ class Agent(LlmAgent):
9393 description : str = DEFAULT_DESCRIPTION
9494 instruction : Union [str , InstructionProvider ] = DEFAULT_INSTRUCTION
9595
96- model_name : str = Field (default_factory = lambda : settings .model .name )
96+ model_name : Union [str , list [str ]] = Field (
97+ default_factory = lambda : settings .model .name
98+ )
9799 model_provider : str = Field (default_factory = lambda : settings .model .provider )
98100 model_api_base : str = Field (default_factory = lambda : settings .model .api_base )
99101 model_api_key : str = Field (default_factory = lambda : settings .model .api_key )
@@ -183,10 +185,29 @@ def model_post_init(self, __context: Any) -> None:
183185 min_tokens = 0 ,
184186 )
185187 else :
188+ fallbacks = None
189+ if isinstance (self .model_name , list ):
190+ if self .model_name :
191+ model_name = self .model_name [0 ]
192+ fallbacks = [
193+ f"{ self .model_provider } /{ m } " for m in self .model_name [1 :]
194+ ]
195+ logger .info (
196+ f"Using primary model: { model_name } , with fallbacks: { self .model_name [1 :]} "
197+ )
198+ else :
199+ model_name = settings .model .name
200+ logger .warning (
201+ f"Empty model_name list provided, using default model from settings: { model_name } "
202+ )
203+ else :
204+ model_name = self .model_name
205+
186206 self .model = LiteLlm (
187- model = f"{ self .model_provider } /{ self . model_name } " ,
207+ model = f"{ self .model_provider } /{ model_name } " ,
188208 api_key = self .model_api_key ,
189209 api_base = self .model_api_base ,
210+ fallbacks = fallbacks ,
190211 ** self .model_extra_config ,
191212 )
192213 logger .debug (
0 commit comments