Skip to content

Commit 1df37d5

Browse files
authored
Merge branch 'main' into think_bf
2 parents d73c1ac + 00fd928 commit 1df37d5

File tree

3 files changed

+190
-191
lines changed

3 files changed

+190
-191
lines changed

mellea/stdlib/genslot.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from pydantic import BaseModel, Field, create_model
1010

1111
from mellea.stdlib.base import Component, TemplateRepresentation
12-
from mellea.stdlib.session import get_session
12+
from mellea.stdlib.session import MelleaSession, get_session
1313

1414
P = ParamSpec("P")
1515
R = TypeVar("R")
@@ -154,7 +154,7 @@ def __init__(self, func: Callable[P, R]):
154154

155155
def __call__(
156156
self,
157-
m=None,
157+
m: MelleaSession | None = None,
158158
model_options: dict | None = None,
159159
*args: P.args,
160160
**kwargs: P.kwargs,
@@ -180,13 +180,11 @@ def __call__(
180180

181181
response_model = create_response_format(self._function._func)
182182

183-
response = m.genslot(
184-
slot_copy, model_options=model_options, format=response_model
185-
)
183+
response = m.act(slot_copy, format=response_model, model_options=model_options)
186184

187185
function_response: FunctionResponse[R] = response_model.model_validate_json(
188-
response.value
189-
) # type: ignore
186+
response.value # type: ignore
187+
)
190188

191189
return function_response.result
192190

mellea/stdlib/sampling.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def __init__(
4848
self.success = success
4949
self.sample_generations = sample_generations
5050
self.sample_validations = sample_validations
51+
self.sample_actions = sample_actions
5152

5253

5354
class SamplingStrategy(abc.ABC):
@@ -153,7 +154,7 @@ def select_from_failure(
153154
sampled_actions: list[Component],
154155
sampled_results: list[ModelOutputThunk],
155156
sampled_val: list[list[tuple[Requirement, ValidationResult]]],
156-
):
157+
) -> int:
157158
"""This function returns the index of the result that should be selected as `.value` iff the loop budget is exhausted and no success.
158159
159160
Args:
@@ -356,17 +357,17 @@ def select_from_failure(
356357

357358
@staticmethod
358359
def repair(
359-
context: Context,
360+
ctx: Context,
360361
past_actions: list[Component],
361362
past_results: list[ModelOutputThunk],
362363
past_val: list[list[tuple[Requirement, ValidationResult]]],
363364
) -> Component:
364-
assert isinstance(context, LinearContext), (
365+
assert isinstance(ctx, LinearContext), (
365366
" Need linear context to run agentic sampling."
366367
)
367368

368369
# add failed execution to chat history
369-
context.insert_turn(ContextTurn(past_actions[-1], past_results[-1]))
370+
ctx.insert_turn(ContextTurn(past_actions[-1], past_results[-1]))
370371

371372
last_failed_reqs: list[Requirement] = [s[0] for s in past_val[-1] if not s[1]]
372373
last_failed_reqs_str = "* " + "\n* ".join(

0 commit comments

Comments
 (0)