Skip to content

Commit 348fedd

Browse files
authored
Fixes for Adapters (Pydantic validation) and SIMBA (candidate_programs form) (#8141)
1 parent 43a8a53 commit 348fedd

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

dspy/adapters/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,7 @@ def parse_value(value, annotation):
155155
if annotation.__origin__ is Union and type(None) in get_args(annotation):
156156
if len(get_args(annotation)) == 2 and str in get_args(annotation):
157157
return str(candidate)
158-
else:
159-
raise e
158+
raise e
160159

161160
def get_annotation_name(annotation):
162161
origin = get_origin(annotation)

dspy/teleprompt/simba.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -288,10 +288,10 @@ def register_new_program(prog: dspy.Module, score_list: list[float]):
288288
M = len(winning_programs) - 1
289289
N = self.num_candidates + 1
290290
if M < 1:
291-
# Only one or zero winning programs
292291
program_idxs = [0] * N
293292
else:
294293
program_idxs = [round(i * M / (N - 1)) for i in range(N)]
294+
295295
program_idxs = list(dict.fromkeys(program_idxs))
296296

297297
candidate_programs = [winning_programs[i].deepcopy() for i in program_idxs]
@@ -307,18 +307,22 @@ def register_new_program(prog: dspy.Module, score_list: list[float]):
307307
avg_score = sum(sys_scores) / len(sys_scores) if sys_scores else 0.0
308308
scores.append(avg_score)
309309
if idx_prog != 0:
310-
trial_logs[idx_prog-1]["train_score"] = avg_score
311-
310+
trial_logs[idx_prog - 1]["train_score"] = avg_score
311+
312+
# Build sorted list of {"score", "program"} dicts
313+
assert len(scores) == len(candidate_programs)
314+
candidate_data = [{"score": s, "program": p} for s, p in zip(scores, candidate_programs)]
315+
candidate_data.sort(key=lambda x: x["score"], reverse=True)
316+
312317
best_idx = scores.index(max(scores)) if scores else 0
313318
best_program = candidate_programs[best_idx].deepcopy()
314319
logger.info(
315320
f"Final trainset scores: {scores}, Best: {max(scores) if scores else 'N/A'} "
316321
f"(at index {best_idx if scores else 'N/A'})\n\n\n"
317322
)
318323

319-
# FIXME: Attach all program candidates in decreasing average score to the best program.
320-
best_program.candidate_programs = candidate_programs
321-
best_program.winning_programs = winning_programs
324+
# Attach sorted, scored candidates & logs
325+
best_program.candidate_programs = candidate_data
322326
best_program.trial_logs = trial_logs
323327

324328
return best_program

0 commit comments

Comments
 (0)