Skip to content

Commit b369247

Browse files
Add BasicAgent solver
1 parent ed56310 commit b369247

46 files changed

Lines changed: 2617 additions & 1496 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

project/common/preparedness_turn_completer/preparedness_turn_completer/oai_completions_turn_completer.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import functools
4-
from typing import Any, Iterable, Literal, Unpack
4+
from typing import Any, Literal, Unpack
55

66
import openai
77
import structlog
@@ -15,12 +15,11 @@
1515
from openai.types.completion_usage import CompletionUsage
1616
from preparedness_turn_completer.turn_completer import TurnCompleter
1717
from preparedness_turn_completer.utils import (
18-
DEFAULT_RETRY_CONFIG,
1918
RetryConfig,
2019
get_model_context_window_length,
2120
warn_about_non_empty_params,
2221
)
23-
from pydantic import BaseModel, ConfigDict, field_validator
22+
from pydantic import BaseModel, ConfigDict, Field, field_validator
2423

2524
logger = structlog.stdlib.get_logger(component=__name__)
2625

@@ -34,9 +33,9 @@ def __init__(
3433
temperature: float | None | NotGiven = NOT_GIVEN,
3534
max_tokens: int | None | NotGiven = NOT_GIVEN,
3635
top_p: float | None | NotGiven = NOT_GIVEN,
37-
tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
36+
tools: list[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
3837
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
39-
retry_config: RetryConfig = DEFAULT_RETRY_CONFIG,
38+
retry_config: RetryConfig | None = None,
4039
):
4140
self.model = model
4241
self.reasoning_effort = reasoning_effort
@@ -47,7 +46,7 @@ def __init__(
4746
self.tools = tools
4847
self.tool_choice = tool_choice
4948
self.encoding_name: str
50-
self.retry_config = retry_config
49+
self.retry_config = retry_config or RetryConfig()
5150
try:
5251
self.encoding_name = tiktoken.encoding_name_for_model(model)
5352
except KeyError:
@@ -74,9 +73,9 @@ class Config(TurnCompleter.Config):
7473
temperature: float | None | NotGiven = NOT_GIVEN
7574
max_tokens: int | None | NotGiven = NOT_GIVEN
7675
top_p: float | None | NotGiven = NOT_GIVEN
77-
tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN
76+
tools: list[ChatCompletionToolParam] | NotGiven = NOT_GIVEN
7877
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN
79-
retry_config: RetryConfig = DEFAULT_RETRY_CONFIG
78+
retry_config: RetryConfig = Field(default_factory=RetryConfig)
8079

8180
def build(self) -> OpenAICompletionsTurnCompleter:
8281
return OpenAICompletionsTurnCompleter(

project/common/preparedness_turn_completer/preparedness_turn_completer/oai_responses_turn_completer/completer.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import functools
4-
from typing import Any, Iterable, Unpack
4+
from typing import Any, Literal, Unpack
55

66
import openai
77
import structlog
@@ -19,27 +19,34 @@
1919
)
2020
from preparedness_turn_completer.turn_completer import TurnCompleter
2121
from preparedness_turn_completer.utils import (
22-
DEFAULT_RETRY_CONFIG,
2322
RetryConfig,
2423
get_model_context_window_length,
2524
warn_about_non_empty_params,
2625
)
27-
from pydantic import BaseModel, ConfigDict, field_validator
26+
from pydantic import BaseModel, ConfigDict, Field, field_validator
2827

2928
logger = structlog.stdlib.get_logger(component=__name__)
3029

3130

31+
class ReasoningConfig(BaseModel):
32+
"""chz-friendly wrapper around openai.types.shared_params.reasoning:Reasoning"""
33+
34+
effort: Literal["minimal", "low", "medium", "high"] | None = None
35+
generate_summary: Literal["auto", "concise", "detailed"] | None = None
36+
summary: Literal["auto", "concise", "detailed"] | None = None
37+
38+
3239
class OpenAIResponsesTurnCompleter(TurnCompleter):
3340
def __init__(
3441
self,
3542
model: str,
3643
reasoning: Reasoning | None | NotGiven = NOT_GIVEN,
3744
text_format: type[BaseModel] | NotGiven = NOT_GIVEN,
38-
tools: Iterable[ParseableToolParam] | NotGiven = NOT_GIVEN,
45+
tools: list[ParseableToolParam] | NotGiven = NOT_GIVEN,
3946
temperature: float | None | NotGiven = NOT_GIVEN,
4047
max_output_tokens: int | None | NotGiven = NOT_GIVEN,
4148
top_p: float | None | NotGiven = NOT_GIVEN,
42-
retry_config: RetryConfig = DEFAULT_RETRY_CONFIG,
49+
retry_config: RetryConfig | None = None,
4350
):
4451
self.model = model
4552
self.reasoning = reasoning
@@ -49,7 +56,7 @@ def __init__(
4956
self.max_output_tokens = max_output_tokens
5057
self.top_p = top_p
5158
self.encoding_name: str
52-
self.retry_config = retry_config
59+
self.retry_config = retry_config or RetryConfig()
5360
try:
5461
self.encoding_name = tiktoken.encoding_name_for_model(model)
5562
except KeyError:
@@ -71,18 +78,28 @@ class Config(TurnCompleter.Config):
7178
)
7279

7380
model: str
74-
reasoning: Reasoning | None | NotGiven = NOT_GIVEN
81+
reasoning: ReasoningConfig | None | NotGiven = NOT_GIVEN
7582
text_format: type[BaseModel] | NotGiven = NOT_GIVEN
76-
tools: Iterable[ParseableToolParam] | NotGiven = NOT_GIVEN
83+
tools: list[ParseableToolParam] | NotGiven = NOT_GIVEN
7784
temperature: float | None | NotGiven = NOT_GIVEN
7885
max_output_tokens: int | None | NotGiven = NOT_GIVEN
7986
top_p: float | None | NotGiven = NOT_GIVEN
80-
retry_config: RetryConfig = DEFAULT_RETRY_CONFIG
87+
retry_config: RetryConfig = Field(default_factory=RetryConfig)
8188

8289
def build(self) -> OpenAIResponsesTurnCompleter:
90+
reasoning_param: Reasoning | None | NotGiven
91+
if isinstance(self.reasoning, ReasoningConfig):
92+
reasoning_param = Reasoning(
93+
effort=self.reasoning.effort,
94+
generate_summary=self.reasoning.generate_summary,
95+
summary=self.reasoning.summary,
96+
)
97+
else:
98+
reasoning_param = self.reasoning
99+
83100
return OpenAIResponsesTurnCompleter(
84101
model=self.model,
85-
reasoning=self.reasoning,
102+
reasoning=reasoning_param,
86103
text_format=self.text_format,
87104
tools=self.tools,
88105
temperature=self.temperature,

project/common/preparedness_turn_completer/preparedness_turn_completer/oai_responses_turn_completer/converters.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -54,17 +54,17 @@
5454
from openai.types.responses.response_output_text import (
5555
AnnotationURLCitation as ResponsesAnnotationURLCitation,
5656
)
57-
from preparedness_turn_completer.oai_responses_turn_completer.type_helpers import (
57+
from preparedness_turn_completer.turn_completer import TurnCompleter
58+
from preparedness_turn_completer.type_helpers import (
5859
ChatCompletionContent,
59-
is_assistant_message,
60+
is_chat_completion_assistant_message_param,
61+
is_chat_completion_tool_message_param,
6062
is_content_array_list,
6163
is_content_part_list,
6264
is_custom_tool_call_param,
6365
is_function_tool_call_param,
6466
is_text_parts_list,
65-
is_tool_message,
6667
)
67-
from preparedness_turn_completer.turn_completer import TurnCompleter
6868

6969
logger = structlog.stdlib.get_logger(component=__name__)
7070

@@ -140,19 +140,22 @@ def _user_completion_to_response_input_items(
140140
def _assistant_completion_to_response_input_items(
141141
message: ChatCompletionMessageParam,
142142
) -> list[ResponseInputItemParam]:
143-
content = message["content"]
144143
role: Literal["assistant"] = "assistant"
145144
input_items: list[ResponseInputItemParam] = []
146-
assert is_assistant_message(message)
147-
assert isinstance(content, str) or is_content_array_list(content), (
148-
f"Expected content to be str or list of content arrays, got {content!r}"
149-
)
150-
if isinstance(content, str):
151-
input_items.append(EasyInputMessageParam(content=content, role=role, type="message"))
152-
elif is_content_array_list(content):
153-
input_items.extend(_content_array_list_to_response_input_items(content))
154-
else:
155-
raise ValueError(f"Expected content to be str or list of content arrays, got {content!r}")
145+
assert is_chat_completion_assistant_message_param(message)
146+
if message.get("content") is not None:
147+
content = message["content"]
148+
assert isinstance(content, str) or is_content_array_list(content), (
149+
f"Expected content to be str or list of content arrays, got {content!r}"
150+
)
151+
if isinstance(content, str):
152+
input_items.append(EasyInputMessageParam(content=content, role=role, type="message"))
153+
elif is_content_array_list(content):
154+
input_items.extend(_content_array_list_to_response_input_items(content))
155+
else:
156+
raise ValueError(
157+
f"Expected content to be str or list of content arrays, got {content!r}"
158+
)
156159

157160
refusal = message.get("refusal", None)
158161
if isinstance(refusal, str):
@@ -225,7 +228,7 @@ def _content_array_list_to_response_input_items(
225228
def _tool_completion_to_response_input_items(
226229
message: ChatCompletionMessageParam,
227230
) -> list[ResponseInputItemParam]:
228-
assert is_tool_message(message)
231+
assert is_chat_completion_tool_message_param(message)
229232
content = message["content"]
230233
output_str: str
231234
if isinstance(content, str):
@@ -235,11 +238,7 @@ def _tool_completion_to_response_input_items(
235238

236239
return [
237240
FunctionCallOutput(
238-
call_id=message["tool_call_id"],
239-
output=output_str,
240-
type="function_call_output",
241-
id=uuid.uuid4().hex,
242-
status="completed",
241+
call_id=message["tool_call_id"], output=output_str, type="function_call_output"
243242
)
244243
]
245244

project/common/preparedness_turn_completer/preparedness_turn_completer/oai_responses_turn_completer/type_helpers.py

Lines changed: 0 additions & 112 deletions
This file was deleted.

0 commit comments

Comments
 (0)