Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ps_fuzz/attacks/rag_poisoning.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _suppress_loggers(logger_names):
MISSING_PACKAGES.append("langchain-community (embeddings)")

try:
from langchain.schema import Document
from langchain_core.documents import Document
except ImportError:
DEPENDENCIES_AVAILABLE = False
MISSING_PACKAGES.append("langchain (schema)")
Expand Down
2 changes: 1 addition & 1 deletion ps_fuzz/chat_clients.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .langchain_integration import get_langchain_chat_models_info
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.outputs.llm_result import LLMResult
from langchain.schema import BaseMessage, HumanMessage, SystemMessage, AIMessage
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage, AIMessage
from typing import List, Dict, Any, Optional
from abc import ABC, abstractmethod
import sys
Expand Down
14 changes: 9 additions & 5 deletions ps_fuzz/langchain_integration.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from langchain_core.language_models.chat_models import BaseChatModel
import langchain.chat_models
import langchain_community.chat_models as chat_models_module
from typing import Any, Dict, get_origin, Optional
import inspect, re

Expand Down Expand Up @@ -85,16 +85,20 @@ def get_langchain_chat_models_info() -> Dict[str, Dict[str, Any]]:
Introspects a langchain library, extracting information about supported chat models and required/optional parameters
"""
models: Dict[str, ChatModelInfo] = {}
for model_cls_name in langchain.chat_models.__all__:
for model_cls_name in chat_models_module.__all__:
if model_cls_name in EXCLUDED_CHAT_MODELS: continue
model_cls = langchain.chat_models.__dict__.get(model_cls_name)
model_cls = chat_models_module.__dict__.get(model_cls_name)
if model_cls and issubclass(model_cls, BaseChatModel):
model_short_name = camel_to_snake(model_cls.__name__).replace('_chat', '').replace('chat_', '')
# Introspect supported model parameters
# Support both Pydantic v1 (__fields__) and v2 (model_fields)
params: Dict[str, ChatModelParams] = {}
for param_name, field in model_cls.__fields__.items():
fields = getattr(model_cls, 'model_fields', None) or getattr(model_cls, '__fields__', {})
for param_name, field in fields.items():
if param_name in CHAT_MODEL_EXCLUDED_PARAMS: continue
typ = field.outer_type_
# Pydantic v2 uses field.annotation, v1 uses field.outer_type_
typ = getattr(field, 'annotation', None) or getattr(field, 'outer_type_', None)
if typ is None: continue
if typ not in [str, float, int, bool] and get_origin(typ) not in [str, float, int, bool]: continue
doc_lines = _get_class_member_doc(model_cls, param_name)
description = ''.join(doc_lines) if doc_lines else None
Expand Down
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ dependencies = [
"setuptools>=61.0",
"httpx>=0.24.0,<0.25.0",
"openai==1.6.1",
"langchain==0.0.353",
"langchain-community==0.0.7",
"langchain-core==0.1.4",
"langchain>=0.3.0,<0.4.0",
"langchain-community>=0.3.0,<0.4.0",
"langchain-core>=0.3.81,<0.4.0",
"argparse==1.4.0",
"python-dotenv==1.0.0",
"tqdm==4.66.1",
"tqdm>=4.66.3",
"colorama==0.4.6",
"prettytable==3.10.0",
"pandas==2.2.2",
Expand Down
12 changes: 5 additions & 7 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,20 @@
"Topic :: Software Development :: Quality Assurance",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11"
],
python_requires='>=3.7',
python_requires='>=3.9',
install_requires=[
"httpx>=0.24.0,<0.25.0",
"openai==1.6.1",
"langchain==0.0.353",
"langchain-community==0.0.7",
"langchain-core==0.1.4",
"langchain>=0.3.0,<0.4.0",
"langchain-community>=0.3.0,<0.4.0",
"langchain-core>=0.3.81,<0.4.0",
"argparse==1.4.0",
"python-dotenv==1.0.0",
"tqdm==4.66.1",
"tqdm>=4.66.3",
"colorama==0.4.6",
"prettytable==3.10.0",
"pandas==2.2.2",
Expand Down
2 changes: 1 addition & 1 deletion tests/test_chat_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Dict, List
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.outputs import LLMResult, ChatResult, ChatGeneration
from langchain_core.pydantic_v1 import Field
from pydantic import Field

# Fake LangChain model
class FakeChatModel(BaseChatModel):
Expand Down
Loading