Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 102 additions & 105 deletions src/praisonai/praisonai/endpoints/registry.py
Original file line number Diff line number Diff line change
@@ -1,131 +1,128 @@
"""
Provider Registry

Central registry for provider adapters.
Central registry for provider adapters — unified with the rest of the wrapper.
"""

from typing import Dict, List, Optional, Type
from __future__ import annotations

import threading
from typing import Any, List, Optional, Type

from .providers.base import BaseProvider
from .._registry import PluginRegistry


# Global provider registry
_providers: Dict[str, Type[BaseProvider]] = {}
def _recipe_loader():
from .providers.recipe import RecipeProvider
return RecipeProvider


def register_provider(provider_type: str, provider_class: Type[BaseProvider]) -> None:
"""
Register a provider class.

Args:
provider_type: Provider type identifier
provider_class: Provider class to register
"""
_providers[provider_type] = provider_class


def get_provider(
provider_type: str,
base_url: str = "http://localhost:8765",
api_key: Optional[str] = None,
**kwargs,
) -> Optional[BaseProvider]:
"""
Get a provider instance by type.

Args:
provider_type: Provider type identifier
base_url: Base URL for the provider
api_key: Optional API key
**kwargs: Additional provider-specific arguments

Returns:
Provider instance or None if not found
"""
# Lazy register built-in providers
_ensure_providers_registered()

if provider_type not in _providers:
return None

return _providers[provider_type](base_url=base_url, api_key=api_key, **kwargs)
def _agents_api_loader():
from .providers.agents_api import AgentsAPIProvider
return AgentsAPIProvider


def list_provider_types() -> List[str]:
"""
List all registered provider types.

Returns:
List of provider type identifiers
"""
_ensure_providers_registered()
return list(_providers.keys())
def _mcp_loader():
from .providers.mcp import MCPProvider
return MCPProvider


def get_provider_class(provider_type: str) -> Optional[Type[BaseProvider]]:
"""
Get a provider class by type.

Args:
provider_type: Provider type identifier

Returns:
Provider class or None if not found
"""
_ensure_providers_registered()
return _providers.get(provider_type)


def _ensure_providers_registered() -> None:
"""Ensure built-in providers are registered."""
if _providers:
return

# Lazy import and register built-in providers
from .providers.recipe import RecipeProvider
from .providers.agents_api import AgentsAPIProvider
from .providers.mcp import MCPProvider
def _tools_mcp_loader():
from .providers.tools_mcp import ToolsMCPProvider
return ToolsMCPProvider


def _a2a_loader():
from .providers.a2a import A2AProvider
return A2AProvider


def _a2u_loader():
from .providers.a2u import A2UProvider

register_provider("recipe", RecipeProvider)
register_provider("agents-api", AgentsAPIProvider)
register_provider("mcp", MCPProvider)
register_provider("tools-mcp", ToolsMCPProvider)
register_provider("a2a", A2AProvider)
register_provider("a2u", A2UProvider)


class ProviderRegistry:
"""
Provider registry class for managing provider instances.

This class provides a more object-oriented interface to the provider registry.
"""

def __init__(self):
"""Initialize the registry."""
_ensure_providers_registered()

return A2UProvider


_BUILTIN_PROVIDERS = {
"recipe": _recipe_loader,
"agents-api": _agents_api_loader,
"mcp": _mcp_loader,
"tools-mcp": _tools_mcp_loader,
"a2a": _a2a_loader,
"a2u": _a2u_loader,
}


class ProviderRegistry(PluginRegistry[Type[BaseProvider]]):
"""Endpoint provider registry — unified with the rest of the wrapper."""

def __init__(self) -> None:
super().__init__(
entry_point_group="praisonai.endpoint_providers",
builtins=_BUILTIN_PROVIDERS,
)

def get(
self,
provider_type: str,
base_url: str = "http://localhost:8765",
api_key: Optional[str] = None,
**kwargs,
**kwargs: Any,
) -> Optional[BaseProvider]:
"""Get a provider instance."""
return get_provider(provider_type, base_url, api_key, **kwargs)

def register(self, provider_type: str, provider_class: Type[BaseProvider]) -> None:
"""Register a provider class."""
register_provider(provider_type, provider_class)

try:
cls = self.resolve(provider_type)
except ValueError:
# Distinguish between missing provider vs import failure
if provider_type.lower() not in self.list_all_names():
return None
raise
return cls(base_url=base_url, api_key=api_key, **kwargs)
Comment thread
greptile-apps[bot] marked this conversation as resolved.

# Backward compatibility methods - forward to parent class methods
def list_types(self) -> List[str]:
"""List all provider types."""
return list_provider_types()
"""List available provider types. Backward compatibility alias for list_names()."""
return self.list_names()

def get_class(self, provider_type: str) -> Optional[Type[BaseProvider]]:
"""Get a provider class."""
return get_provider_class(provider_type)
"""Get provider class by type. Backward compatibility alias for resolve()."""
try:
return self.resolve(provider_type)
except ValueError:
return None


_default_registry: Optional[ProviderRegistry] = None
_default_lock = threading.Lock()


def get_default_registry() -> ProviderRegistry:
global _default_registry
if _default_registry is None:
with _default_lock:
if _default_registry is None:
_default_registry = ProviderRegistry()
return _default_registry


# Module-level functions kept for backwards compat — now delegate to the registry
def register_provider(provider_type: str, provider_class: Type[BaseProvider]) -> None:
get_default_registry().register(provider_type, provider_class)


def get_provider(provider_type, base_url="http://localhost:8765", api_key=None, **kwargs):
return get_default_registry().get(provider_type, base_url=base_url, api_key=api_key, **kwargs)


def list_provider_types() -> List[str]:
return get_default_registry().list_names()


def get_provider_class(provider_type: str) -> Optional[Type[BaseProvider]]:
registry = get_default_registry()
try:
return registry.resolve(provider_type)
except ValueError:
# Distinguish between missing provider vs import failure
if provider_type.lower() not in registry.list_all_names():
return None
raise
36 changes: 36 additions & 0 deletions src/praisonai/praisonai/scheduler/daemon_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import sys
import signal
import subprocess
import asyncio
from pathlib import Path
from typing import Optional, Dict, List
from datetime import datetime
Expand Down Expand Up @@ -153,6 +154,41 @@ def stop_daemon(self, pid: int, timeout: int = 10) -> bool:

except (OSError, ProcessLookupError):
return False

async def astop_daemon(self, pid: int, timeout: int = 10) -> bool:
"""
Async variant of stop_daemon — never blocks the event loop.

Args:
pid: Process ID
timeout: Timeout in seconds

Returns:
True if stopped successfully
"""
try:
# Try graceful shutdown first (SIGTERM)
os.kill(pid, signal.SIGTERM)

# Wait for process to terminate
for _ in range(timeout * 10):
try:
os.kill(pid, 0) # Check if still alive
await asyncio.sleep(0.1) # cooperative wait
except (OSError, ProcessLookupError):
return True # Process terminated

# Force kill if still alive
try:
os.kill(pid, signal.SIGKILL)
await asyncio.sleep(0.2) # Give it time to die
except (OSError, ProcessLookupError):
pass

return True

except (OSError, ProcessLookupError):
return False

def get_status(self, pid: int) -> Optional[Dict]:
"""
Expand Down
32 changes: 32 additions & 0 deletions src/praisonai/praisonai/scheduler/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import logging
import threading
import time
import asyncio
from typing import Optional, Dict, Any
from abc import ABC, abstractmethod

Expand Down Expand Up @@ -175,6 +176,37 @@ def deploy_once(self) -> bool:
logger.error(f"One-time deployment failed: {e}")
return False

async def adeploy_with_retry(self, max_retries: int = 3) -> bool:
"""
Async variant of deployment retry logic — never blocks the event loop.

Args:
max_retries: Maximum number of retry attempts

Returns:
True if deployment succeeded, False otherwise
"""
deployer = self._get_deployer()

for attempt in range(max_retries):
try:
# Run blocking deploy() call in thread pool to avoid blocking event loop
if await asyncio.to_thread(deployer.deploy):
logger.info(f"Deployment successful on attempt {attempt + 1}")
return True
else:
logger.warning(f"Deployment failed on attempt {attempt + 1}")
except (OSError, RuntimeError, ConnectionError) as e:
logger.exception(f"Deployment error on attempt {attempt + 1}: {e}")
except Exception as e:
logger.exception(f"Unexpected deployment error on attempt {attempt + 1}: {e}")

if attempt < max_retries - 1:
await asyncio.sleep(30) # Wait before retry (cooperative)
Comment thread
greptile-apps[bot] marked this conversation as resolved.

logger.error(f"Deployment failed after {max_retries} attempts")
return False
Comment on lines +179 to +208

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical | 🏗️ Heavy lift

Blocking synchronous call defeats async purpose.

Line 193 calls deployer.deploy() which is a synchronous, potentially blocking operation. The DeployHandlerAdapter.deploy() method (lines 34-59) imports DeployHandler and calls handler.handle_deploy(), which likely performs I/O, network requests, or subprocess management. This will block the event loop, defeating the stated goal of the method to "never block the event loop."

To make this truly async, consider one of these approaches:

  1. Run the blocking call in a thread pool (quickest fix):
if await asyncio.to_thread(deployer.deploy):
  1. Create an async DeployerInterface (better long-term):
class DeployerInterface(ABC):
    `@abstractmethod`
    async def adeploy(self) -> bool:
        """Execute deployment asynchronously."""
        pass
  1. Document the blocking behavior if async deployment is deferred to future work.
🧰 Tools
🪛 Ruff (0.15.15)

[warning] 198-198: Do not catch blind exception: Exception

(BLE001)

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/praisonai/praisonai/scheduler/deployment.py` around lines 179 - 205,
adeploy_with_retry currently calls the blocking deployer.deploy() (via
DeployHandlerAdapter.deploy which calls DeployHandler.handle_deploy) and thus
can block the event loop; replace that synchronous call with an asynchronous
thread-execution so the loop stays cooperative (e.g., call the blocking function
via asyncio.to_thread and await it: await asyncio.to_thread(deployer.deploy) and
use its boolean result), or as a longer-term refactor add an async Deployer
interface (e.g., DeployerInterface.adeploy and make DeployHandlerAdapter provide
adeploy that awaits DeployHandler.handle_deploy) and call await
deployer.adeploy() from adeploy_with_retry. Ensure logging and retry logic
remain unchanged.



def create_deployment_scheduler(provider: str = "gcp", config: Optional[Dict[str, Any]] = None) -> DeploymentScheduler:
"""
Expand Down
40 changes: 32 additions & 8 deletions src/praisonai/praisonai/train/llm/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,41 @@
import os
import sys
import yaml
import torch
import shutil
import subprocess
from transformers import TextStreamer
from unsloth import FastLanguageModel, is_bfloat16_supported
from trl import SFTTrainer
from transformers import TrainingArguments
from datasets import load_dataset, concatenate_datasets
from psutil import virtual_memory
from unsloth.chat_templates import standardize_sharegpt, get_chat_template
from functools import partial


def _lazy_import_training_deps():
"""Import heavy training dependencies only when needed."""
try:
import torch
from transformers import TextStreamer, TrainingArguments
from unsloth import FastLanguageModel, is_bfloat16_supported
from unsloth.chat_templates import standardize_sharegpt, get_chat_template
from trl import SFTTrainer
from datasets import load_dataset, concatenate_datasets
from psutil import virtual_memory
# Make available in global scope for the rest of the module
globals().update({
'torch': torch,
'TextStreamer': TextStreamer,
'FastLanguageModel': FastLanguageModel,
'is_bfloat16_supported': is_bfloat16_supported,
'SFTTrainer': SFTTrainer,
'TrainingArguments': TrainingArguments,
'load_dataset': load_dataset,
'concatenate_datasets': concatenate_datasets,
'virtual_memory': virtual_memory,
'standardize_sharegpt': standardize_sharegpt,
'get_chat_template': get_chat_template,
})
except ImportError as e:
raise ImportError(
f"Training dependencies not available. Install with: "
f"pip install torch transformers unsloth datasets trl psutil. Error: {e}"
) from e

#####################################
# Step 1: Formatting Raw Conversations
#####################################
Expand Down Expand Up @@ -107,6 +130,7 @@ def tokenize_function(examples, hf_tokenizer, max_length):
#####################################
class TrainModel:
def __init__(self, config_path="config.yaml"):
_lazy_import_training_deps()
self.load_config(config_path)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = None
Expand Down
Loading
Loading