diff --git a/src/praisonai/praisonai/endpoints/registry.py b/src/praisonai/praisonai/endpoints/registry.py index 9202310e0..cc1ce09e2 100644 --- a/src/praisonai/praisonai/endpoints/registry.py +++ b/src/praisonai/praisonai/endpoints/registry.py @@ -1,131 +1,128 @@ """ Provider Registry -Central registry for provider adapters. +Central registry for provider adapters — unified with the rest of the wrapper. """ -from typing import Dict, List, Optional, Type +from __future__ import annotations + +import threading +from typing import Any, List, Optional, Type from .providers.base import BaseProvider +from .._registry import PluginRegistry -# Global provider registry -_providers: Dict[str, Type[BaseProvider]] = {} +def _recipe_loader(): + from .providers.recipe import RecipeProvider + return RecipeProvider -def register_provider(provider_type: str, provider_class: Type[BaseProvider]) -> None: - """ - Register a provider class. - - Args: - provider_type: Provider type identifier - provider_class: Provider class to register - """ - _providers[provider_type] = provider_class - - -def get_provider( - provider_type: str, - base_url: str = "http://localhost:8765", - api_key: Optional[str] = None, - **kwargs, -) -> Optional[BaseProvider]: - """ - Get a provider instance by type. - - Args: - provider_type: Provider type identifier - base_url: Base URL for the provider - api_key: Optional API key - **kwargs: Additional provider-specific arguments - - Returns: - Provider instance or None if not found - """ - # Lazy register built-in providers - _ensure_providers_registered() - - if provider_type not in _providers: - return None - - return _providers[provider_type](base_url=base_url, api_key=api_key, **kwargs) +def _agents_api_loader(): + from .providers.agents_api import AgentsAPIProvider + return AgentsAPIProvider -def list_provider_types() -> List[str]: - """ - List all registered provider types. - - Returns: - List of provider type identifiers - """ - _ensure_providers_registered() - return list(_providers.keys()) +def _mcp_loader(): + from .providers.mcp import MCPProvider + return MCPProvider -def get_provider_class(provider_type: str) -> Optional[Type[BaseProvider]]: - """ - Get a provider class by type. - - Args: - provider_type: Provider type identifier - - Returns: - Provider class or None if not found - """ - _ensure_providers_registered() - return _providers.get(provider_type) - - -def _ensure_providers_registered() -> None: - """Ensure built-in providers are registered.""" - if _providers: - return - - # Lazy import and register built-in providers - from .providers.recipe import RecipeProvider - from .providers.agents_api import AgentsAPIProvider - from .providers.mcp import MCPProvider +def _tools_mcp_loader(): from .providers.tools_mcp import ToolsMCPProvider + return ToolsMCPProvider + + +def _a2a_loader(): from .providers.a2a import A2AProvider + return A2AProvider + + +def _a2u_loader(): from .providers.a2u import A2UProvider - - register_provider("recipe", RecipeProvider) - register_provider("agents-api", AgentsAPIProvider) - register_provider("mcp", MCPProvider) - register_provider("tools-mcp", ToolsMCPProvider) - register_provider("a2a", A2AProvider) - register_provider("a2u", A2UProvider) - - -class ProviderRegistry: - """ - Provider registry class for managing provider instances. - - This class provides a more object-oriented interface to the provider registry. - """ - - def __init__(self): - """Initialize the registry.""" - _ensure_providers_registered() - + return A2UProvider + + +_BUILTIN_PROVIDERS = { + "recipe": _recipe_loader, + "agents-api": _agents_api_loader, + "mcp": _mcp_loader, + "tools-mcp": _tools_mcp_loader, + "a2a": _a2a_loader, + "a2u": _a2u_loader, +} + + +class ProviderRegistry(PluginRegistry[Type[BaseProvider]]): + """Endpoint provider registry — unified with the rest of the wrapper.""" + + def __init__(self) -> None: + super().__init__( + entry_point_group="praisonai.endpoint_providers", + builtins=_BUILTIN_PROVIDERS, + ) + def get( self, provider_type: str, base_url: str = "http://localhost:8765", api_key: Optional[str] = None, - **kwargs, + **kwargs: Any, ) -> Optional[BaseProvider]: - """Get a provider instance.""" - return get_provider(provider_type, base_url, api_key, **kwargs) - - def register(self, provider_type: str, provider_class: Type[BaseProvider]) -> None: - """Register a provider class.""" - register_provider(provider_type, provider_class) - + try: + cls = self.resolve(provider_type) + except ValueError: + # Distinguish between missing provider vs import failure + if provider_type.lower() not in self.list_all_names(): + return None + raise + return cls(base_url=base_url, api_key=api_key, **kwargs) + + # Backward compatibility methods - forward to parent class methods def list_types(self) -> List[str]: - """List all provider types.""" - return list_provider_types() - + """List available provider types. Backward compatibility alias for list_names().""" + return self.list_names() + def get_class(self, provider_type: str) -> Optional[Type[BaseProvider]]: - """Get a provider class.""" - return get_provider_class(provider_type) + """Get provider class by type. Backward compatibility alias for resolve().""" + try: + return self.resolve(provider_type) + except ValueError: + return None + + +_default_registry: Optional[ProviderRegistry] = None +_default_lock = threading.Lock() + + +def get_default_registry() -> ProviderRegistry: + global _default_registry + if _default_registry is None: + with _default_lock: + if _default_registry is None: + _default_registry = ProviderRegistry() + return _default_registry + + +# Module-level functions kept for backwards compat — now delegate to the registry +def register_provider(provider_type: str, provider_class: Type[BaseProvider]) -> None: + get_default_registry().register(provider_type, provider_class) + + +def get_provider(provider_type, base_url="http://localhost:8765", api_key=None, **kwargs): + return get_default_registry().get(provider_type, base_url=base_url, api_key=api_key, **kwargs) + + +def list_provider_types() -> List[str]: + return get_default_registry().list_names() + + +def get_provider_class(provider_type: str) -> Optional[Type[BaseProvider]]: + registry = get_default_registry() + try: + return registry.resolve(provider_type) + except ValueError: + # Distinguish between missing provider vs import failure + if provider_type.lower() not in registry.list_all_names(): + return None + raise diff --git a/src/praisonai/praisonai/scheduler/daemon_manager.py b/src/praisonai/praisonai/scheduler/daemon_manager.py index d08be3c04..330af5983 100644 --- a/src/praisonai/praisonai/scheduler/daemon_manager.py +++ b/src/praisonai/praisonai/scheduler/daemon_manager.py @@ -5,6 +5,7 @@ import sys import signal import subprocess +import asyncio from pathlib import Path from typing import Optional, Dict, List from datetime import datetime @@ -153,6 +154,41 @@ def stop_daemon(self, pid: int, timeout: int = 10) -> bool: except (OSError, ProcessLookupError): return False + + async def astop_daemon(self, pid: int, timeout: int = 10) -> bool: + """ + Async variant of stop_daemon — never blocks the event loop. + + Args: + pid: Process ID + timeout: Timeout in seconds + + Returns: + True if stopped successfully + """ + try: + # Try graceful shutdown first (SIGTERM) + os.kill(pid, signal.SIGTERM) + + # Wait for process to terminate + for _ in range(timeout * 10): + try: + os.kill(pid, 0) # Check if still alive + await asyncio.sleep(0.1) # cooperative wait + except (OSError, ProcessLookupError): + return True # Process terminated + + # Force kill if still alive + try: + os.kill(pid, signal.SIGKILL) + await asyncio.sleep(0.2) # Give it time to die + except (OSError, ProcessLookupError): + pass + + return True + + except (OSError, ProcessLookupError): + return False def get_status(self, pid: int) -> Optional[Dict]: """ diff --git a/src/praisonai/praisonai/scheduler/deployment.py b/src/praisonai/praisonai/scheduler/deployment.py index badd51df8..75c6b69ea 100644 --- a/src/praisonai/praisonai/scheduler/deployment.py +++ b/src/praisonai/praisonai/scheduler/deployment.py @@ -8,6 +8,7 @@ import logging import threading import time +import asyncio from typing import Optional, Dict, Any from abc import ABC, abstractmethod @@ -175,6 +176,37 @@ def deploy_once(self) -> bool: logger.error(f"One-time deployment failed: {e}") return False + async def adeploy_with_retry(self, max_retries: int = 3) -> bool: + """ + Async variant of deployment retry logic — never blocks the event loop. + + Args: + max_retries: Maximum number of retry attempts + + Returns: + True if deployment succeeded, False otherwise + """ + deployer = self._get_deployer() + + for attempt in range(max_retries): + try: + # Run blocking deploy() call in thread pool to avoid blocking event loop + if await asyncio.to_thread(deployer.deploy): + logger.info(f"Deployment successful on attempt {attempt + 1}") + return True + else: + logger.warning(f"Deployment failed on attempt {attempt + 1}") + except (OSError, RuntimeError, ConnectionError) as e: + logger.exception(f"Deployment error on attempt {attempt + 1}: {e}") + except Exception as e: + logger.exception(f"Unexpected deployment error on attempt {attempt + 1}: {e}") + + if attempt < max_retries - 1: + await asyncio.sleep(30) # Wait before retry (cooperative) + + logger.error(f"Deployment failed after {max_retries} attempts") + return False + def create_deployment_scheduler(provider: str = "gcp", config: Optional[Dict[str, Any]] = None) -> DeploymentScheduler: """ diff --git a/src/praisonai/praisonai/train/llm/trainer.py b/src/praisonai/praisonai/train/llm/trainer.py index 748b13743..de0f720b4 100644 --- a/src/praisonai/praisonai/train/llm/trainer.py +++ b/src/praisonai/praisonai/train/llm/trainer.py @@ -10,18 +10,41 @@ import os import sys import yaml -import torch import shutil import subprocess -from transformers import TextStreamer -from unsloth import FastLanguageModel, is_bfloat16_supported -from trl import SFTTrainer -from transformers import TrainingArguments -from datasets import load_dataset, concatenate_datasets -from psutil import virtual_memory -from unsloth.chat_templates import standardize_sharegpt, get_chat_template from functools import partial + +def _lazy_import_training_deps(): + """Import heavy training dependencies only when needed.""" + try: + import torch + from transformers import TextStreamer, TrainingArguments + from unsloth import FastLanguageModel, is_bfloat16_supported + from unsloth.chat_templates import standardize_sharegpt, get_chat_template + from trl import SFTTrainer + from datasets import load_dataset, concatenate_datasets + from psutil import virtual_memory + # Make available in global scope for the rest of the module + globals().update({ + 'torch': torch, + 'TextStreamer': TextStreamer, + 'FastLanguageModel': FastLanguageModel, + 'is_bfloat16_supported': is_bfloat16_supported, + 'SFTTrainer': SFTTrainer, + 'TrainingArguments': TrainingArguments, + 'load_dataset': load_dataset, + 'concatenate_datasets': concatenate_datasets, + 'virtual_memory': virtual_memory, + 'standardize_sharegpt': standardize_sharegpt, + 'get_chat_template': get_chat_template, + }) + except ImportError as e: + raise ImportError( + f"Training dependencies not available. Install with: " + f"pip install torch transformers unsloth datasets trl psutil. Error: {e}" + ) from e + ##################################### # Step 1: Formatting Raw Conversations ##################################### @@ -107,6 +130,7 @@ def tokenize_function(examples, hf_tokenizer, max_length): ##################################### class TrainModel: def __init__(self, config_path="config.yaml"): + _lazy_import_training_deps() self.load_config(config_path) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = None diff --git a/src/praisonai/praisonai/upload_vision.py b/src/praisonai/praisonai/upload_vision.py index c1cb16fb9..579c6974e 100644 --- a/src/praisonai/praisonai/upload_vision.py +++ b/src/praisonai/praisonai/upload_vision.py @@ -7,13 +7,25 @@ import os import yaml -import torch import shutil import subprocess -from unsloth import FastVisionModel + + +def _lazy_import_vision_upload_deps(): + """Import heavy vision deps only when needed (mirrors train_vision.py).""" + try: + import torch + from unsloth import FastVisionModel + globals().update({"torch": torch, "FastVisionModel": FastVisionModel}) + except ImportError as e: + raise ImportError( + "Vision upload dependencies missing. " + "Install with: pip install torch unsloth" + ) from e class UploadVisionModel: def __init__(self, config_path="config.yaml"): + _lazy_import_vision_upload_deps() self.load_config(config_path) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = None diff --git a/src/praisonai/tests/unit/test_async_daemon_deployment.py b/src/praisonai/tests/unit/test_async_daemon_deployment.py new file mode 100644 index 000000000..16f401968 --- /dev/null +++ b/src/praisonai/tests/unit/test_async_daemon_deployment.py @@ -0,0 +1,292 @@ +#!/usr/bin/env python +""" +Unit tests for async daemon and deployment methods. + +Tests that the new async methods (astop_daemon, adeploy_with_retry) properly +use asyncio.to_thread and never block the event loop. + +Per AGENTS.md §9: Both smoke tests and real agentic tests required. +""" + +import asyncio +import os +import signal +import time +import unittest +from unittest import mock +from unittest.mock import Mock, patch, AsyncMock + + +class TestAsyncDaemonDeployment(unittest.TestCase): + """Test async daemon and deployment functionality.""" + + def test_daemon_manager_astop_daemon_smoke(self): + """Smoke test: astop_daemon method exists and has correct signature.""" + from praisonai.praisonai.scheduler.daemon_manager import DaemonManager + + manager = DaemonManager() + + # Method should exist + self.assertTrue(hasattr(manager, 'astop_daemon')) + + # Should be a coroutine function + self.assertTrue(asyncio.iscoroutinefunction(manager.astop_daemon)) + + def test_deployment_scheduler_adeploy_with_retry_smoke(self): + """Smoke test: adeploy_with_retry method exists and has correct signature.""" + from praisonai.praisonai.scheduler.deployment import DeploymentScheduler + + scheduler = DeploymentScheduler() + + # Method should exist + self.assertTrue(hasattr(scheduler, 'adeploy_with_retry')) + + # Should be a coroutine function + self.assertTrue(asyncio.iscoroutinefunction(scheduler.adeploy_with_retry)) + + def test_astop_daemon_uses_async_sleep(self): + """Test that astop_daemon uses asyncio.sleep instead of time.sleep.""" + from praisonai.praisonai.scheduler.daemon_manager import DaemonManager + + manager = DaemonManager() + + # Mock os.kill to simulate process behavior + with patch('os.kill') as mock_kill: + # First call succeeds (SIGTERM), second call raises ProcessLookupError (process dead) + mock_kill.side_effect = [None, ProcessLookupError("Process not found")] + + # Mock asyncio.sleep to track calls + with patch('asyncio.sleep', new_callable=AsyncMock) as mock_sleep: + async def run_test(): + result = await manager.astop_daemon(12345, timeout=1) + return result + + # Run the test + result = asyncio.run(run_test()) + + # Should succeed + self.assertTrue(result) + + # Should have used asyncio.sleep, not time.sleep + mock_sleep.assert_called() + + # Verify os.kill was called with SIGTERM + mock_kill.assert_called_with(12345, signal.SIGTERM) + + def test_astop_daemon_timeout_behavior(self): + """Test astop_daemon timeout and escalation to SIGKILL.""" + from praisonai.praisonai.scheduler.daemon_manager import DaemonManager + + manager = DaemonManager() + + # Mock os.kill to simulate stubborn process + with patch('os.kill') as mock_kill: + # Process stays alive for timeout duration, then we kill it + call_count = 0 + def kill_side_effect(pid, sig): + nonlocal call_count + call_count += 1 + if call_count <= 10: # Process stays alive for first 10 checks + return None + elif sig == signal.SIGKILL: # Dies after SIGKILL + raise ProcessLookupError("Process killed") + return None + + mock_kill.side_effect = kill_side_effect + + with patch('asyncio.sleep', new_callable=AsyncMock) as mock_sleep: + async def run_test(): + result = await manager.astop_daemon(12345, timeout=1) + return result + + result = asyncio.run(run_test()) + + # Should succeed after escalation + self.assertTrue(result) + + # Should have called SIGTERM first, then SIGKILL + kill_calls = mock_kill.call_args_list + self.assertTrue(any(call[0][1] == signal.SIGTERM for call in kill_calls)) + self.assertTrue(any(call[0][1] == signal.SIGKILL for call in kill_calls)) + + def test_adeploy_with_retry_uses_asyncio_to_thread(self): + """Test that adeploy_with_retry uses asyncio.to_thread for blocking calls.""" + from praisonai.praisonai.scheduler.deployment import DeploymentScheduler + + scheduler = DeploymentScheduler() + + # Mock the deployer + mock_deployer = Mock() + mock_deployer.deploy.return_value = True + scheduler._deployer = mock_deployer + + # Mock asyncio.to_thread to track calls + with patch('asyncio.to_thread', new_callable=AsyncMock) as mock_to_thread: + mock_to_thread.return_value = True + + with patch('asyncio.sleep', new_callable=AsyncMock): + async def run_test(): + result = await scheduler.adeploy_with_retry(max_retries=1) + return result + + result = asyncio.run(run_test()) + + # Should succeed + self.assertTrue(result) + + # Should have used asyncio.to_thread + mock_to_thread.assert_called_once() + + # Verify it was called with the deploy method + call_args = mock_to_thread.call_args[0] + self.assertEqual(call_args[0], mock_deployer.deploy) + + def test_adeploy_with_retry_retry_logic(self): + """Test adeploy_with_retry retry logic and asyncio.sleep usage.""" + from praisonai.praisonai.scheduler.deployment import DeploymentScheduler + + scheduler = DeploymentScheduler() + + # Mock deployer that fails twice then succeeds + mock_deployer = Mock() + call_count = 0 + def deploy_side_effect(): + nonlocal call_count + call_count += 1 + return call_count >= 3 # Fail first 2 times, succeed on 3rd + + mock_deployer.deploy.side_effect = deploy_side_effect + scheduler._deployer = mock_deployer + + with patch('asyncio.to_thread', new_callable=AsyncMock) as mock_to_thread: + # Make asyncio.to_thread call the actual method + async def to_thread_side_effect(func): + return func() + mock_to_thread.side_effect = to_thread_side_effect + + with patch('asyncio.sleep', new_callable=AsyncMock) as mock_sleep: + async def run_test(): + result = await scheduler.adeploy_with_retry(max_retries=3) + return result + + result = asyncio.run(run_test()) + + # Should eventually succeed + self.assertTrue(result) + + # Should have called asyncio.sleep between retries (2 times for 3 attempts) + self.assertEqual(mock_sleep.call_count, 2) + + # Sleep should be called with 30 seconds + mock_sleep.assert_called_with(30) + + def test_adeploy_with_retry_max_retries_exhausted(self): + """Test adeploy_with_retry when all retries are exhausted.""" + from praisonai.praisonai.scheduler.deployment import DeploymentScheduler + + scheduler = DeploymentScheduler() + + # Mock deployer that always fails + mock_deployer = Mock() + mock_deployer.deploy.return_value = False + scheduler._deployer = mock_deployer + + with patch('asyncio.to_thread', new_callable=AsyncMock) as mock_to_thread: + mock_to_thread.return_value = False + + with patch('asyncio.sleep', new_callable=AsyncMock) as mock_sleep: + async def run_test(): + result = await scheduler.adeploy_with_retry(max_retries=2) + return result + + result = asyncio.run(run_test()) + + # Should fail after exhausting retries + self.assertFalse(result) + + # Should have made max_retries attempts + self.assertEqual(mock_to_thread.call_count, 2) + + # Should sleep between retries (retries - 1 times) + self.assertEqual(mock_sleep.call_count, 1) + + def test_async_methods_never_block_event_loop(self): + """Integration test: verify async methods don't block event loop.""" + from praisonai.praisonai.scheduler.daemon_manager import DaemonManager + from praisonai.praisonai.scheduler.deployment import DeploymentScheduler + + # This test runs multiple async operations concurrently + # If any method blocks the event loop, this will hang or timeout + + manager = DaemonManager() + scheduler = DeploymentScheduler() + + # Mock dependencies + mock_deployer = Mock() + mock_deployer.deploy.return_value = True + scheduler._deployer = mock_deployer + + async def concurrent_test(): + # Start multiple async operations that should run concurrently + with patch('os.kill', side_effect=ProcessLookupError("Process not found")): + with patch('asyncio.to_thread', new_callable=AsyncMock) as mock_to_thread: + mock_to_thread.return_value = True + + with patch('asyncio.sleep', new_callable=AsyncMock): + # Run operations concurrently + tasks = [ + manager.astop_daemon(123), + manager.astop_daemon(124), + scheduler.adeploy_with_retry(max_retries=1), + scheduler.adeploy_with_retry(max_retries=1) + ] + + results = await asyncio.gather(*tasks) + + # All should succeed + self.assertTrue(all(results)) + + # Run with a reasonable timeout - if event loop blocks, this will timeout + start_time = time.time() + asyncio.run(asyncio.wait_for(concurrent_test(), timeout=5.0)) + elapsed = time.time() - start_time + + # Should complete quickly since operations run concurrently + self.assertLess(elapsed, 2.0, "Async operations took too long - possible event loop blocking") + + def test_exception_handling_in_async_methods(self): + """Test proper exception handling in async methods.""" + from praisonai.praisonai.scheduler.daemon_manager import DaemonManager + from praisonai.praisonai.scheduler.deployment import DeploymentScheduler + + manager = DaemonManager() + scheduler = DeploymentScheduler() + + # Test astop_daemon exception handling + with patch('os.kill', side_effect=OSError("Permission denied")): + async def test_daemon_exception(): + result = await manager.astop_daemon(12345) + return result + + result = asyncio.run(test_daemon_exception()) + # Should handle exception gracefully + self.assertFalse(result) + + # Test adeploy_with_retry exception handling + mock_deployer = Mock() + mock_deployer.deploy.side_effect = RuntimeError("Deployment failed") + scheduler._deployer = mock_deployer + + with patch('asyncio.to_thread', side_effect=RuntimeError("Deployment failed")): + with patch('asyncio.sleep', new_callable=AsyncMock): + async def test_deploy_exception(): + result = await scheduler.adeploy_with_retry(max_retries=1) + return result + + result = asyncio.run(test_deploy_exception()) + # Should handle exception gracefully and continue retries + self.assertFalse(result) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/src/praisonai/tests/unit/test_lazy_imports.py b/src/praisonai/tests/unit/test_lazy_imports.py new file mode 100644 index 000000000..843749849 --- /dev/null +++ b/src/praisonai/tests/unit/test_lazy_imports.py @@ -0,0 +1,185 @@ +#!/usr/bin/env python +""" +Unit tests for lazy import behavior in architectural fixes. + +Tests that heavy ML dependencies (torch, transformers, unsloth, etc.) are not loaded +during package import time, only when actually needed. + +Per AGENTS.md §9: Both smoke tests and real agentic tests required. +""" + +import sys +import time +import unittest +from unittest import mock + + +class TestLazyImports(unittest.TestCase): + """Test lazy import patterns for performance compliance.""" + + def test_trainer_lazy_imports_smoke(self): + """Smoke test: trainer module imports quickly without heavy deps.""" + # Remove any cached imports to get fresh import time + modules_to_remove = [ + 'praisonai.praisonai.train.llm.trainer', + 'torch', 'transformers', 'unsloth', 'trl', 'datasets', 'psutil' + ] + for module in modules_to_remove: + if module in sys.modules: + del sys.modules[module] + + start_time = time.time() + + # Import should be fast without triggering heavy deps + from praisonai.praisonai.train.llm.trainer import TrainModel + + import_time = time.time() - start_time + + # Should import in well under 200ms per AGENTS.md performance target + self.assertLess(import_time, 0.2, + f"Import took {import_time:.3f}s, exceeds 200ms target") + + # Heavy deps should NOT be in globals yet + import praisonai.praisonai.train.llm.trainer as trainer_module + trainer_globals = dir(trainer_module) + + # These should not be available until _lazy_import_training_deps() is called + heavy_deps = ['torch', 'FastLanguageModel', 'SFTTrainer', 'load_dataset'] + for dep in heavy_deps: + self.assertNotIn(dep, trainer_globals, + f"Heavy dependency '{dep}' loaded too early") + + def test_upload_vision_lazy_imports_smoke(self): + """Smoke test: upload_vision module imports quickly without heavy deps.""" + # Remove any cached imports + modules_to_remove = [ + 'praisonai.praisonai.upload_vision', + 'torch', 'unsloth' + ] + for module in modules_to_remove: + if module in sys.modules: + del sys.modules[module] + + start_time = time.time() + + # Import should be fast + from praisonai.praisonai.upload_vision import UploadVisionModel + + import_time = time.time() - start_time + + # Should import quickly + self.assertLess(import_time, 0.2, + f"Import took {import_time:.3f}s, exceeds 200ms target") + + def test_trainer_lazy_loading_mechanism(self): + """Test that lazy import mechanism works correctly when instantiated.""" + # Mock the heavy dependencies to avoid actually importing them + with mock.patch.dict('sys.modules', { + 'torch': mock.MagicMock(), + 'transformers': mock.MagicMock(), + 'unsloth': mock.MagicMock(), + 'trl': mock.MagicMock(), + 'datasets': mock.MagicMock(), + 'psutil': mock.MagicMock() + }): + # Mock the specific imports that would be loaded + mock_torch = mock.MagicMock() + mock_TextStreamer = mock.MagicMock() + mock_FastLanguageModel = mock.MagicMock() + + with mock.patch.multiple('sys.modules', + torch=mock_torch, + **{'transformers.TextStreamer': mock_TextStreamer} + ): + # Mock yaml.safe_load to avoid needing actual config file + with mock.patch('yaml.safe_load', return_value={'model_name': 'test'}): + with mock.patch('builtins.open', mock.mock_open(read_data='model_name: test')): + from praisonai.praisonai.train.llm.trainer import TrainModel + + # Creating instance should trigger lazy import + trainer = TrainModel() + + # Verify trainer was created successfully + self.assertIsNotNone(trainer) + self.assertEqual(trainer.config['model_name'], 'test') + + def test_upload_vision_lazy_loading_mechanism(self): + """Test that upload vision lazy import works correctly.""" + # Mock the heavy dependencies + mock_torch = mock.MagicMock() + mock_FastVisionModel = mock.MagicMock() + + with mock.patch.dict('sys.modules', { + 'torch': mock_torch, + 'unsloth': mock.MagicMock() + }): + # Mock the specific classes that would be imported + with mock.patch('unsloth.FastVisionModel', mock_FastVisionModel): + from praisonai.praisonai.upload_vision import UploadVisionModel + + # Creating instance should trigger lazy import and succeed + vision_model = UploadVisionModel() + self.assertIsNotNone(vision_model) + + def test_import_error_handling_trainer(self): + """Test that import errors are handled gracefully with proper exception chaining.""" + # Mock missing dependencies + with mock.patch.dict('sys.modules', {}, clear=True): + with mock.patch('builtins.__import__', side_effect=ImportError("torch not found")): + from praisonai.praisonai.train.llm.trainer import TrainModel + + with mock.patch('yaml.safe_load', return_value={'model_name': 'test'}): + with mock.patch('builtins.open', mock.mock_open(read_data='model_name: test')): + # Should raise ImportError with helpful message + with self.assertRaises(ImportError) as context: + TrainModel() + + error_msg = str(context.exception) + self.assertIn("Training dependencies not available", error_msg) + self.assertIn("pip install torch transformers unsloth", error_msg) + + # Exception chaining should be preserved (from e) + self.assertIsNotNone(context.exception.__cause__) + + def test_import_error_handling_vision(self): + """Test that vision upload import errors are handled gracefully.""" + # Mock missing dependencies + with mock.patch.dict('sys.modules', {}, clear=True): + with mock.patch('builtins.__import__', side_effect=ImportError("unsloth not found")): + from praisonai.praisonai.upload_vision import UploadVisionModel + + # Should raise ImportError with helpful message + with self.assertRaises(ImportError) as context: + UploadVisionModel() + + error_msg = str(context.exception) + self.assertIn("Vision upload dependencies not available", error_msg) + self.assertIn("pip install torch unsloth", error_msg) + + # Exception chaining should be preserved (from e) + self.assertIsNotNone(context.exception.__cause__) + + def test_package_import_performance_target(self): + """Test that praisonai package meets < 200ms import target.""" + # This is the critical performance test per AGENTS.md + + # Clear any cached modules that might affect timing + praisonai_modules = [m for m in sys.modules.keys() if m.startswith('praisonai')] + for module in praisonai_modules: + if 'test' not in module: # Don't remove test modules + del sys.modules[module] + + start_time = time.time() + + # Import the main package + import praisonai + + total_import_time = time.time() - start_time + + # Must meet AGENTS.md performance requirement + self.assertLess(total_import_time, 0.2, + f"Package import took {total_import_time:.3f}s, exceeds 200ms target") + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/src/praisonai/tests/unit/test_provider_registry_compat.py b/src/praisonai/tests/unit/test_provider_registry_compat.py new file mode 100644 index 000000000..ef4f1a1b8 --- /dev/null +++ b/src/praisonai/tests/unit/test_provider_registry_compat.py @@ -0,0 +1,236 @@ +#!/usr/bin/env python +""" +Unit tests for ProviderRegistry backward compatibility. + +Tests that the ProviderRegistry architectural refactor maintains 100% backward +compatibility with existing APIs while adding new PluginRegistry features. + +Per AGENTS.md §9: Both smoke tests and real agentic tests required. +""" + +import threading +import unittest +from unittest import mock +from typing import Type, Optional + +import praisonai.praisonai.endpoints.registry as registry_module +from praisonai.praisonai.endpoints.providers.base import BaseProvider + + +class MockProvider(BaseProvider): + """Mock provider for testing.""" + + def __init__(self, base_url="http://localhost:8765", api_key=None, **kwargs): + self.base_url = base_url + self.api_key = api_key + self.kwargs = kwargs + + +class TestProviderRegistryBackwardCompatibility(unittest.TestCase): + """Test backward compatibility of ProviderRegistry refactor.""" + + def setUp(self): + """Reset registry state for each test.""" + # Reset the default registry + registry_module._default_registry = None + + def test_provider_registry_smoke(self): + """Smoke test: basic registry operations work.""" + from praisonai.praisonai.endpoints.registry import ProviderRegistry + + registry = ProviderRegistry() + self.assertIsNotNone(registry) + + # Should have expected methods from PluginRegistry inheritance + self.assertTrue(hasattr(registry, 'register')) + self.assertTrue(hasattr(registry, 'resolve')) + self.assertTrue(hasattr(registry, 'list_names')) + + def test_backward_compat_instance_methods(self): + """Test that old ProviderRegistry instance methods still work.""" + from praisonai.praisonai.endpoints.registry import ProviderRegistry + + registry = ProviderRegistry() + + # Register a test provider + registry.register("test", MockProvider) + + # Backward compatibility methods should work + self.assertTrue(hasattr(registry, 'list_types')) + self.assertTrue(hasattr(registry, 'get_class')) + + # list_types() should return provider types + types = registry.list_types() + self.assertIsInstance(types, list) + self.assertIn("test", types) + + # get_class() should return provider class + provider_class = registry.get_class("test") + self.assertEqual(provider_class, MockProvider) + + # get_class() should return None for missing provider + missing_class = registry.get_class("nonexistent") + self.assertIsNone(missing_class) + + def test_module_level_functions_preserved(self): + """Test that all module-level functions still work.""" + # These functions should exist and delegate to default registry + functions_to_test = [ + 'register_provider', + 'get_provider', + 'list_provider_types', + 'get_provider_class' + ] + + for func_name in functions_to_test: + self.assertTrue(hasattr(registry_module, func_name), + f"Missing backward compat function: {func_name}") + + def test_register_provider_module_function(self): + """Test module-level register_provider function.""" + # Register via module function + registry_module.register_provider("test_module", MockProvider) + + # Should be available via other module functions + types = registry_module.list_provider_types() + self.assertIn("test_module", types) + + provider_class = registry_module.get_provider_class("test_module") + self.assertEqual(provider_class, MockProvider) + + def test_get_provider_module_function(self): + """Test module-level get_provider function.""" + # Register a provider + registry_module.register_provider("test_get", MockProvider) + + # Get provider instance via module function + provider = registry_module.get_provider( + "test_get", + base_url="http://test:8080", + api_key="secret123", + extra_param="value" + ) + + self.assertIsInstance(provider, MockProvider) + self.assertEqual(provider.base_url, "http://test:8080") + self.assertEqual(provider.api_key, "secret123") + self.assertEqual(provider.kwargs["extra_param"], "value") + + def test_get_provider_class_error_handling(self): + """Test proper error handling in get_provider_class function.""" + # Test missing provider returns None + result = registry_module.get_provider_class("definitely_missing") + self.assertIsNone(result) + + # Test import error handling + registry = registry_module.get_default_registry() + + # Mock a provider that fails to import + def failing_loader(): + raise ImportError("Mock import failure") + + registry.register("failing_provider", failing_loader) + + # Should raise the import error (not return None) + with self.assertRaises(ImportError): + registry_module.get_provider_class("failing_provider") + + def test_default_registry_singleton_thread_safety(self): + """Test that default registry singleton is thread-safe.""" + results = {} + + def get_registry_in_thread(thread_id): + registry = registry_module.get_default_registry() + results[thread_id] = id(registry) + + # Create multiple threads accessing default registry + threads = [] + for i in range(10): + thread = threading.Thread(target=get_registry_in_thread, args=(i,)) + threads.append(thread) + thread.start() + + # Wait for all threads + for thread in threads: + thread.join() + + # All threads should get the same registry instance + registry_ids = list(results.values()) + self.assertEqual(len(set(registry_ids)), 1, + "Multiple registry instances created - not thread-safe") + + def test_registry_get_method_error_handling(self): + """Test error handling in ProviderRegistry.get() method.""" + from praisonai.praisonai.endpoints.registry import ProviderRegistry + + registry = ProviderRegistry() + + # Missing provider should return None + result = registry.get("missing_provider") + self.assertIsNone(result) + + # Mock a provider that fails to import + def failing_loader(): + raise ImportError("Mock import failure") + + registry.register("failing_provider", failing_loader) + + # Import error should be raised (not swallowed) + with self.assertRaises(ImportError): + registry.get("failing_provider") + + def test_builtin_providers_loaded(self): + """Test that builtin providers are properly loaded.""" + from praisonai.praisonai.endpoints.registry import get_default_registry + + registry = get_default_registry() + available_types = registry.list_names() + + # Should have the expected builtin providers + expected_builtins = ["recipe", "agents-api", "mcp", "tools-mcp", "a2a", "a2u"] + + for builtin in expected_builtins: + self.assertIn(builtin, available_types, + f"Missing builtin provider: {builtin}") + + def test_provider_instantiation_end_to_end(self): + """Test complete provider instantiation workflow.""" + # Register a custom provider + registry_module.register_provider("custom", MockProvider) + + # Get provider instance with various parameters + provider = registry_module.get_provider( + "custom", + base_url="https://custom.api.com", + api_key="sk-custom123", + timeout=30, + retries=3 + ) + + # Verify provider was created correctly + self.assertIsInstance(provider, MockProvider) + self.assertEqual(provider.base_url, "https://custom.api.com") + self.assertEqual(provider.api_key, "sk-custom123") + self.assertEqual(provider.kwargs["timeout"], 30) + self.assertEqual(provider.kwargs["retries"], 3) + + def test_try_create_backward_compatibility(self): + """Test backward compatibility of try_create pattern if it existed.""" + # This tests the pattern mentioned in the original request + from praisonai.praisonai.endpoints.registry import ProviderRegistry + + registry = ProviderRegistry() + registry.register("test_create", MockProvider) + + # Test successful creation + provider = registry.get("test_create", api_key="test123") + self.assertIsNotNone(provider) + self.assertEqual(provider.api_key, "test123") + + # Test failed creation (missing provider) + provider = registry.get("nonexistent") + self.assertIsNone(provider) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file