diff --git a/src/praisonai/praisonai/agents_generator.py b/src/praisonai/praisonai/agents_generator.py index dfb9b020a..c1762c52f 100644 --- a/src/praisonai/praisonai/agents_generator.py +++ b/src/praisonai/praisonai/agents_generator.py @@ -343,6 +343,168 @@ def __init__(self, cli_config): agent_config[field] = value self.logger.debug(f"CLI override for agent {agent_name}: {field} = {value}") + def _prepare_for_run(self, config): + """ + Single source of truth for YAML normalisation, validation, + CLI-backend compatibility, tool resolution, AutoGen version + selection, and adapter resolution. Used by BOTH sync and async. + """ + # Canonical format conversion: 'agents' -> 'roles', 'instructions' -> 'backstory' + if 'agents' in config and 'roles' not in config: + config['roles'] = {} + for agent_name, agent_config in config['agents'].items(): + role_config = dict(agent_config) if agent_config else {} + # Convert 'instructions' to 'backstory' if present + if 'instructions' in role_config and 'backstory' not in role_config: + role_config['backstory'] = role_config['instructions'] + # Ensure required fields have defaults + if 'role' not in role_config: + role_config['role'] = agent_name.replace('_', ' ').title() + if 'goal' not in role_config: + role_config['goal'] = role_config.get('backstory', 'Complete the assigned task') + if 'backstory' not in role_config: + role_config['backstory'] = f'You are a {role_config["role"]}' + config['roles'][agent_name] = role_config + + # Get workflow input: 'input' is canonical, 'topic' is alias for backward compatibility + topic = config.get('input', config.get('topic', '')) + + # Validate agents configuration for typos in field names + self._validate_agents_config(config) + + # Build tools dictionary using shared logic + tools_dict = self._build_tools_dict(config) + + # Select framework with AutoGen version logic + framework = self._select_autogen_version( + self.framework or config.get('framework', 'crewai'), + config, + ) + + # Get and resolve adapter + adapter = self._get_framework_adapter(framework).resolve() + + # Validate framework availability + from .framework_adapters.validators import assert_framework_available + assert_framework_available(adapter.name) + + # Validate cli_backend compatibility + self._validate_cli_backend_compatibility(config, framework) + + # Initialize observability hooks + from .observability.hooks import init_observability + init_observability(adapter.name) + + # Run adapter setup hooks + adapter.setup(framework_tag=adapter.name) + + # Update framework reference if resolution changed it + self.framework = adapter.name + self.framework_adapter = adapter + + return { + 'adapter': adapter, + 'config': config, + 'topic': topic, + 'tools_dict': tools_dict, + } + + def _build_tools_dict(self, config): + """Shared tool resolution logic for sync and async paths.""" + tools_dict = {} + + # Demand-driven tool resolution - only resolve tools actually used in YAML + if is_available("crewai") or is_available("autogen") or is_available("praisonaiagents") or is_available("ag2"): + try: + # Collect all tool names mentioned in the YAML config + needed_tools: set[str] = set() + for role_cfg in config.get('roles', {}).values(): + for t in role_cfg.get('tools') or []: + if isinstance(t, str) and t.strip(): + needed_tools.add(t.strip()) + for task_cfg in (role_cfg.get('tasks') or {}).values(): + if not isinstance(task_cfg, dict): + continue + for t in task_cfg.get('tools') or []: + if isinstance(t, str) and t.strip(): + needed_tools.add(t.strip()) + + # Resolve only the tools actually referenced in YAML + for tool_name in needed_tools: + try: + resolved_tool = self.tool_resolver.resolve(tool_name) + if resolved_tool is None: + self.logger.warning(f"Tool '{tool_name}' not found") + continue + tools_dict[tool_name] = ( + resolved_tool() if inspect.isclass(resolved_tool) else resolved_tool + ) + except Exception as e: + self.logger.warning(f"Failed to initialize tool '{tool_name}': {e}") + continue + + except Exception as e: + self.logger.warning(f"Error collecting YAML tool references: {e}") + + # Add tools from class names - use tool_resolver to check tool validity + for tool_class in self.tools: + if isinstance(tool_class, type): + try: + tool_instance = tool_class() + tool_name = tool_class.__name__ + tools_dict[tool_name] = tool_instance + self.logger.debug(f"Added tool: {tool_name}") + except Exception as e: + self.logger.warning(f"Failed to instantiate tool class {tool_class.__name__}: {e}") + + root_directory = os.getcwd() + tools_py_path = os.path.join(root_directory, 'tools.py') + tools_dir_path = Path(root_directory) / 'tools' + + # Use consolidated ToolResolver for tools.py loading + tools_dict.update(self.tool_resolver.get_local_tool_classes()) + if os.path.isfile(tools_py_path): + self.logger.debug("tools.py exists in the root directory. Loading tools.py and skipping tools folder.") + elif tools_dir_path.is_dir(): + tools_dict.update(self.tool_resolver.get_local_tool_classes_from_dir(tools_dir_path)) + if tools_dict: + self.logger.debug("tools folder exists in the root directory") + + return tools_dict + + def _select_autogen_version(self, framework, config): + """Shared AutoGen version selection logic for sync and async paths.""" + if framework == "autogen": + autogen_v4_adapter = self._get_framework_adapter("autogen_v4") + autogen_v2_adapter = self._get_framework_adapter("autogen") + + autogen_version = str( + config.get('autogen_version', os.environ.get("AUTOGEN_VERSION", "auto")) + ).lower() + use_v4 = False + + if autogen_version == "v0.4" and autogen_v4_adapter.is_available(): + use_v4 = True + elif autogen_version == "v0.2" and autogen_v2_adapter.is_available(): + use_v4 = False + elif autogen_version == "auto": + use_v4 = autogen_v4_adapter.is_available() + else: + use_v4 = autogen_v4_adapter.is_available() and not autogen_v2_adapter.is_available() + + framework = "autogen_v4" if use_v4 else "autogen" + + # Initialize AgentOps if configured + agentops_api_key = os.getenv("AGENTOPS_API_KEY") + if agentops_api_key: + try: + import agentops + agentops.init(agentops_api_key, default_tags=[framework]) + except ImportError: + pass + + return framework + def _validate_cli_backend_compatibility(self, config, framework): """Validate that cli_backend is only used with compatible frameworks.""" # Check if any agent/role defines cli_backend (support both key names) @@ -512,141 +674,19 @@ def generate_crew_and_kickoff(self): # Route to YAMLWorkflowParser for advanced workflow patterns return self._run_yaml_workflow(config) - config, adapter, tools_dict, topic = self._prepare(config) - return adapter.run( - config, + # Use shared preparation logic + prep = self._prepare_for_run(config) + + self.logger.info(f"Using framework: {prep['adapter'].name}") + return prep['adapter'].run( + prep['config'], self.config_list, - topic, - tools_dict=tools_dict, + prep['topic'], + tools_dict=prep['tools_dict'], agent_callback=getattr(self, 'agent_callback', None), task_callback=getattr(self, 'task_callback', None), cli_config=getattr(self, 'cli_config', None), ) - - def _prepare(self, config): - """Shared preparation logic for both sync and async entry points.""" - # Canonical format conversion: 'agents' -> 'roles', 'instructions' -> 'backstory' - if 'agents' in config and 'roles' not in config: - config['roles'] = {} - for agent_name, agent_config in config['agents'].items(): - role_config = dict(agent_config) if agent_config else {} - if 'instructions' in role_config and 'backstory' not in role_config: - role_config['backstory'] = role_config['instructions'] - if 'role' not in role_config: - role_config['role'] = agent_name.replace('_', ' ').title() - if 'goal' not in role_config: - role_config['goal'] = role_config.get('backstory', 'Complete the assigned task') - if 'backstory' not in role_config: - role_config['backstory'] = f'You are a {role_config["role"]}' - config['roles'][agent_name] = role_config - - # Get workflow input: 'input' is canonical, 'topic' is alias for backward compatibility - topic = config.get('input', config.get('topic', '')) - - # Validate agents configuration for typos in field names - self._validate_agents_config(config) - - tools_dict = {} - - # Demand-driven tool resolution - only resolve tools actually used in YAML - if is_available("crewai") or is_available("autogen") or is_available("praisonaiagents") or is_available("ag2"): - try: - # Collect all tool names mentioned in the YAML config - needed_tools: set[str] = set() - for role_cfg in config.get('roles', {}).values(): - for t in role_cfg.get('tools') or []: - if isinstance(t, str) and t.strip(): - needed_tools.add(t.strip()) - for task_cfg in (role_cfg.get('tasks') or {}).values(): - if not isinstance(task_cfg, dict): - continue - for t in task_cfg.get('tools') or []: - if isinstance(t, str) and t.strip(): - needed_tools.add(t.strip()) - - # Resolve only the tools actually referenced in YAML - for tool_name in needed_tools: - try: - resolved_tool = self.tool_resolver.resolve(tool_name, instantiate=True) - if resolved_tool is not None: - tools_dict[tool_name] = resolved_tool - except Exception as e: - self.logger.warning(f"Failed to initialize tool '{tool_name}': {e}") - continue - - except Exception as e: - self.logger.warning(f"Error collecting YAML tool references: {e}") - - # Add tools from class names - use tool_resolver to check tool validity - for tool_class in self.tools: - if isinstance(tool_class, type): - try: - tool_instance = tool_class() - tool_name = tool_class.__name__ - tools_dict[tool_name] = tool_instance - self.logger.debug(f"Added tool: {tool_name}") - except Exception as e: - self.logger.warning(f"Failed to instantiate tool class {tool_class.__name__}: {e}") - - root_directory = os.getcwd() - tools_py_path = os.path.join(root_directory, 'tools.py') - tools_dir_path = Path(root_directory) / 'tools' - - # Use consolidated ToolResolver for tools.py loading - tools_dict.update(self.tool_resolver.get_local_tool_classes()) - if os.path.isfile(tools_py_path): - self.logger.debug("tools.py exists in the root directory. Loading tools.py and skipping tools folder.") - elif tools_dir_path.is_dir(): - tools_dict.update(self.tool_resolver.get_local_tool_classes_from_dir(tools_dir_path)) - if tools_dict: - self.logger.debug("tools folder exists in the root directory") - - framework = self.framework or config.get('framework', 'crewai') - - # AutoGen version selection logic - if framework == "autogen": - autogen_v4_adapter = self._get_framework_adapter("autogen_v4") - autogen_v2_adapter = self._get_framework_adapter("autogen") - - autogen_version = str( - config.get('autogen_version', os.environ.get("AUTOGEN_VERSION", "auto")) - ).lower() - use_v4 = False - - if autogen_version == "v0.4" and autogen_v4_adapter.is_available(): - use_v4 = True - elif autogen_version == "v0.2" and autogen_v2_adapter.is_available(): - use_v4 = False - elif autogen_version == "auto": - use_v4 = autogen_v4_adapter.is_available() - else: - use_v4 = autogen_v4_adapter.is_available() and not autogen_v2_adapter.is_available() - - framework = "autogen_v4" if use_v4 else "autogen" - - # Validate cli_backend compatibility - self._validate_cli_backend_compatibility(config, framework) - - # Get framework adapter and resolve to concrete variant - adapter = self._get_framework_adapter(framework).resolve() - - # Validate framework availability early - from .framework_adapters.validators import assert_framework_available - assert_framework_available(adapter.name) - - # Initialize observability hooks - from .observability.hooks import init_observability - init_observability(adapter.name) - - # Run adapter setup hooks - adapter.setup(framework_tag=adapter.name) - - # Update framework reference if resolution changed it - self.framework = adapter.name - self.framework_adapter = adapter - - self.logger.info(f"Using framework: {adapter.name}") - return config, adapter, tools_dict, topic async def agenerate_crew_and_kickoff(self): """ @@ -682,12 +722,15 @@ async def agenerate_crew_and_kickoff(self): async def _arun_framework(self, config): """Async version of _run_framework with shared preparation logic.""" - config, adapter, tools_dict, topic = self._prepare(config) - return await adapter.arun( - config, + # Use shared preparation logic + prep = self._prepare_for_run(config) + + self.logger.info(f"Using framework: {prep['adapter'].name}") + return await prep['adapter'].arun( + prep['config'], self.config_list, - topic, - tools_dict=tools_dict, + prep['topic'], + tools_dict=prep['tools_dict'], agent_callback=getattr(self, 'agent_callback', None), task_callback=getattr(self, 'task_callback', None), cli_config=getattr(self, 'cli_config', None), diff --git a/src/praisonai/praisonai/db/adapter.py b/src/praisonai/praisonai/db/adapter.py index c74b9aab3..c5c49041e 100644 --- a/src/praisonai/praisonai/db/adapter.py +++ b/src/praisonai/praisonai/db/adapter.py @@ -195,6 +195,7 @@ def on_user_message( metadata: Optional[Dict[str, Any]] = None, ) -> None: """Called when user sends a message.""" + self._init_stores() if not self._conversation_store: return @@ -218,6 +219,7 @@ def on_agent_message( metadata: Optional[Dict[str, Any]] = None, ) -> None: """Called when agent produces a response.""" + self._init_stores() if not self._conversation_store: return @@ -243,6 +245,7 @@ def on_tool_call( metadata: Optional[Dict[str, Any]] = None, ) -> None: """Called when a tool is executed.""" + self._init_stores() if not self._conversation_store: return @@ -273,6 +276,7 @@ def on_agent_end( metadata: Optional[Dict[str, Any]] = None, ) -> None: """Called when agent session ends.""" + self._init_stores() if not self._conversation_store: return @@ -290,6 +294,7 @@ def on_run_start( metadata: Optional[Dict[str, Any]] = None, ) -> None: """Called when a new run (turn) starts.""" + self._init_stores() # Store run start in state if available (even without conversation store) if self._state_store: run_key = f"run:{session_id}:{run_id}" @@ -316,6 +321,7 @@ def on_run_end( metadata: Optional[Dict[str, Any]] = None, ) -> None: """Called when a run (turn) ends.""" + self._init_stores() if self._state_store: run_key = f"run:{session_id}:{run_id}" run_data = self._state_store.get(run_key) or {} @@ -447,6 +453,7 @@ def on_trace_start( metadata: Optional[Dict[str, Any]] = None, ) -> None: """Called when a new trace starts.""" + self._init_stores() if self._state_store: trace_key = f"trace:{trace_id}" self._state_store.set(trace_key, { @@ -468,6 +475,7 @@ def on_trace_end( metadata: Optional[Dict[str, Any]] = None, ) -> None: """Called when a trace ends.""" + self._init_stores() if self._state_store: trace_key = f"trace:{trace_id}" trace_data = self._state_store.get(trace_key) or {} @@ -487,6 +495,7 @@ def on_span_start( attributes: Optional[Dict[str, Any]] = None, ) -> None: """Called when a new span starts.""" + self._init_stores() if self._state_store: span_key = f"span:{span_id}" self._state_store.set(span_key, { @@ -507,6 +516,7 @@ def on_span_end( attributes: Optional[Dict[str, Any]] = None, ) -> None: """Called when a span ends.""" + self._init_stores() if self._state_store: span_key = f"span:{span_id}" span_data = self._state_store.get(span_key) or {} diff --git a/src/praisonai/praisonai/framework_adapters/base.py b/src/praisonai/praisonai/framework_adapters/base.py index c7f3a8773..19d991d57 100644 --- a/src/praisonai/praisonai/framework_adapters/base.py +++ b/src/praisonai/praisonai/framework_adapters/base.py @@ -101,6 +101,8 @@ def cleanup(self) -> None: class BaseFrameworkAdapter: """Base class for framework adapters providing common functionality.""" + DEFAULT_MODEL = "openai/gpt-4o-mini" + def __init__(self): self._tool_registry: Dict[str, Any] = {} @@ -116,6 +118,24 @@ def list_tools(self) -> List[str]: """List all registered tool names.""" return list(self._tool_registry.keys()) + def _resolve_llm(self, spec, llm_config): + """Build a PraisonAIModel from a per-agent llm/function_calling_llm spec. + Accepts str, dict, or None. Single source of truth for all adapters.""" + from ..inc import PraisonAIModel + import os + + base = llm_config[0].get('base_url') if (llm_config and len(llm_config) > 0) else None + key = llm_config[0].get('api_key') if (llm_config and len(llm_config) > 0) else None + + if isinstance(spec, str) and spec.strip(): + model = spec.strip() + elif isinstance(spec, dict) and spec.get('model'): + model = spec['model'] + else: + model = os.environ.get("MODEL_NAME") or self.DEFAULT_MODEL + + return PraisonAIModel(model=model, base_url=base, api_key=key).get_model() + def _format_template(self, template: str, **kwargs) -> str: """Safely format template string with given kwargs, preserving JSON-like braces.""" if not isinstance(template, str): @@ -150,22 +170,8 @@ async def arun( cli_config: Optional[Dict[str, Any]] = None, ) -> str: """ - Async execution. Default implementation offloads sync run() to a worker thread. - - Sync-only adapters (crewai, autogen v0.2) can use this default. - Native-async adapters should override this method. - - Args: - config: Framework configuration - llm_config: LLM configuration list - topic: Topic for the tasks - tools_dict: Available tools dictionary - agent_callback: Callback for agent events - task_callback: Callback for task events - cli_config: CLI configuration - - Returns: - Execution result as string + Safe default for sync-only adapters (crewai, autogen v0.2): + run the sync implementation in a worker thread, freeing the loop. """ import asyncio return await asyncio.to_thread( @@ -175,6 +181,7 @@ async def arun( task_callback=task_callback, cli_config=cli_config ) + def cleanup(self) -> None: """Clean up resources - default implementation does nothing.""" pass diff --git a/src/praisonai/praisonai/framework_adapters/crewai_adapter.py b/src/praisonai/praisonai/framework_adapters/crewai_adapter.py index fd4d90d0d..958ad01a5 100644 --- a/src/praisonai/praisonai/framework_adapters/crewai_adapter.py +++ b/src/praisonai/praisonai/framework_adapters/crewai_adapter.py @@ -53,7 +53,6 @@ def run( import os from crewai import Agent, Task, Crew from crewai.telemetry import Telemetry - from ..inc import PraisonAIModel from .._framework_availability import is_available # Suppress crewai.cli.config logger (scoped to when CrewAI is actually used) @@ -75,33 +74,11 @@ def run( agent_tools = [tools_dict[tool] for tool in details.get('tools', []) if tools_dict and tool in tools_dict] - # Configure LLM - llm_model = details.get('llm') - if llm_model: - llm = PraisonAIModel( - model=llm_model.get("model") or os.environ.get("MODEL_NAME") or "openai/gpt-4o-mini", - base_url=llm_config[0].get('base_url') if llm_config else None, - api_key=llm_config[0].get('api_key') if llm_config else None - ).get_model() - else: - llm = PraisonAIModel( - base_url=llm_config[0].get('base_url') if llm_config else None, - api_key=llm_config[0].get('api_key') if llm_config else None - ).get_model() - - # Configure function calling LLM - function_calling_llm_model = details.get('function_calling_llm') - if function_calling_llm_model: - function_calling_llm = PraisonAIModel( - model=function_calling_llm_model.get("model") or os.environ.get("MODEL_NAME") or "openai/gpt-4o-mini", - base_url=llm_config[0].get('base_url') if llm_config else None, - api_key=llm_config[0].get('api_key') if llm_config else None - ).get_model() - else: - function_calling_llm = PraisonAIModel( - base_url=llm_config[0].get('base_url') if llm_config else None, - api_key=llm_config[0].get('api_key') if llm_config else None - ).get_model() + # Configure LLM using shared resolver + llm = self._resolve_llm(details.get('llm'), llm_config) + + # Configure function calling LLM using shared resolver + function_calling_llm = self._resolve_llm(details.get('function_calling_llm'), llm_config) # Create CrewAI agent with full feature set agent = Agent( diff --git a/src/praisonai/tests/unit/test_db_adapter_tracing_init.py b/src/praisonai/tests/unit/test_db_adapter_tracing_init.py new file mode 100644 index 000000000..7772209ac --- /dev/null +++ b/src/praisonai/tests/unit/test_db_adapter_tracing_init.py @@ -0,0 +1,142 @@ +""" +Unit tests for PraisonAIDB adapter tracing hooks store initialization. + +Tests that all tracing hooks (_init_stores() calls) prevent silent data loss +by ensuring stores are initialized before any write operations. +""" + +import unittest +from unittest.mock import Mock, patch +import time + + +class TestPraisonAIDBTracingInit(unittest.TestCase): + """Test that tracing hooks properly initialize stores before use.""" + + def setUp(self): + """Set up test fixtures.""" + from praisonai.db.adapter import PraisonAIDB + self.db = PraisonAIDB() + + def test_on_trace_start_calls_init_stores(self): + """Test that on_trace_start calls _init_stores() before writing.""" + # Mock _init_stores to track calls + with patch.object(self.db, '_init_stores') as mock_init: + # Mock _state_store to prevent actual writes + self.db._state_store = Mock() + + # Call tracing hook + self.db.on_trace_start( + trace_id="test-trace", + session_id="test-session", + agent_name="test-agent" + ) + + # Verify _init_stores was called before accessing _state_store + mock_init.assert_called_once() + self.db._state_store.set.assert_called_once() + + def test_on_trace_end_calls_init_stores(self): + """Test that on_trace_end calls _init_stores() before writing.""" + with patch.object(self.db, '_init_stores') as mock_init: + self.db._state_store = Mock() + self.db._state_store.get.return_value = {"metadata": {}} + + self.db.on_trace_end( + trace_id="test-trace", + status="completed" + ) + + mock_init.assert_called_once() + self.db._state_store.get.assert_called_once() + self.db._state_store.set.assert_called_once() + + def test_on_span_start_calls_init_stores(self): + """Test that on_span_start calls _init_stores() before writing.""" + with patch.object(self.db, '_init_stores') as mock_init: + self.db._state_store = Mock() + + self.db.on_span_start( + span_id="test-span", + trace_id="test-trace", + name="test-operation" + ) + + mock_init.assert_called_once() + self.db._state_store.set.assert_called_once() + + def test_on_span_end_calls_init_stores(self): + """Test that on_span_end calls _init_stores() before writing.""" + with patch.object(self.db, '_init_stores') as mock_init: + self.db._state_store = Mock() + self.db._state_store.get.return_value = {"attributes": {}} + + self.db.on_span_end( + span_id="test-span", + status="completed" + ) + + mock_init.assert_called_once() + self.db._state_store.get.assert_called_once() + self.db._state_store.set.assert_called_once() + + def test_tracing_hooks_prevent_silent_data_loss(self): + """Test that tracing hooks prevent silent data loss when stores not initialized.""" + # Test scenario where _init_stores() is critical + # Start with uninitialized adapter (no stores configured) + db = self.db + self.assertIsNone(db._state_store) + self.assertFalse(db._initialized) + + # Mock _init_stores to simulate successful initialization + def mock_init(): + db._state_store = Mock() + db._initialized = True + + with patch.object(db, '_init_stores', side_effect=mock_init): + # Call tracing hook - should not fail silently + db.on_trace_start("trace1", session_id="session1") + + # Verify store was initialized and data was written + self.assertIsNotNone(db._state_store) + self.assertTrue(db._initialized) + db._state_store.set.assert_called_once() + + def test_init_stores_idempotent(self): + """Test that _init_stores() can be called multiple times safely.""" + # Configure minimal database URLs + db = self.db + db._state_url = "memory://" + + # Mock the store creation to avoid actual dependencies + with patch('praisonai.persistence.factory.create_state_store') as mock_create: + mock_store = Mock() + mock_create.return_value = mock_store + + # Call _init_stores multiple times + db._init_stores() + db._init_stores() + db._init_stores() + + # Should only initialize once due to idempotent behavior + self.assertTrue(db._initialized) + self.assertEqual(db._state_store, mock_store) + + def test_tracing_without_state_url_graceful_degradation(self): + """Test that tracing hooks degrade gracefully when no state_url configured.""" + # No state_url configured - tracing should not crash + db = self.db + self.assertIsNone(db._state_url) + + # These should not raise exceptions + db.on_trace_start("trace1") + db.on_trace_end("trace1") + db.on_span_start("span1", "trace1", "operation") + db.on_span_end("span1") + + # Store should remain None (graceful degradation) + self.assertIsNone(db._state_store) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/src/praisonai/tests/unit/test_framework_adapter_simple.py b/src/praisonai/tests/unit/test_framework_adapter_simple.py new file mode 100644 index 000000000..25d4c58f3 --- /dev/null +++ b/src/praisonai/tests/unit/test_framework_adapter_simple.py @@ -0,0 +1,126 @@ +""" +Simplified unit tests for BaseFrameworkAdapter._resolve_llm() IndexError guard. + +Focuses on the core regression test: IndexError prevention with empty llm_config. +""" + +import unittest +from unittest.mock import Mock, patch + + +class TestBaseFrameworkAdapterIndexErrorFix(unittest.TestCase): + """Test that BaseFrameworkAdapter._resolve_llm() prevents IndexError.""" + + def setUp(self): + """Set up test fixtures.""" + from praisonai.framework_adapters.base import BaseFrameworkAdapter + self.adapter = BaseFrameworkAdapter() + + def test_resolve_llm_empty_llm_config_no_crash(self): + """Test that empty llm_config lists don't cause IndexError.""" + # This is the core regression test for the IndexError bug + with patch('praisonai.inc.PraisonAIModel') as mock_model_class: + mock_instance = Mock() + mock_model_class.return_value = mock_instance + mock_instance.get_model.return_value = "test_model" + + # Before the fix: this would raise IndexError on llm_config[0] + # After the fix: should handle empty list gracefully + try: + result = self.adapter._resolve_llm("gpt-4o-mini", []) + indexerror_fixed = True + except IndexError: + indexerror_fixed = False + + self.assertTrue(indexerror_fixed, "Empty llm_config should not cause IndexError") + + # Verify the method completed and returned a result + self.assertEqual(result, "test_model") + + def test_resolve_llm_none_llm_config_no_crash(self): + """Test that None llm_config doesn't cause IndexError.""" + with patch('praisonai.inc.PraisonAIModel') as mock_model_class: + mock_instance = Mock() + mock_model_class.return_value = mock_instance + mock_instance.get_model.return_value = "test_model" + + # Before the fix: this would raise IndexError on llm_config[0] + # After the fix: should handle None gracefully + try: + result = self.adapter._resolve_llm("gpt-4o-mini", None) + indexerror_fixed = True + except (IndexError, TypeError): + indexerror_fixed = False + + self.assertTrue(indexerror_fixed, "None llm_config should not cause IndexError") + + # Verify method returns result + self.assertEqual(result, "test_model") + + def test_resolve_llm_guards_check_length_before_access(self): + """Test that _resolve_llm checks llm_config length before accessing [0].""" + # This tests the specific fix: checking (llm_config and len(llm_config) > 0) + with patch('praisonai.inc.PraisonAIModel') as mock_model_class: + mock_instance = Mock() + mock_model_class.return_value = mock_instance + mock_instance.get_model.return_value = "test_model" + + # Test various edge cases that should not crash + test_cases = [ + [], # Empty list + None, # None value + [{}], # List with empty dict + [{"other_key": "value"}] # List with dict without expected keys + ] + + for llm_config in test_cases: + with self.subTest(llm_config=llm_config): + try: + result = self.adapter._resolve_llm("gpt-4o", llm_config) + # Should complete without IndexError + self.assertEqual(result, "test_model") + safe_access = True + except IndexError: + safe_access = False + + self.assertTrue(safe_access, f"llm_config {llm_config} should be handled safely") + + def test_resolve_llm_extracts_base_url_and_api_key_when_present(self): + """Test that base_url and api_key are extracted when llm_config is valid.""" + with patch('praisonai.inc.PraisonAIModel') as mock_model_class: + mock_instance = Mock() + mock_model_class.return_value = mock_instance + mock_instance.get_model.return_value = "test_model" + + # Test with valid llm_config + llm_config = [{"base_url": "https://test.api", "api_key": "test-key"}] + result = self.adapter._resolve_llm("gpt-4o", llm_config) + + # Verify PraisonAIModel was called with extracted values + mock_model_class.assert_called_with( + model="gpt-4o", + base_url="https://test.api", + api_key="test-key" + ) + + def test_crewai_adapter_string_llm_compatibility(self): + """Test that CrewAI adapter can handle string llm specs without crashing.""" + # This is a basic smoke test for string llm support in CrewAI adapter + try: + from praisonai.framework_adapters.crewai_adapter import CrewAIAdapter + adapter = CrewAIAdapter() + + # Test that the adapter has the _resolve_llm method (indicating it uses the fix) + has_resolve_llm = hasattr(adapter, '_resolve_llm') + self.assertTrue(has_resolve_llm, "CrewAI adapter should inherit _resolve_llm method") + + # Basic smoke test - adapter should not crash on instantiation + crewai_works = True + except Exception: + crewai_works = False + + self.assertTrue(crewai_works, "CrewAI adapter should be importable and instantiable") + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file