Skip to content

Commit 1e4709a

Browse files
Refactor api.py for structured output (#378)
* Refactor API wrapper with modular design for LLM calls and output handling Co-authored-by: ss.shankar505 <[email protected]> * Refactor APIWrapper: Simplify LLM call logic and improve modularity Co-authored-by: ss.shankar505 <[email protected]> * Refactor output mode handling in APIWrapper with flexible configuration Co-authored-by: ss.shankar505 <[email protected]> * Add comprehensive tests for DocETL output modes with synthetic data Co-authored-by: ss.shankar505 <[email protected]> * Refactor output modes tests with improved pytest structure and DSLRunner Co-authored-by: ss.shankar505 <[email protected]> * Fix runtime errors * Add nested JSON parsing for string values in API response Co-authored-by: ss.shankar505 <[email protected]> * Handle nested JSON parsing by extracting matching key values Co-authored-by: ss.shankar505 <[email protected]> * Simplify JSON parsing logic in API utility functions Co-authored-by: ss.shankar505 <[email protected]> * Add to tests * Add documentation for DocETL output modes and configuration options Co-authored-by: ss.shankar505 <[email protected]> * Add docs --------- Co-authored-by: Cursor Agent <[email protected]>
1 parent 4cee4d7 commit 1e4709a

File tree

6 files changed

+663
-13
lines changed

6 files changed

+663
-13
lines changed

Makefile

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@ tests:
88
poetry run pytest
99

1010
tests-basic:
11-
poetry run pytest tests/basic
11+
poetry run pytest -s tests/basic
1212
poetry run pytest -s tests/test_api.py
13-
poetry run pytest tests/test_runner_caching.py
13+
poetry run pytest -s tests/test_runner_caching.py
1414
poetry run pytest -s tests/test_pandas_accessors.py
15+
poetry run pytest -s tests/test_output_modes.py
1516

1617
lint:
1718
poetry run ruff check docetl/* --fix

docetl/operations/map.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from docetl.base_schemas import Tool, ToolFunction
1717
from docetl.operations.base import BaseOperation
1818
from docetl.operations.utils import RichLoopBar, strict_render
19+
from docetl.operations.utils.api import OutputMode
1920

2021

2122
class MapOperation(BaseOperation):
@@ -326,12 +327,17 @@ def _process_map_item(
326327
]
327328

328329
def validation_fn(response: Union[Dict[str, Any], ModelResponse]):
330+
structured_mode = (
331+
self.config.get("output", {}).get("mode")
332+
== OutputMode.STRUCTURED_OUTPUT.value
333+
)
329334
output = (
330335
self.runner.api.parse_llm_response(
331336
response,
332337
schema=self.config["output"]["schema"],
333338
tools=self.config.get("tools", None),
334339
manually_fix_errors=self.manually_fix_errors,
340+
use_structured_output=structured_mode,
335341
)[0]
336342
if isinstance(response, ModelResponse)
337343
else response
@@ -381,11 +387,16 @@ def validation_fn(response: Union[Dict[str, Any], ModelResponse]):
381387
if llm_result.validated:
382388
# Parse the response
383389
if isinstance(llm_result.response, ModelResponse):
390+
structured_mode = (
391+
self.config.get("output", {}).get("mode")
392+
== OutputMode.STRUCTURED_OUTPUT.value
393+
)
384394
outputs = self.runner.api.parse_llm_response(
385395
llm_result.response,
386396
schema=self.config["output"]["schema"],
387397
tools=self.config.get("tools", None),
388398
manually_fix_errors=self.manually_fix_errors,
399+
use_structured_output=structured_mode,
389400
)
390401
else:
391402
outputs = [llm_result.response]
@@ -432,8 +443,14 @@ def _process_map_batch(items: List[Dict]) -> Tuple[List[Dict], float]:
432443
total_cost += llm_result.total_cost
433444

434445
# Parse the LLM response
446+
structured_mode = (
447+
self.config.get("output", {}).get("mode")
448+
== OutputMode.STRUCTURED_OUTPUT.value
449+
)
435450
parsed_output = self.runner.api.parse_llm_response(
436-
llm_result.response, self.config["output"]["schema"]
451+
llm_result.response,
452+
self.config["output"]["schema"],
453+
use_structured_output=structured_mode,
437454
)[0].get("results", [])
438455
items_and_outputs = [
439456
(item, parsed_output[idx] if idx < len(parsed_output) else None)
@@ -709,11 +726,16 @@ def process_prompt(item, prompt_config):
709726
),
710727
op_config=self.config,
711728
)
729+
structured_mode = (
730+
self.config.get("output", {}).get("mode")
731+
== OutputMode.STRUCTURED_OUTPUT.value
732+
)
712733
output = self.runner.api.parse_llm_response(
713734
response.response,
714735
schema=local_output_schema,
715736
tools=prompt_config.get("tools", None),
716737
manually_fix_errors=self.manually_fix_errors,
738+
use_structured_output=structured_mode,
717739
)[0]
718740
return output, prompt, response.total_cost
719741

docetl/operations/reduce.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525
get_embeddings_for_clustering,
2626
)
2727
from docetl.operations.utils import rich_as_completed, strict_render
28+
29+
# Import OutputMode enum for structured output checks
30+
from docetl.operations.utils.api import OutputMode
2831
from docetl.utils import completion_cost
2932

3033

@@ -765,9 +768,14 @@ def _incremental_reduce(
765768
return current_output, prompts, total_cost
766769

767770
def validation_fn(self, response: Dict[str, Any]):
771+
structured_mode = (
772+
self.config.get("output", {}).get("mode")
773+
== OutputMode.STRUCTURED_OUTPUT.value
774+
)
768775
output = self.runner.api.parse_llm_response(
769776
response,
770777
schema=self.config["output"]["schema"],
778+
use_structured_output=structured_mode,
771779
)[0]
772780
if self.runner.api.validate_output(self.config, output, self.console):
773781
return output, True
@@ -834,10 +842,15 @@ def _increment_fold(
834842
self._update_fold_time(end_time - start_time)
835843

836844
if response.validated:
845+
structured_mode = (
846+
self.config.get("output", {}).get("mode")
847+
== OutputMode.STRUCTURED_OUTPUT.value
848+
)
837849
folded_output = self.runner.api.parse_llm_response(
838850
response.response,
839851
schema=self.config["output"]["schema"],
840852
manually_fix_errors=self.manually_fix_errors,
853+
use_structured_output=structured_mode,
841854
)[0]
842855

843856
folded_output.update(dict(zip(self.config["reduce_key"], key)))
@@ -897,10 +910,15 @@ def _merge_results(
897910
self._update_merge_time(end_time - start_time)
898911

899912
if response.validated:
913+
structured_mode = (
914+
self.config.get("output", {}).get("mode")
915+
== OutputMode.STRUCTURED_OUTPUT.value
916+
)
900917
merged_output = self.runner.api.parse_llm_response(
901918
response.response,
902919
schema=self.config["output"]["schema"],
903920
manually_fix_errors=self.manually_fix_errors,
921+
use_structured_output=structured_mode,
904922
)[0]
905923
merged_output.update(dict(zip(self.config["reduce_key"], key)))
906924
merge_cost = response.total_cost
@@ -1010,10 +1028,15 @@ def _batch_reduce(
10101028
item_cost += response.total_cost
10111029

10121030
if response.validated:
1031+
structured_mode = (
1032+
self.config.get("output", {}).get("mode")
1033+
== OutputMode.STRUCTURED_OUTPUT.value
1034+
)
10131035
output = self.runner.api.parse_llm_response(
10141036
response.response,
10151037
schema=self.config["output"]["schema"],
10161038
manually_fix_errors=self.manually_fix_errors,
1039+
use_structured_output=structured_mode,
10171040
)[0]
10181041
output.update(dict(zip(self.config["reduce_key"], key)))
10191042

0 commit comments

Comments
 (0)