diff --git a/ps_fuzz/attacks/rag_poisoning.py b/ps_fuzz/attacks/rag_poisoning.py index 2e64360..dc6747e 100644 --- a/ps_fuzz/attacks/rag_poisoning.py +++ b/ps_fuzz/attacks/rag_poisoning.py @@ -52,7 +52,7 @@ def _suppress_loggers(logger_names): MISSING_PACKAGES.append("langchain-community (embeddings)") try: - from langchain.schema import Document + from langchain_core.documents import Document except ImportError: DEPENDENCIES_AVAILABLE = False MISSING_PACKAGES.append("langchain (schema)") diff --git a/ps_fuzz/chat_clients.py b/ps_fuzz/chat_clients.py index 4473963..aa08e22 100644 --- a/ps_fuzz/chat_clients.py +++ b/ps_fuzz/chat_clients.py @@ -1,7 +1,7 @@ from .langchain_integration import get_langchain_chat_models_info from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.outputs.llm_result import LLMResult -from langchain.schema import BaseMessage, HumanMessage, SystemMessage, AIMessage +from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage, AIMessage from typing import List, Dict, Any, Optional from abc import ABC, abstractmethod import sys diff --git a/ps_fuzz/langchain_integration.py b/ps_fuzz/langchain_integration.py index 8291a00..85b3ee1 100644 --- a/ps_fuzz/langchain_integration.py +++ b/ps_fuzz/langchain_integration.py @@ -1,5 +1,5 @@ from langchain_core.language_models.chat_models import BaseChatModel -import langchain.chat_models +import langchain_community.chat_models as chat_models_module from typing import Any, Dict, get_origin, Optional import inspect, re @@ -85,16 +85,20 @@ def get_langchain_chat_models_info() -> Dict[str, Dict[str, Any]]: Introspects a langchain library, extracting information about supported chat models and required/optional parameters """ models: Dict[str, ChatModelInfo] = {} - for model_cls_name in langchain.chat_models.__all__: + for model_cls_name in chat_models_module.__all__: if model_cls_name in EXCLUDED_CHAT_MODELS: continue - model_cls = langchain.chat_models.__dict__.get(model_cls_name) + model_cls = chat_models_module.__dict__.get(model_cls_name) if model_cls and issubclass(model_cls, BaseChatModel): model_short_name = camel_to_snake(model_cls.__name__).replace('_chat', '').replace('chat_', '') # Introspect supported model parameters + # Support both Pydantic v1 (__fields__) and v2 (model_fields) params: Dict[str, ChatModelParams] = {} - for param_name, field in model_cls.__fields__.items(): + fields = getattr(model_cls, 'model_fields', None) or getattr(model_cls, '__fields__', {}) + for param_name, field in fields.items(): if param_name in CHAT_MODEL_EXCLUDED_PARAMS: continue - typ = field.outer_type_ + # Pydantic v2 uses field.annotation, v1 uses field.outer_type_ + typ = getattr(field, 'annotation', None) or getattr(field, 'outer_type_', None) + if typ is None: continue if typ not in [str, float, int, bool] and get_origin(typ) not in [str, float, int, bool]: continue doc_lines = _get_class_member_doc(model_cls, param_name) description = ''.join(doc_lines) if doc_lines else None diff --git a/pyproject.toml b/pyproject.toml index 6f76b2b..ca69828 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,12 +26,12 @@ dependencies = [ "setuptools>=61.0", "httpx>=0.24.0,<0.25.0", "openai==1.6.1", - "langchain==0.0.353", - "langchain-community==0.0.7", - "langchain-core==0.1.4", + "langchain>=0.3.0,<0.4.0", + "langchain-community>=0.3.0,<0.4.0", + "langchain-core>=0.3.81,<0.4.0", "argparse==1.4.0", "python-dotenv==1.0.0", - "tqdm==4.66.1", + "tqdm>=4.66.3", "colorama==0.4.6", "prettytable==3.10.0", "pandas==2.2.2", diff --git a/setup.py b/setup.py index 854b0ea..2baf766 100755 --- a/setup.py +++ b/setup.py @@ -23,22 +23,20 @@ "Topic :: Software Development :: Quality Assurance", "License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11" ], - python_requires='>=3.7', + python_requires='>=3.9', install_requires=[ "httpx>=0.24.0,<0.25.0", "openai==1.6.1", - "langchain==0.0.353", - "langchain-community==0.0.7", - "langchain-core==0.1.4", + "langchain>=0.3.0,<0.4.0", + "langchain-community>=0.3.0,<0.4.0", + "langchain-core>=0.3.81,<0.4.0", "argparse==1.4.0", "python-dotenv==1.0.0", - "tqdm==4.66.1", + "tqdm>=4.66.3", "colorama==0.4.6", "prettytable==3.10.0", "pandas==2.2.2", diff --git a/tests/test_chat_clients.py b/tests/test_chat_clients.py index 475f029..81c9a0e 100644 --- a/tests/test_chat_clients.py +++ b/tests/test_chat_clients.py @@ -8,7 +8,7 @@ from typing import Dict, List from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.outputs import LLMResult, ChatResult, ChatGeneration -from langchain_core.pydantic_v1 import Field +from pydantic import Field # Fake LangChain model class FakeChatModel(BaseChatModel):