Skip to content

Commit 97032e1

Browse files
authored
Add initial DSPy reliability tests (#1773)
* fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * Fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * format Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * Update run_tests.yml --------- Signed-off-by: dbczumar <[email protected]>
1 parent 4822d47 commit 97032e1

File tree

7 files changed

+425
-1
lines changed

7 files changed

+425
-1
lines changed

.github/workflows/run_tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ jobs:
6262
with:
6363
args: check --fix-only
6464
- name: Run tests with pytest
65-
run: poetry run pytest tests/
65+
run: poetry run pytest tests/ --ignore=tests/reliability
6666

6767
build_poetry:
6868
name: Build Poetry

tests/reliability/README.md

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# DSPy Reliability Tests
2+
3+
This directory contains reliability tests for DSPy programs. The purpose of these tests is to verify that DSPy programs reliabily produce expected outputs across multiple large language models (LLMs), regardless of model size or capability. These tests are designed to ensure that DSPy programs maintain robustness and accuracy across diverse LLM configurations.
4+
5+
### Overview
6+
7+
Each test in this directory executes a DSPy program using various LLMs. By running the same tests across different models, these tests help validate that DSPy programs handle a wide range of inputs effectively and produce reliable outputs, even in cases where the model might struggle with the input or task.
8+
9+
### Key Features
10+
11+
- **Diverse LLMs**: Each DSPy program is tested with multiple LLMs, ranging from smaller models to more advanced, high-performance models. This approach allows us to assess the consistency and generality of DSPy program outputs across different model capabilities.
12+
- **Challenging and Adversarial Tests**: Some of the tests are intentionally challenging or adversarial, crafted to push the boundaries of DSPy. These challenging cases allow us to gauge the robustness of DSPy and identify areas for potential improvement.
13+
- **Cross-Model Compatibility**: By testing with different LLMs, we aim to ensure that DSPy programs perform well across model types and configurations, reducing model-specific edge cases and enhancing program versatility.
14+
15+
### Running the Tests
16+
17+
- First, populate the configuration file `reliability_tests_conf.yaml` (located in this directory) with the necessary LiteLLM model/provider names and access credentials for 1. each LLM you want to test and 2. the LLM judge that you want to use for assessing the correctness of outputs in certain test cases. These should be placed in the `litellm_params` section for each model in the defined `model_list`. You can also use `litellm_params` to specify values for LLM hyperparameters like `temperature`. Any model that lacks configured `litellm_params` in the configuration file will be ignored during testing.
18+
19+
The configuration must also specify a DSPy adapter to use when testing, e.g. `"chat"` (for `dspy.ChatAdapter`) or `"json"` (for `dspy.JSONAdapter`).
20+
21+
An example of `reliability_tests_conf.yaml`:
22+
23+
```yaml
24+
adapter: chat
25+
model_list:
26+
# The model to use for judging the correctness of program
27+
# outputs throughout reliability test suites. We recommend using
28+
# a high quality model as the judge, such as OpenAI GPT-4o
29+
- model_name: "judge"
30+
litellm_params:
31+
model: "openai/gpt-4o"
32+
api_key: "<my_openai_api_key>"
33+
- model_name: "gpt-4o"
34+
litellm_params:
35+
model: "openai/gpt-4o"
36+
api_key: "<my_openai_api_key>"
37+
- model_name: "claude-3.5-sonnet"
38+
litellm_params:
39+
model: "anthropic/claude-3.5"
40+
api_key: "<my_anthropic_api_key>"
41+
42+
- Second, to run the tests, run the following command from this directory:
43+
44+
```bash
45+
pytest .
46+
```
47+
48+
This will execute all tests for the configured models and display detailed results for each model configuration. Tests are set up to mark expected failures for known challenging cases where a specific model might struggle, while actual (unexpected) DSPy reliability issues are flagged as failures (see below).
49+
50+
### Known Failing Models
51+
52+
Some tests may be expected to fail with certain models, especially in challenging cases. These known failures are logged but do not affect the overall test result. This setup allows us to keep track of model-specific limitations without obstructing general test outcomes. Models that are known to fail a particular test case are specified using the `@known_failing_models` decorator. For example:
53+
54+
```
55+
@known_failing_models(["llama-3.2-3b-instruct"])
56+
def test_program_with_complex_deeply_nested_output_structure():
57+
...
58+
```

tests/reliability/__init__.py

Whitespace-only changes.

tests/reliability/conftest.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import os
2+
3+
import pytest
4+
5+
import dspy
6+
from tests.conftest import clear_settings
7+
from tests.reliability.utils import parse_reliability_conf_yaml
8+
9+
# Standard list of models that should be used for periodic DSPy reliability testing
10+
MODEL_LIST = [
11+
"gpt-4o",
12+
"gpt-4o-mini",
13+
"gpt-4-turbo",
14+
"gpt-o1-preview",
15+
"gpt-o1-mini",
16+
"claude-3.5-sonnet",
17+
"claude-3.5-haiku",
18+
"gemini-1.5-pro",
19+
"gemini-1.5-flash",
20+
"llama-3.1-405b-instruct",
21+
"llama-3.1-70b-instruct",
22+
"llama-3.1-8b-instruct",
23+
"llama-3.2-3b-instruct",
24+
]
25+
26+
27+
def pytest_generate_tests(metafunc):
28+
"""
29+
Hook to parameterize reliability test cases with each model defined in the
30+
reliability tests YAML configuration
31+
"""
32+
known_failing_models = getattr(metafunc.function, "_known_failing_models", [])
33+
34+
if "configure_model" in metafunc.fixturenames:
35+
params = [(model, model in known_failing_models) for model in MODEL_LIST]
36+
ids = [f"{model}" for model, _ in params] # Custom IDs for display
37+
metafunc.parametrize("configure_model", params, indirect=True, ids=ids)
38+
39+
40+
@pytest.fixture(autouse=True)
41+
def configure_model(request):
42+
"""
43+
Fixture to configure the DSPy library with a particular configured model and adapter
44+
before executing a test case.
45+
"""
46+
module_dir = os.path.dirname(os.path.abspath(__file__))
47+
conf_path = os.path.join(module_dir, "reliability_conf.yaml")
48+
reliability_conf = parse_reliability_conf_yaml(conf_path)
49+
50+
if reliability_conf.adapter.lower() == "chat":
51+
adapter = dspy.ChatAdapter()
52+
elif reliability_conf.adapter.lower() == "json":
53+
adapter = dspy.JSONAdapter()
54+
else:
55+
raise ValueError(f"Unknown adapter specification '{adapter}' in reliability_conf.yaml")
56+
57+
model_name, should_ignore_failure = request.param
58+
model_params = reliability_conf.models.get(model_name)
59+
if model_params:
60+
lm = dspy.LM(**model_params)
61+
dspy.configure(lm=lm, adapter=adapter)
62+
else:
63+
pytest.skip(
64+
f"Skipping test because no reliability testing YAML configuration was found" f" for model {model_name}."
65+
)
66+
67+
# Store `should_ignore_failure` flag on the request node for use in post-test handling
68+
request.node.should_ignore_failure = should_ignore_failure
69+
request.node.model_name = model_name
70+
71+
72+
@pytest.hookimpl(tryfirst=True, hookwrapper=True)
73+
def pytest_runtest_makereport(item, call):
74+
"""
75+
Hook to conditionally ignore failures in a given test case for known failing models.
76+
"""
77+
outcome = yield
78+
rep = outcome.get_result()
79+
80+
should_ignore_failure = getattr(item, "should_ignore_failure", False)
81+
82+
if should_ignore_failure and rep.failed:
83+
rep.outcome = "passed"
84+
rep.wasxfail = "Ignoring failure for known failing model"
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
adapter: chat
2+
model_list:
3+
# The model to use for judging the correctness of program
4+
# outputs throughout reliability test suites. We recommend using
5+
# a high quality model as the judge, such as OpenAI GPT-4o
6+
- model_name: "judge"
7+
litellm_params:
8+
# model: "<litellm_provider>/<litellm_model_name>"
9+
# api_key: "api key"
10+
# api_base: "<api_base>"
11+
- model_name: "gpt-4o"
12+
litellm_params:
13+
# model: "<litellm_provider>/<litellm_model_name>"
14+
# api_key: "api key"
15+
# api_base: "<api_base>"
16+
- model_name: "gpt-4o-mini"
17+
litellm_params:
18+
# model: "<litellm_provider>/<litellm_model_name>"
19+
# api_key: "api key"
20+
# api_base: "<api_base>"
21+
- model_name: "gpt-4-turbo"
22+
litellm_params:
23+
# model: "<litellm_provider>/<litellm_model_name>"
24+
# api_key: "api key"
25+
# api_base: "<api_base>"
26+
- model_name: "gpt-o1"
27+
litellm_params:
28+
# model: "<litellm_provider>/<litellm_model_name>"
29+
# api_key: "api key"
30+
# api_base: "<api_base>"
31+
- model_name: "gpt-o1-mini"
32+
litellm_params:
33+
# model: "<litellm_provider>/<litellm_model_name>"
34+
# api_key: "api key"
35+
# api_base: "<api_base>"
36+
- model_name: "claude-3.5-sonnet"
37+
litellm_params:
38+
# model: "<litellm_provider>/<litellm_model_name>"
39+
# api_key: "api key"
40+
# api_base: "<api_base>"
41+
- model_name: "claude-3.5-haiku"
42+
litellm_params:
43+
# model: "<litellm_provider>/<litellm_model_name>"
44+
# api_key: "api key"
45+
# api_base: "<api_base>"
46+
- model_name: "gemini-1.5-pro"
47+
litellm_params:
48+
# model: "<litellm_provider>/<litellm_model_name>"
49+
# api_key: "api key"
50+
# api_base: "<api_base>"
51+
- model_name: "gemini-1.5-flash"
52+
litellm_params:
53+
# model: "<litellm_provider>/<litellm_model_name>"
54+
# api_key: "api key"
55+
# api_base: "<api_base>"
56+
- model_name: "llama-3.1-405b-instruct"
57+
litellm_params:
58+
# model: "<litellm_provider>/<litellm_model_name>"
59+
# api_key: "api key"
60+
# api_base: "<api_base>"
61+
- model_name: "llama-3.1-70b-instruct"
62+
litellm_params:
63+
# model: "<litellm_provider>/<litellm_model_name>"
64+
# api_key: "api key"
65+
# api_base: "<api_base>"
66+
- model_name: "llama-3.1-8b-instruct"
67+
litellm_params:
68+
# model: "<litellm_provider>/<litellm_model_name>"
69+
# api_key: "api key"
70+
# api_base: "<api_base>"
71+
- model_name: "llama-3.2-3b-instruct"
72+
litellm_params:
73+
# model: "<litellm_provider>/<litellm_model_name>"
74+
# api_key: "api key"
75+
# api_base: "<api_base>"
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
from enum import Enum
2+
from typing import List
3+
4+
import pydantic
5+
6+
import dspy
7+
from tests.reliability.utils import assert_program_output_correct, known_failing_models
8+
9+
10+
def test_qa_with_pydantic_answer_model():
11+
class Answer(pydantic.BaseModel):
12+
value: str
13+
certainty: float = pydantic.Field(
14+
description="A value between 0 and 1 indicating the model's confidence in the answer."
15+
)
16+
comments: List[str] = pydantic.Field(
17+
description="At least two comments providing additional details about the answer."
18+
)
19+
20+
class QA(dspy.Signature):
21+
question: str = dspy.InputField()
22+
answer: Answer = dspy.OutputField()
23+
24+
program = dspy.Predict(QA)
25+
answer = program(question="What is the capital of France?").answer
26+
27+
assert_program_output_correct(
28+
program_output=answer.value,
29+
grading_guidelines="The answer should be Paris. Answer should not contain extraneous information.",
30+
)
31+
assert_program_output_correct(
32+
program_output=answer.comments, grading_guidelines="The comments should be relevant to the answer"
33+
)
34+
assert answer.certainty >= 0
35+
assert answer.certainty <= 1
36+
assert len(answer.comments) >= 2
37+
38+
39+
def test_color_classification_using_enum():
40+
Color = Enum("Color", ["RED", "GREEN", "BLUE"])
41+
42+
class Colorful(dspy.Signature):
43+
text: str = dspy.InputField()
44+
color: Color = dspy.OutputField()
45+
46+
program = dspy.Predict(Colorful)
47+
color = program(text="The sky is blue").color
48+
49+
assert color == Color.BLUE
50+
51+
52+
def test_entity_extraction_with_multiple_primitive_outputs():
53+
class ExtractEntityFromDescriptionOutput(pydantic.BaseModel):
54+
entity_hu: str = pydantic.Field(description="The extracted entity in Hungarian, cleaned and lowercased.")
55+
entity_en: str = pydantic.Field(description="The English translation of the extracted Hungarian entity.")
56+
is_inverted: bool = pydantic.Field(
57+
description="Boolean flag indicating if the input is connected in an inverted way."
58+
)
59+
categories: str = pydantic.Field(description="English categories separated by '|' to which the entity belongs.")
60+
review: bool = pydantic.Field(
61+
description="Boolean flag indicating low confidence or uncertainty in the extraction."
62+
)
63+
64+
class ExtractEntityFromDescription(dspy.Signature):
65+
"""Extract an entity from a Hungarian description, provide its English translation, categories, and an inverted flag."""
66+
67+
description: str = dspy.InputField(description="The input description in Hungarian.")
68+
entity: ExtractEntityFromDescriptionOutput = dspy.OutputField(
69+
description="The extracted entity and its properties."
70+
)
71+
72+
program = dspy.ChainOfThought(ExtractEntityFromDescription)
73+
74+
extracted_entity = program(description="A kávé egy növényi eredetű ital, amelyet a kávébabból készítenek.").entity
75+
assert_program_output_correct(
76+
program_output=extracted_entity.entity_hu,
77+
grading_guidelines="The translation of the text into English should be equivalent to 'coffee'",
78+
)
79+
assert_program_output_correct(
80+
program_output=extracted_entity.entity_hu,
81+
grading_guidelines="The text should be equivalent to 'coffee'",
82+
)
83+
assert_program_output_correct(
84+
program_output=extracted_entity.categories,
85+
grading_guidelines=(
86+
"The text should contain English language categories that apply to the word 'coffee'."
87+
" The categories should be separated by the character '|'."
88+
),
89+
)

0 commit comments

Comments
 (0)