diff --git a/deep_research/materializers/__init__.py b/deep_research/materializers/__init__.py index 260eedb3..7009d557 100644 --- a/deep_research/materializers/__init__.py +++ b/deep_research/materializers/__init__.py @@ -2,20 +2,25 @@ Materializers package for the ZenML Deep Research project. This package contains custom ZenML materializers that handle serialization and -deserialization of complex data types used in the research pipeline, particularly -the ResearchState object that tracks the state of the research process. +deserialization of complex data types used in the research pipeline. """ +from .analysis_data_materializer import AnalysisDataMaterializer from .approval_decision_materializer import ApprovalDecisionMaterializer +from .final_report_materializer import FinalReportMaterializer from .prompt_materializer import PromptMaterializer -from .pydantic_materializer import ResearchStateMaterializer -from .reflection_output_materializer import ReflectionOutputMaterializer +from .query_context_materializer import QueryContextMaterializer +from .search_data_materializer import SearchDataMaterializer +from .synthesis_data_materializer import SynthesisDataMaterializer from .tracing_metadata_materializer import TracingMetadataMaterializer __all__ = [ "ApprovalDecisionMaterializer", "PromptMaterializer", - "ReflectionOutputMaterializer", - "ResearchStateMaterializer", "TracingMetadataMaterializer", + "QueryContextMaterializer", + "SearchDataMaterializer", + "SynthesisDataMaterializer", + "AnalysisDataMaterializer", + "FinalReportMaterializer", ] diff --git a/deep_research/materializers/analysis_data_materializer.py b/deep_research/materializers/analysis_data_materializer.py new file mode 100644 index 00000000..79053a79 --- /dev/null +++ b/deep_research/materializers/analysis_data_materializer.py @@ -0,0 +1,396 @@ +"""Materializer for AnalysisData with viewpoint tension diagrams and reflection insights.""" + +import os +from typing import Dict + +from utils.pydantic_models import AnalysisData +from zenml.enums import ArtifactType, VisualizationType +from zenml.io import fileio +from zenml.materializers import PydanticMaterializer + + +class AnalysisDataMaterializer(PydanticMaterializer): + """Materializer for AnalysisData with viewpoint and reflection visualization.""" + + ASSOCIATED_TYPES = (AnalysisData,) + ASSOCIATED_ARTIFACT_TYPE = ArtifactType.DATA + + def save_visualizations( + self, data: AnalysisData + ) -> Dict[str, VisualizationType]: + """Create and save visualizations for the AnalysisData. + + Args: + data: The AnalysisData to visualize + + Returns: + Dictionary mapping file paths to visualization types + """ + visualization_path = os.path.join(self.uri, "analysis_data.html") + html_content = self._generate_visualization_html(data) + + with fileio.open(visualization_path, "w") as f: + f.write(html_content) + + return {visualization_path: VisualizationType.HTML} + + def _generate_visualization_html(self, data: AnalysisData) -> str: + """Generate HTML visualization for the analysis data. + + Args: + data: The AnalysisData to visualize + + Returns: + HTML string + """ + # Viewpoint analysis section + viewpoint_html = "" + if data.viewpoint_analysis: + va = data.viewpoint_analysis + + # Points of agreement + agreement_html = "" + if va.main_points_of_agreement: + agreement_html = "

Main Points of Agreement

" + + # Areas of tension + tensions_html = "" + if va.areas_of_tension: + tensions_html = ( + "

Areas of Tension

" + ) + for tension in va.areas_of_tension: + viewpoints_html = "" + for perspective, view in tension.viewpoints.items(): + viewpoints_html += f""" +
+
{perspective}
+
{view}
+
+ """ + + tensions_html += f""" +
+

{tension.topic}

+
+ {viewpoints_html} +
+
+ """ + tensions_html += "
" + + # Perspective gaps + gaps_html = "" + if va.perspective_gaps: + gaps_html = f""" +
+

Perspective Gaps

+

{va.perspective_gaps}

+
+ """ + + # Integrative insights + insights_html = "" + if va.integrative_insights: + insights_html = f""" +
+

Integrative Insights

+

{va.integrative_insights}

+
+ """ + + viewpoint_html = f""" +
+

Viewpoint Analysis

+ {agreement_html} + {tensions_html} + {gaps_html} + {insights_html} +
+ """ + + # Reflection metadata section + reflection_html = "" + if data.reflection_metadata: + rm = data.reflection_metadata + + # Critique summary + critique_html = "" + if rm.critique_summary: + critique_html = "

Critique Summary

" + + # Additional questions + questions_html = "" + if rm.additional_questions_identified: + questions_html = "

Additional Questions Identified

" + + # Searches performed + searches_html = "" + if rm.searches_performed: + searches_html = "

Searches Performed

" + + # Error handling + error_html = "" + if rm.error: + error_html = f""" +
+

Error Encountered

+

{rm.error}

+
+ """ + + reflection_html = f""" +
+

Reflection Metadata

+
+ {int(rm.improvements_made)} + Improvements Made +
+ {critique_html} + {questions_html} + {searches_html} + {error_html} +
+ """ + + # Handle empty state + if not viewpoint_html and not reflection_html: + content_html = ( + '
No analysis data available yet
' + ) + else: + content_html = viewpoint_html + reflection_html + + html = f""" + + + + Analysis Data Visualization + + + +
+
+

Research Analysis

+
+ + {content_html} +
+ + + """ + + return html diff --git a/deep_research/materializers/final_report_materializer.py b/deep_research/materializers/final_report_materializer.py new file mode 100644 index 00000000..3eb848ba --- /dev/null +++ b/deep_research/materializers/final_report_materializer.py @@ -0,0 +1,269 @@ +"""Materializer for FinalReport with enhanced interactive report visualization.""" + +import os +from datetime import datetime +from typing import Dict + +from utils.pydantic_models import FinalReport +from zenml.enums import ArtifactType, VisualizationType +from zenml.io import fileio +from zenml.materializers import PydanticMaterializer + + +class FinalReportMaterializer(PydanticMaterializer): + """Materializer for FinalReport with interactive report visualization.""" + + ASSOCIATED_TYPES = (FinalReport,) + ASSOCIATED_ARTIFACT_TYPE = ArtifactType.DATA + + def save_visualizations( + self, data: FinalReport + ) -> Dict[str, VisualizationType]: + """Create and save visualizations for the FinalReport. + + Args: + data: The FinalReport to visualize + + Returns: + Dictionary mapping file paths to visualization types + """ + # Save the actual report + report_path = os.path.join(self.uri, "final_report.html") + with fileio.open(report_path, "w") as f: + f.write(data.report_html) + + # Save a wrapper visualization with metadata + visualization_path = os.path.join( + self.uri, "report_visualization.html" + ) + html_content = self._generate_visualization_html(data) + + with fileio.open(visualization_path, "w") as f: + f.write(html_content) + + return { + report_path: VisualizationType.HTML, + visualization_path: VisualizationType.HTML, + } + + def _generate_visualization_html(self, data: FinalReport) -> str: + """Generate HTML wrapper visualization for the final report. + + Args: + data: The FinalReport to visualize + + Returns: + HTML string + """ + # Format timestamp + timestamp = datetime.fromtimestamp(data.generated_at).strftime( + "%B %d, %Y at %I:%M %p UTC" + ) + + # Extract some statistics from the HTML report if possible + report_length = len(data.report_html) + + html = f""" + + + + Final Research Report - {data.main_query[:50]}... + + + +
+
+

Final Research Report

+ +
+
Research Query
+
{data.main_query}
+
+ + +
+
+ +
+
+ + Open in New Tab + + +
+ +
+ Loading report... +
+ + +
+ + + + + """ + + return html diff --git a/deep_research/materializers/pydantic_materializer.py b/deep_research/materializers/pydantic_materializer.py deleted file mode 100644 index ee01281b..00000000 --- a/deep_research/materializers/pydantic_materializer.py +++ /dev/null @@ -1,764 +0,0 @@ -"""Pydantic materializer for research state objects. - -This module contains an extended version of ZenML's PydanticMaterializer -that adds visualization capabilities for the ResearchState model. -""" - -import os -from typing import Dict - -from utils.pydantic_models import ResearchState -from zenml.enums import ArtifactType, VisualizationType -from zenml.io import fileio -from zenml.materializers import PydanticMaterializer - - -class ResearchStateMaterializer(PydanticMaterializer): - """Materializer for the ResearchState class with visualizations.""" - - ASSOCIATED_TYPES = (ResearchState,) - ASSOCIATED_ARTIFACT_TYPE = ArtifactType.DATA - - def save_visualizations( - self, data: ResearchState - ) -> Dict[str, VisualizationType]: - """Create and save visualizations for the ResearchState. - - Args: - data: The ResearchState to visualize - - Returns: - Dictionary mapping file paths to visualization types - """ - # Generate an HTML visualization - visualization_path = os.path.join(self.uri, "research_state.html") - - # Create HTML content based on current stage - html_content = self._generate_visualization_html(data) - - # Write the HTML content to a file - with fileio.open(visualization_path, "w") as f: - f.write(html_content) - - # Return the visualization path and type - return {visualization_path: VisualizationType.HTML} - - def _generate_visualization_html(self, state: ResearchState) -> str: - """Generate HTML visualization for the research state. - - Args: - state: The ResearchState to visualize - - Returns: - HTML string - """ - # Base structure for the HTML - html = f""" - - - - Research State: {state.main_query} - - - - -
-

Research State

- - -
-
Initial Query
-
Query Decomposition
-
Information Gathering
-
Information Synthesis
-
Viewpoint Analysis
-
Reflection & Enhancement
-
Final Report
-
- - -
-
-
- - - - - - """ - - # Overview tab content (always shown) - is_active = default_active_tab == "overview" - html += f""" -
-
-

Main Query

-
- """ - - if state.main_query: - html += f"

{state.main_query}

" - else: - html += "

No main query specified

" - - html += """ -
-
-
- """ - - # Sub-questions tab content - if state.sub_questions: - is_active = default_active_tab == "sub-questions" - html += f""" -
-
-

Sub-Questions ({len(state.sub_questions)})

-
- """ - - for i, question in enumerate(state.sub_questions): - html += f""" -
- {i + 1}. {question} -
- """ - - html += """ -
-
-
- """ - - # Search results tab content - if state.search_results: - is_active = default_active_tab == "search-results" - html += f""" -
-
-

Search Results

- """ - - for question, results in state.search_results.items(): - html += f""" -

{question}

-

Found {len(results)} results

-
    - """ - - for result in results: - # Extract domain from URL or use special handling for generated content - if result.url == "tavily-generated-answer": - domain = "Tavily" - else: - domain = "" - try: - from urllib.parse import urlparse - - parsed_url = urlparse(result.url) - domain = parsed_url.netloc - # Strip www. prefix to save space - if domain.startswith("www."): - domain = domain[4:] - except: - domain = ( - result.url.split("/")[2] - if len(result.url.split("/")) > 2 - else "" - ) - # Strip www. prefix to save space - if domain.startswith("www."): - domain = domain[4:] - - html += f""" -
  • - {result.title} ({domain}) -
  • - """ - - html += """ -
- """ - - html += """ -
-
- """ - - # Synthesized information tab content - if state.synthesized_info: - is_active = default_active_tab == "synthesis" - html += f""" -
-
-

Synthesized Information

- """ - - for question, info in state.synthesized_info.items(): - html += f""" -

{question} {info.confidence_level}

-
-

{info.synthesized_answer}

- """ - - if info.key_sources: - html += """ -
-

Key Sources:

-
    - """ - - for source in info.key_sources[:3]: - html += f""" -
  • {source[:50]}...
  • - """ - - if len(info.key_sources) > 3: - html += f"
  • ...and {len(info.key_sources) - 3} more sources
  • " - - html += """ -
-
- """ - - if info.information_gaps: - html += f""" - - """ - - html += """ -
- """ - - html += """ -
-
- """ - - # Viewpoint analysis tab content - if state.viewpoint_analysis: - is_active = default_active_tab == "viewpoints" - html += f""" -
-
-

Viewpoint Analysis

-
- """ - - # Points of agreement - if state.viewpoint_analysis.main_points_of_agreement: - html += """ -

Points of Agreement

-
    - """ - - for point in state.viewpoint_analysis.main_points_of_agreement: - html += f""" -
  • {point}
  • - """ - - html += """ -
- """ - - # Areas of tension - if state.viewpoint_analysis.areas_of_tension: - html += """ -

Areas of Tension

- """ - - for tension in state.viewpoint_analysis.areas_of_tension: - html += f""" -
-

{tension.topic}

-
    - """ - - for viewpoint, description in tension.viewpoints.items(): - html += f""" -
  • {viewpoint}: {description}
  • - """ - - html += """ -
-
- """ - - # Perspective gaps and integrative insights - if state.viewpoint_analysis.perspective_gaps: - html += f""" -

Perspective Gaps

-

{state.viewpoint_analysis.perspective_gaps}

- """ - - if state.viewpoint_analysis.integrative_insights: - html += f""" -

Integrative Insights

-

{state.viewpoint_analysis.integrative_insights}

- """ - - html += """ -
-
-
- """ - - # Reflection & Enhancement tab content - if state.enhanced_info or state.reflection_metadata: - is_active = default_active_tab == "reflection" - html += f""" -
-
-

Reflection & Enhancement

- """ - - # Reflection metadata - if state.reflection_metadata: - html += """ -
- """ - - if state.reflection_metadata.critique_summary: - html += """ -

Critique Summary

-
    - """ - - for critique in state.reflection_metadata.critique_summary: - html += f""" -
  • {critique}
  • - """ - - html += """ -
- """ - - if state.reflection_metadata.additional_questions_identified: - html += """ -

Additional Questions Identified

-
    - """ - - for question in state.reflection_metadata.additional_questions_identified: - html += f""" -
  • {question}
  • - """ - - html += """ -
- """ - - html += f""" - -
- """ - - # Enhanced information - if state.enhanced_info: - html += """ -

Enhanced Information

- """ - - for question, info in state.enhanced_info.items(): - # Show only for questions with improvements - if info.improvements: - html += f""" -
-

{question} {info.confidence_level}

- -
-

Improvements Made:

-
    - """ - - for improvement in info.improvements: - html += f""" -
  • {improvement}
  • - """ - - html += """ -
-
-
- """ - - html += """ -
-
- """ - - # Final report tab - if state.final_report_html: - is_active = default_active_tab == "final-report" - html += f""" -
-
-

Final Report

-

Final HTML report is available but not displayed here. View the HTML artifact to see the complete report.

-
-
- """ - - # Close HTML tags - html += """ -
- - - """ - - return html - - def _get_stage_class(self, state: ResearchState, stage: str) -> str: - """Get CSS class for a stage based on current progress. - - Args: - state: ResearchState object - stage: Stage name - - Returns: - CSS class string - """ - current_stage = state.get_current_stage() - - # These are the stages in order - stages = [ - "empty", - "initial", - "after_query_decomposition", - "after_search", - "after_synthesis", - "after_viewpoint_analysis", - "after_reflection", - "final_report", - ] - - current_index = ( - stages.index(current_stage) if current_stage in stages else 0 - ) - stage_index = stages.index(stage) if stage in stages else 0 - - if stage_index == current_index: - return "active" - elif stage_index < current_index: - return "completed" - else: - return "" - - def _calculate_progress(self, state: ResearchState) -> int: - """Calculate overall progress percentage. - - Args: - state: ResearchState object - - Returns: - Progress percentage (0-100) - """ - # Map stages to progress percentages - stage_percentages = { - "empty": 0, - "initial": 5, - "after_query_decomposition": 20, - "after_search": 40, - "after_synthesis": 60, - "after_viewpoint_analysis": 75, - "after_reflection": 90, - "final_report": 100, - } - - current_stage = state.get_current_stage() - return stage_percentages.get(current_stage, 0) diff --git a/deep_research/materializers/query_context_materializer.py b/deep_research/materializers/query_context_materializer.py new file mode 100644 index 00000000..ebb7088b --- /dev/null +++ b/deep_research/materializers/query_context_materializer.py @@ -0,0 +1,272 @@ +"""Materializer for QueryContext with interactive mind map visualization.""" + +import os +from typing import Dict + +from utils.pydantic_models import QueryContext +from zenml.enums import ArtifactType, VisualizationType +from zenml.io import fileio +from zenml.materializers import PydanticMaterializer + + +class QueryContextMaterializer(PydanticMaterializer): + """Materializer for QueryContext with mind map visualization.""" + + ASSOCIATED_TYPES = (QueryContext,) + ASSOCIATED_ARTIFACT_TYPE = ArtifactType.DATA + + def save_visualizations( + self, data: QueryContext + ) -> Dict[str, VisualizationType]: + """Create and save mind map visualization for the QueryContext. + + Args: + data: The QueryContext to visualize + + Returns: + Dictionary mapping file paths to visualization types + """ + visualization_path = os.path.join(self.uri, "query_context.html") + html_content = self._generate_visualization_html(data) + + with fileio.open(visualization_path, "w") as f: + f.write(html_content) + + return {visualization_path: VisualizationType.HTML} + + def _generate_visualization_html(self, context: QueryContext) -> str: + """Generate HTML mind map visualization for the query context. + + Args: + context: The QueryContext to visualize + + Returns: + HTML string + """ + # Create sub-questions HTML + sub_questions_html = "" + if context.sub_questions: + for i, sub_q in enumerate(context.sub_questions, 1): + sub_questions_html += f""" +
+
{i}
+
{sub_q}
+
+ """ + else: + sub_questions_html = '
No sub-questions decomposed yet
' + + # Format timestamp + from datetime import datetime + + timestamp = datetime.fromtimestamp( + context.decomposition_timestamp + ).strftime("%Y-%m-%d %H:%M:%S UTC") + + html = f""" + + + + Query Context - {context.main_query[:50]}... + + + +
+
+

Query Decomposition Mind Map

+
Created: {timestamp}
+
+ +
+
+ {context.main_query} +
+ +
+ {sub_questions_html} +
+
+ +
+
+
{len(context.sub_questions)}
+
Sub-Questions
+
+
+
{len(context.main_query.split())}
+
Words in Query
+
+
+
{sum(len(q.split()) for q in context.sub_questions)}
+
Total Sub-Question Words
+
+
+
+ + + """ + + return html diff --git a/deep_research/materializers/reflection_output_materializer.py b/deep_research/materializers/reflection_output_materializer.py deleted file mode 100644 index 1e8b37ae..00000000 --- a/deep_research/materializers/reflection_output_materializer.py +++ /dev/null @@ -1,279 +0,0 @@ -"""Materializer for ReflectionOutput with custom visualization.""" - -import os -from typing import Dict - -from utils.pydantic_models import ReflectionOutput -from zenml.enums import ArtifactType, VisualizationType -from zenml.io import fileio -from zenml.materializers import PydanticMaterializer - - -class ReflectionOutputMaterializer(PydanticMaterializer): - """Materializer for the ReflectionOutput class with visualizations.""" - - ASSOCIATED_TYPES = (ReflectionOutput,) - ASSOCIATED_ARTIFACT_TYPE = ArtifactType.DATA - - def save_visualizations( - self, data: ReflectionOutput - ) -> Dict[str, VisualizationType]: - """Create and save visualizations for the ReflectionOutput. - - Args: - data: The ReflectionOutput to visualize - - Returns: - Dictionary mapping file paths to visualization types - """ - # Generate an HTML visualization - visualization_path = os.path.join(self.uri, "reflection_output.html") - - # Create HTML content - html_content = self._generate_visualization_html(data) - - # Write the HTML content to a file - with fileio.open(visualization_path, "w") as f: - f.write(html_content) - - # Return the visualization path and type - return {visualization_path: VisualizationType.HTML} - - def _generate_visualization_html(self, output: ReflectionOutput) -> str: - """Generate HTML visualization for the reflection output. - - Args: - output: The ReflectionOutput to visualize - - Returns: - HTML string - """ - html = f""" - - - - Reflection Output - - - -
-

🔍 Reflection & Analysis Output

- - - -
-

- 📝Critique Summary - {} -

- """.format(len(output.critique_summary)) - - if output.critique_summary: - for critique in output.critique_summary: - html += """ -
- """ - - # Handle different critique formats - if isinstance(critique, dict): - for key, value in critique.items(): - html += f""" -
{key}:
-
{value}
- """ - else: - html += f""" -
{critique}
- """ - - html += """ -
- """ - else: - html += """ -

No critique summary available

- """ - - html += """ -
- -
-

- Additional Questions Identified - {} -

- """.format(len(output.additional_questions)) - - if output.additional_questions: - for question in output.additional_questions: - html += f""" -
- {question} -
- """ - else: - html += """ -

No additional questions identified

- """ - - html += """ -
- -
-

📊Research State Summary

-
-

Main Query: {}

-

Current Stage: {}

-

Sub-questions: {}

-

Search Results: {} queries with results

-

Synthesized Info: {} topics synthesized

-
-
- """.format( - output.state.main_query, - output.state.get_current_stage().replace("_", " ").title(), - len(output.state.sub_questions), - len(output.state.search_results), - len(output.state.synthesized_info), - ) - - # Add metadata section - html += """ -
-

This reflection output suggests improvements and additional research directions based on the current research state.

-
-
- - - """ - - return html diff --git a/deep_research/materializers/search_data_materializer.py b/deep_research/materializers/search_data_materializer.py new file mode 100644 index 00000000..7c0ef883 --- /dev/null +++ b/deep_research/materializers/search_data_materializer.py @@ -0,0 +1,394 @@ +"""Materializer for SearchData with cost breakdown charts and search results visualization.""" + +import json +import os +from typing import Dict + +from utils.pydantic_models import SearchData +from zenml.enums import ArtifactType, VisualizationType +from zenml.io import fileio +from zenml.materializers import PydanticMaterializer + + +class SearchDataMaterializer(PydanticMaterializer): + """Materializer for SearchData with interactive visualizations.""" + + ASSOCIATED_TYPES = (SearchData,) + ASSOCIATED_ARTIFACT_TYPE = ArtifactType.DATA + + def save_visualizations( + self, data: SearchData + ) -> Dict[str, VisualizationType]: + """Create and save visualizations for the SearchData. + + Args: + data: The SearchData to visualize + + Returns: + Dictionary mapping file paths to visualization types + """ + visualization_path = os.path.join(self.uri, "search_data.html") + html_content = self._generate_visualization_html(data) + + with fileio.open(visualization_path, "w") as f: + f.write(html_content) + + return {visualization_path: VisualizationType.HTML} + + def _generate_visualization_html(self, data: SearchData) -> str: + """Generate HTML visualization for the search data. + + Args: + data: The SearchData to visualize + + Returns: + HTML string + """ + # Prepare data for charts + cost_data = [ + {"provider": k, "cost": v} for k, v in data.search_costs.items() + ] + + # Create search results HTML + results_html = "" + for sub_q, results in data.search_results.items(): + results_html += f""" +
+

{sub_q}

+
{len(results)} results found
+
+ """ + + for i, result in enumerate(results[:5]): # Show first 5 results + results_html += f""" +
+
{result.title or "Untitled"}
+
{result.snippet or result.content[:200]}...
+ View Source +
+ """ + + if len(results) > 5: + results_html += f'
... and {len(results) - 5} more results
' + + results_html += """ +
+
+ """ + + if not results_html: + results_html = ( + '
No search results yet
' + ) + + # Calculate total cost + total_cost = sum(data.search_costs.values()) + + html = f""" + + + + Search Data Visualization + + + + +
+
+

Search Data Analysis

+ +
+
+
{data.total_searches}
+
Total Searches
+
+
+
{ + len(data.search_results) + }
+
Sub-Questions
+
+
+
{ + sum(len(results) for results in data.search_results.values()) + }
+
Total Results
+
+
+
${total_cost:.4f}
+
Total Cost
+
+
+
+ +
+

Cost Analysis

+ +
+ +
+ +
+

Cost Breakdown by Provider

+ + + + + + + + + + { + "".join( + f''' + + + + + + ''' + for provider, cost in data.search_costs.items() + ) + } + +
ProviderCostPercentage
{provider}${cost:.4f}{(cost / total_cost * 100 if total_cost > 0 else 0):.1f}%
+
+
+ +
+

Search Results

+ {results_html} +
+
+ + + + + """ + + return html diff --git a/deep_research/materializers/synthesis_data_materializer.py b/deep_research/materializers/synthesis_data_materializer.py new file mode 100644 index 00000000..fb5d68f2 --- /dev/null +++ b/deep_research/materializers/synthesis_data_materializer.py @@ -0,0 +1,431 @@ +"""Materializer for SynthesisData with confidence metrics and synthesis quality visualization.""" + +import os +from typing import Dict + +from utils.pydantic_models import SynthesisData +from zenml.enums import ArtifactType, VisualizationType +from zenml.io import fileio +from zenml.materializers import PydanticMaterializer + + +class SynthesisDataMaterializer(PydanticMaterializer): + """Materializer for SynthesisData with quality metrics visualization.""" + + ASSOCIATED_TYPES = (SynthesisData,) + ASSOCIATED_ARTIFACT_TYPE = ArtifactType.DATA + + def save_visualizations( + self, data: SynthesisData + ) -> Dict[str, VisualizationType]: + """Create and save visualizations for the SynthesisData. + + Args: + data: The SynthesisData to visualize + + Returns: + Dictionary mapping file paths to visualization types + """ + visualization_path = os.path.join(self.uri, "synthesis_data.html") + html_content = self._generate_visualization_html(data) + + with fileio.open(visualization_path, "w") as f: + f.write(html_content) + + return {visualization_path: VisualizationType.HTML} + + def _generate_visualization_html(self, data: SynthesisData) -> str: + """Generate HTML visualization for the synthesis data. + + Args: + data: The SynthesisData to visualize + + Returns: + HTML string + """ + # Count confidence levels + confidence_counts = {"high": 0, "medium": 0, "low": 0} + for info in data.synthesized_info.values(): + confidence_counts[info.confidence_level] += 1 + + # Create synthesis cards HTML + synthesis_html = "" + for sub_q, info in data.synthesized_info.items(): + confidence_color = { + "high": "#2dce89", + "medium": "#ffd600", + "low": "#f5365c", + }.get(info.confidence_level, "#666") + + sources_html = "" + if info.key_sources: + sources_html = ( + "
Key Sources:
" + + gaps_html = "" + if info.information_gaps: + gaps_html = f""" +
+ Information Gaps: +

{info.information_gaps}

+
+ """ + + improvements_html = "" + if info.improvements: + improvements_html = "
Suggested Improvements:
" + + # Check if this has enhanced version + enhanced_badge = "" + enhanced_section = "" + if sub_q in data.enhanced_info: + enhanced_badge = 'Enhanced' + enhanced_info = data.enhanced_info[sub_q] + enhanced_section = f""" +
+

Enhanced Answer

+

{enhanced_info.synthesized_answer}

+
+ Confidence: {enhanced_info.confidence_level.upper()} +
+
+ """ + + synthesis_html += f""" +
+
+

{sub_q}

+ {enhanced_badge} +
+ +
+

Original Synthesis

+

{info.synthesized_answer}

+ +
+ Confidence: {info.confidence_level.upper()} +
+ + {sources_html} + {gaps_html} + {improvements_html} +
+ + {enhanced_section} +
+ """ + + if not synthesis_html: + synthesis_html = '
No synthesis data available yet
' + + # Calculate statistics + total_syntheses = len(data.synthesized_info) + total_enhanced = len(data.enhanced_info) + avg_sources = sum( + len(info.key_sources) for info in data.synthesized_info.values() + ) / max(total_syntheses, 1) + + html = f""" + + + + Synthesis Data Visualization + + + + +
+
+

Synthesis Quality Analysis

+
+ +
+
+
{total_syntheses}
+
Total Syntheses
+
+
+
{total_enhanced}
+
Enhanced Syntheses
+
+
+
{avg_sources:.1f}
+
Avg Sources per Synthesis
+
+
+
{confidence_counts["high"]}
+
High Confidence
+
+
+ +
+

Confidence Distribution

+
+ +
+
+ +
+

Synthesized Information

+ {synthesis_html} +
+
+ + + + + """ + + return html diff --git a/deep_research/pipelines/parallel_research_pipeline.py b/deep_research/pipelines/parallel_research_pipeline.py index 669b4824..fabdc204 100644 --- a/deep_research/pipelines/parallel_research_pipeline.py +++ b/deep_research/pipelines/parallel_research_pipeline.py @@ -8,7 +8,6 @@ from steps.process_sub_question_step import process_sub_question_step from steps.pydantic_final_report_step import pydantic_final_report_step from steps.query_decomposition_step import initial_query_decomposition_step -from utils.pydantic_models import ResearchState from zenml import pipeline @@ -56,24 +55,22 @@ def parallelized_deep_research_pipeline( introduction_prompt, ) = initialize_prompts_step(pipeline_version="1.0.0") - # Initialize the research state with the main query - state = ResearchState(main_query=query) - # Step 1: Decompose the query into sub-questions, limiting to max_sub_questions - decomposed_state = initial_query_decomposition_step( - state=state, + query_context = initial_query_decomposition_step( + main_query=query, query_decomposition_prompt=query_decomposition_prompt, max_sub_questions=max_sub_questions, langfuse_project_name=langfuse_project_name, ) # Fan out: Process each sub-question in parallel - # Collect artifacts to establish dependencies for the merge step - after = [] + # Collect step names to establish dependencies for the merge step + parallel_step_names = [] for i in range(max_sub_questions): # Process the i-th sub-question (if it exists) - sub_state = process_sub_question_step( - state=decomposed_state, + step_name = f"process_question_{i + 1}" + search_data, synthesis_data = process_sub_question_step( + query_context=query_context, search_query_prompt=search_query_prompt, synthesis_prompt=synthesis_prompt, question_index=i, @@ -81,66 +78,87 @@ def parallelized_deep_research_pipeline( search_mode=search_mode, num_results_per_search=num_results_per_search, langfuse_project_name=langfuse_project_name, - id=f"process_question_{i + 1}", + id=step_name, + after="initial_query_decomposition_step", ) - after.append(sub_state) + parallel_step_names.append(step_name) # Fan in: Merge results from all parallel processing # The 'after' parameter ensures this step runs after all processing steps - # It doesn't directly use the processed_states input - merged_state = merge_sub_question_results_step( - original_state=decomposed_state, - step_prefix="process_question_", - output_name="output", - after=after, # This creates the dependency + merged_search_data, merged_synthesis_data = ( + merge_sub_question_results_step( + step_prefix="process_question_", + after=parallel_step_names, # Wait for all parallel steps to complete + ) ) # Continue with subsequent steps - analyzed_state = cross_viewpoint_analysis_step( - state=merged_state, + analysis_data = cross_viewpoint_analysis_step( + query_context=query_context, + synthesis_data=merged_synthesis_data, viewpoint_analysis_prompt=viewpoint_analysis_prompt, langfuse_project_name=langfuse_project_name, + after="merge_sub_question_results_step", ) # New 3-step reflection flow with optional human approval # Step 1: Generate reflection and recommendations (no searches yet) - reflection_output = generate_reflection_step( - state=analyzed_state, + analysis_with_reflection, recommended_queries = generate_reflection_step( + query_context=query_context, + synthesis_data=merged_synthesis_data, + analysis_data=analysis_data, reflection_prompt=reflection_prompt, langfuse_project_name=langfuse_project_name, + after="cross_viewpoint_analysis_step", ) # Step 2: Get approval for recommended searches approval_decision = get_research_approval_step( - reflection_output=reflection_output, + query_context=query_context, + synthesis_data=merged_synthesis_data, + analysis_data=analysis_with_reflection, + recommended_queries=recommended_queries, require_approval=require_approval, timeout=approval_timeout, max_queries=max_additional_searches, + after="generate_reflection_step", ) # Step 3: Execute approved searches (if any) - reflected_state = execute_approved_searches_step( - reflection_output=reflection_output, - approval_decision=approval_decision, - additional_synthesis_prompt=additional_synthesis_prompt, - search_provider=search_provider, - search_mode=search_mode, - num_results_per_search=num_results_per_search, - langfuse_project_name=langfuse_project_name, + enhanced_search_data, enhanced_synthesis_data, enhanced_analysis_data = ( + execute_approved_searches_step( + query_context=query_context, + search_data=merged_search_data, + synthesis_data=merged_synthesis_data, + analysis_data=analysis_with_reflection, + recommended_queries=recommended_queries, + approval_decision=approval_decision, + additional_synthesis_prompt=additional_synthesis_prompt, + search_provider=search_provider, + search_mode=search_mode, + num_results_per_search=num_results_per_search, + langfuse_project_name=langfuse_project_name, + after="get_research_approval_step", + ) ) # Use our new Pydantic-based final report step - # This returns a tuple (state, html_report) - final_state, final_report = pydantic_final_report_step( - state=reflected_state, + pydantic_final_report_step( + query_context=query_context, + search_data=enhanced_search_data, + synthesis_data=enhanced_synthesis_data, + analysis_data=enhanced_analysis_data, conclusion_generation_prompt=conclusion_generation_prompt, executive_summary_prompt=executive_summary_prompt, introduction_prompt=introduction_prompt, langfuse_project_name=langfuse_project_name, + after="execute_approved_searches_step", ) # Collect tracing metadata for the entire pipeline run - _, tracing_metadata = collect_tracing_metadata_step( - state=final_state, + collect_tracing_metadata_step( + query_context=query_context, + search_data=enhanced_search_data, langfuse_project_name=langfuse_project_name, + after="pydantic_final_report_step", ) diff --git a/deep_research/steps/approval_step.py b/deep_research/steps/approval_step.py index c74277c9..93566049 100644 --- a/deep_research/steps/approval_step.py +++ b/deep_research/steps/approval_step.py @@ -1,26 +1,62 @@ import logging import time -from typing import Annotated +from typing import Annotated, List from materializers.approval_decision_materializer import ( ApprovalDecisionMaterializer, ) from utils.approval_utils import ( format_approval_request, - summarize_research_progress, ) -from utils.pydantic_models import ApprovalDecision, ReflectionOutput -from zenml import add_tags, log_metadata, step +from utils.pydantic_models import ( + AnalysisData, + ApprovalDecision, + QueryContext, + SynthesisData, +) +from zenml import log_metadata, step from zenml.client import Client logger = logging.getLogger(__name__) +def summarize_research_progress_from_artifacts( + synthesis_data: SynthesisData, analysis_data: AnalysisData +) -> dict: + """Summarize research progress from the new artifact structure.""" + completed_count = len(synthesis_data.synthesized_info) + + # Calculate confidence levels from synthesis data + confidence_levels = [] + for info in synthesis_data.synthesized_info.values(): + confidence_levels.append(info.confidence_level) + + # Calculate average confidence (high=1.0, medium=0.5, low=0.0) + confidence_map = {"high": 1.0, "medium": 0.5, "low": 0.0} + avg_confidence = sum( + confidence_map.get(c.lower(), 0.5) for c in confidence_levels + ) / max(len(confidence_levels), 1) + + low_confidence_count = sum( + 1 for c in confidence_levels if c.lower() == "low" + ) + + return { + "completed_count": completed_count, + "avg_confidence": round(avg_confidence, 2), + "low_confidence_count": low_confidence_count, + } + + @step( - enable_cache=False, output_materializers=ApprovalDecisionMaterializer + enable_cache=False, + output_materializers={"approval_decision": ApprovalDecisionMaterializer}, ) # Never cache approval decisions def get_research_approval_step( - reflection_output: ReflectionOutput, + query_context: QueryContext, + synthesis_data: SynthesisData, + analysis_data: AnalysisData, + recommended_queries: List[str], require_approval: bool = True, alerter_type: str = "slack", timeout: int = 3600, @@ -33,7 +69,10 @@ def get_research_approval_step( automatically approves all queries. Args: - reflection_output: Output from the reflection generation step + query_context: Context containing the main query and sub-questions + synthesis_data: Synthesized information from research + analysis_data: Analysis including viewpoints and critique + recommended_queries: List of recommended additional queries require_approval: Whether to require human approval alerter_type: Type of alerter to use (slack, email, etc.) timeout: Timeout in seconds for approval response @@ -45,7 +84,7 @@ def get_research_approval_step( start_time = time.time() # Limit queries to max_queries - limited_queries = reflection_output.recommended_queries[:max_queries] + limited_queries = recommended_queries[:max_queries] # If approval not required, auto-approve all if not require_approval: @@ -61,9 +100,7 @@ def get_research_approval_step( "execution_time_seconds": execution_time, "approval_required": False, "approval_method": "AUTO_APPROVED", - "num_queries_recommended": len( - reflection_output.recommended_queries - ), + "num_queries_recommended": len(recommended_queries), "num_queries_approved": len(limited_queries), "max_queries_allowed": max_queries, "approval_status": "approved", @@ -108,11 +145,29 @@ def get_research_approval_step( ) # Prepare approval request - progress_summary = summarize_research_progress(reflection_output.state) + progress_summary = summarize_research_progress_from_artifacts( + synthesis_data, analysis_data + ) + + # Extract critique points from analysis data + critique_points = [] + if analysis_data.critique_summary: + # Convert critique summary to list of dicts for compatibility + for i, critique in enumerate( + analysis_data.critique_summary.split("\n") + ): + if critique.strip(): + critique_points.append( + { + "issue": critique.strip(), + "importance": "high" if i < 3 else "medium", + } + ) + message = format_approval_request( - main_query=reflection_output.state.main_query, + main_query=query_context.main_query, progress_summary=progress_summary, - critique_points=reflection_output.critique_summary, + critique_points=critique_points, proposed_queries=limited_queries, timeout=timeout, ) @@ -140,9 +195,7 @@ def get_research_approval_step( "approval_required": require_approval, "approval_method": "NO_ALERTER_AUTO_APPROVED", "alerter_type": "none", - "num_queries_recommended": len( - reflection_output.recommended_queries - ), + "num_queries_recommended": len(recommended_queries), "num_queries_approved": len(limited_queries), "max_queries_allowed": max_queries, "approval_status": "auto_approved", @@ -202,7 +255,7 @@ def get_research_approval_step( "approval_method": "DISCORD_APPROVED", "alerter_type": alerter_type, "num_queries_recommended": len( - reflection_output.recommended_queries + recommended_queries ), "num_queries_approved": len(limited_queries), "max_queries_allowed": max_queries, @@ -230,7 +283,7 @@ def get_research_approval_step( "approval_method": "DISCORD_REJECTED", "alerter_type": alerter_type, "num_queries_recommended": len( - reflection_output.recommended_queries + recommended_queries ), "num_queries_approved": 0, "max_queries_allowed": max_queries, @@ -260,9 +313,7 @@ def get_research_approval_step( "approval_required": require_approval, "approval_method": "ALERTER_ERROR", "alerter_type": alerter_type, - "num_queries_recommended": len( - reflection_output.recommended_queries - ), + "num_queries_recommended": len(recommended_queries), "num_queries_approved": 0, "max_queries_allowed": max_queries, "approval_status": "error", @@ -289,9 +340,7 @@ def get_research_approval_step( "execution_time_seconds": execution_time, "approval_required": require_approval, "approval_method": "ERROR", - "num_queries_recommended": len( - reflection_output.recommended_queries - ), + "num_queries_recommended": len(recommended_queries), "num_queries_approved": 0, "max_queries_allowed": max_queries, "approval_status": "error", @@ -301,7 +350,7 @@ def get_research_approval_step( ) # Add tag to the approval decision artifact - add_tags(tags=["hitl"], artifact="approval_decision") + # add_tags(tags=["hitl"], artifact_name="approval_decision", infer_artifact=True) return ApprovalDecision( approved=False, diff --git a/deep_research/steps/collect_tracing_metadata_step.py b/deep_research/steps/collect_tracing_metadata_step.py index d7e2e8a6..bfaa70e8 100644 --- a/deep_research/steps/collect_tracing_metadata_step.py +++ b/deep_research/steps/collect_tracing_metadata_step.py @@ -1,15 +1,15 @@ """Step to collect tracing metadata from Langfuse for the pipeline run.""" import logging -from typing import Annotated, Tuple +from typing import Annotated, Dict -from materializers.pydantic_materializer import ResearchStateMaterializer from materializers.tracing_metadata_materializer import ( TracingMetadataMaterializer, ) from utils.pydantic_models import ( PromptTypeMetrics, - ResearchState, + QueryContext, + SearchData, TracingMetadata, ) from utils.tracing_metadata_utils import ( @@ -18,7 +18,7 @@ get_trace_stats, get_traces_by_name, ) -from zenml import add_tags, get_step_context, step +from zenml import get_step_context, step logger = logging.getLogger(__name__) @@ -26,28 +26,26 @@ @step( enable_cache=False, output_materializers={ - "state": ResearchStateMaterializer, "tracing_metadata": TracingMetadataMaterializer, }, ) def collect_tracing_metadata_step( - state: ResearchState, + query_context: QueryContext, + search_data: SearchData, langfuse_project_name: str, -) -> Tuple[ - Annotated[ResearchState, "state"], - Annotated[TracingMetadata, "tracing_metadata"], -]: +) -> Annotated[TracingMetadata, "tracing_metadata"]: """Collect tracing metadata from Langfuse for the current pipeline run. This step gathers comprehensive metrics about token usage, costs, and performance for the entire pipeline run, providing insights into resource consumption. Args: - state: The final research state + query_context: The query context (for reference) + search_data: The search data containing cost information langfuse_project_name: Langfuse project name for accessing traces Returns: - Tuple of (ResearchState, TracingMetadata) - the state is passed through unchanged + TracingMetadata with comprehensive cost and performance metrics """ ctx = get_step_context() pipeline_run_name = ctx.pipeline_run.name @@ -74,7 +72,9 @@ def collect_tracing_metadata_step( logger.warning( f"No trace found for pipeline run: {pipeline_run_name}" ) - return state, metadata + # Still add search costs before returning + _add_search_costs_to_metadata(metadata, search_data) + return metadata trace = traces[0] @@ -179,26 +179,8 @@ def collect_tracing_metadata_step( except Exception as e: logger.warning(f"Failed to collect prompt-level metrics: {str(e)}") - # Add search costs from the state - if hasattr(state, "search_costs") and state.search_costs: - metadata.search_costs = state.search_costs.copy() - logger.info(f"Added search costs: {metadata.search_costs}") - - if hasattr(state, "search_cost_details") and state.search_cost_details: - metadata.search_cost_details = state.search_cost_details.copy() - - # Count queries by provider - search_queries_count = {} - for detail in state.search_cost_details: - provider = detail.get("provider", "unknown") - search_queries_count[provider] = ( - search_queries_count.get(provider, 0) + 1 - ) - metadata.search_queries_count = search_queries_count - - logger.info( - f"Added {len(metadata.search_cost_details)} search cost detail entries" - ) + # Add search costs from the SearchData artifact + _add_search_costs_to_metadata(metadata, search_data) total_search_cost = sum(metadata.search_costs.values()) logger.info( @@ -216,25 +198,55 @@ def collect_tracing_metadata_step( f"Failed to collect tracing metadata for pipeline run {pipeline_run_name}: {str(e)}" ) # Return metadata with whatever we could collect - # Still try to get search costs even if Langfuse failed - if hasattr(state, "search_costs") and state.search_costs: - metadata.search_costs = state.search_costs.copy() - if hasattr(state, "search_cost_details") and state.search_cost_details: - metadata.search_cost_details = state.search_cost_details.copy() - # Count queries by provider - search_queries_count = {} - for detail in state.search_cost_details: - provider = detail.get("provider", "unknown") - search_queries_count[provider] = ( - search_queries_count.get(provider, 0) + 1 - ) - metadata.search_queries_count = search_queries_count + _add_search_costs_to_metadata(metadata, search_data) - # Add tags to the artifacts - add_tags(tags=["state"], artifact="state") - add_tags( - tags=["exa", "tavily", "llm", "cost"], artifact="tracing_metadata" - ) + # Add tags to the artifact + # add_tags( + # tags=["exa", "tavily", "llm", "cost", "tracing"], + # artifact_name="tracing_metadata", + # infer_artifact=True, + # ) - return state, metadata + return metadata + + +def _add_search_costs_to_metadata( + metadata: TracingMetadata, search_data: SearchData +) -> None: + """Add search costs from SearchData to TracingMetadata. + + Args: + metadata: The TracingMetadata object to update + search_data: The SearchData containing cost information + """ + if search_data.search_costs: + metadata.search_costs = search_data.search_costs.copy() + logger.info(f"Added search costs: {metadata.search_costs}") + + if search_data.search_cost_details: + # Convert SearchCostDetail objects to dicts for backward compatibility + metadata.search_cost_details = [ + { + "provider": detail.provider, + "query": detail.query, + "cost": detail.cost, + "timestamp": detail.timestamp, + "step": detail.step, + "sub_question": detail.sub_question, + } + for detail in search_data.search_cost_details + ] + + # Count queries by provider + search_queries_count: Dict[str, int] = {} + for detail in search_data.search_cost_details: + provider = detail.provider + search_queries_count[provider] = ( + search_queries_count.get(provider, 0) + 1 + ) + metadata.search_queries_count = search_queries_count + + logger.info( + f"Added {len(metadata.search_cost_details)} search cost detail entries" + ) diff --git a/deep_research/steps/cross_viewpoint_step.py b/deep_research/steps/cross_viewpoint_step.py index 0c24b6e3..9212608e 100644 --- a/deep_research/steps/cross_viewpoint_step.py +++ b/deep_research/steps/cross_viewpoint_step.py @@ -3,25 +3,28 @@ import time from typing import Annotated, List -from materializers.pydantic_materializer import ResearchStateMaterializer +from materializers.analysis_data_materializer import AnalysisDataMaterializer from utils.helper_functions import ( safe_json_loads, ) from utils.llm_utils import run_llm_completion from utils.pydantic_models import ( + AnalysisData, Prompt, - ResearchState, + QueryContext, + SynthesisData, ViewpointAnalysis, ViewpointTension, ) -from zenml import add_tags, log_metadata, step +from zenml import log_metadata, step logger = logging.getLogger(__name__) -@step(output_materializers=ResearchStateMaterializer) +@step(output_materializers={"analysis_data": AnalysisDataMaterializer}) def cross_viewpoint_analysis_step( - state: ResearchState, + query_context: QueryContext, + synthesis_data: SynthesisData, viewpoint_analysis_prompt: Prompt, llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", viewpoint_categories: List[str] = [ @@ -33,27 +36,32 @@ def cross_viewpoint_analysis_step( "historical", ], langfuse_project_name: str = "deep-research", -) -> Annotated[ResearchState, "analyzed_state"]: +) -> Annotated[AnalysisData, "analysis_data"]: """Analyze synthesized information across different viewpoints. Args: - state: The current research state + query_context: The query context with main query and sub-questions + synthesis_data: The synthesized information to analyze viewpoint_analysis_prompt: Prompt for viewpoint analysis llm_model: The model to use for viewpoint analysis viewpoint_categories: Categories of viewpoints to analyze + langfuse_project_name: Project name for tracing Returns: - Updated research state with viewpoint analysis + AnalysisData containing viewpoint analysis """ start_time = time.time() logger.info( - f"Performing cross-viewpoint analysis on {len(state.synthesized_info)} sub-questions" + f"Performing cross-viewpoint analysis on {len(synthesis_data.synthesized_info)} sub-questions" ) + # Initialize analysis data + analysis_data = AnalysisData() + # Prepare input for viewpoint analysis analysis_input = { - "main_query": state.main_query, - "sub_questions": state.sub_questions, + "main_query": query_context.main_query, + "sub_questions": query_context.sub_questions, "synthesized_information": { question: { "synthesized_answer": info.synthesized_answer, @@ -61,7 +69,7 @@ def cross_viewpoint_analysis_step( "confidence_level": info.confidence_level, "information_gaps": info.information_gaps, } - for question, info in state.synthesized_info.items() + for question, info in synthesis_data.synthesized_info.items() }, "viewpoint_categories": viewpoint_categories, } @@ -113,8 +121,8 @@ def cross_viewpoint_analysis_step( logger.info("Completed viewpoint analysis") - # Update the state with the viewpoint analysis - state.update_viewpoint_analysis(viewpoint_analysis) + # Update the analysis data with the viewpoint analysis + analysis_data.viewpoint_analysis = viewpoint_analysis # Calculate execution time execution_time = time.time() - start_time @@ -133,7 +141,9 @@ def cross_viewpoint_analysis_step( "viewpoint_analysis": { "execution_time_seconds": execution_time, "llm_model": llm_model, - "num_sub_questions_analyzed": len(state.synthesized_info), + "num_sub_questions_analyzed": len( + synthesis_data.synthesized_info + ), "viewpoint_categories_requested": viewpoint_categories, "num_agreement_points": len( viewpoint_analysis.main_points_of_agreement @@ -172,7 +182,7 @@ def cross_viewpoint_analysis_step( # Log artifact metadata log_metadata( metadata={ - "state_with_viewpoint_analysis": { + "analysis_data_characteristics": { "has_viewpoint_analysis": True, "total_viewpoints_analyzed": sum( tension_categories.values() @@ -184,13 +194,14 @@ def cross_viewpoint_analysis_step( else None, } }, + artifact_name="analysis_data", infer_artifact=True, ) # Add tags to the artifact - add_tags(tags=["state", "viewpoint"], artifact="analyzed_state") + # add_tags(tags=["analysis", "viewpoint"], artifact_name="analysis_data", infer_artifact=True) - return state + return analysis_data except Exception as e: logger.error(f"Error performing viewpoint analysis: {e}") @@ -204,8 +215,8 @@ def cross_viewpoint_analysis_step( integrative_insights="No insights available due to analysis failure.", ) - # Update the state with the fallback analysis - state.update_viewpoint_analysis(fallback_analysis) + # Update the analysis data with the fallback analysis + analysis_data.viewpoint_analysis = fallback_analysis # Log error metadata execution_time = time.time() - start_time @@ -214,7 +225,9 @@ def cross_viewpoint_analysis_step( "viewpoint_analysis": { "execution_time_seconds": execution_time, "llm_model": llm_model, - "num_sub_questions_analyzed": len(state.synthesized_info), + "num_sub_questions_analyzed": len( + synthesis_data.synthesized_info + ), "viewpoint_categories_requested": viewpoint_categories, "analysis_success": False, "error_message": str(e), @@ -224,6 +237,10 @@ def cross_viewpoint_analysis_step( ) # Add tags to the artifact - add_tags(tags=["state", "viewpoint"], artifact="analyzed_state") + # add_tags( + # tags=["analysis", "viewpoint", "fallback"], + # artifact_name="analysis_data", + # infer_artifact=True, + # ) - return state + return analysis_data diff --git a/deep_research/steps/execute_approved_searches_step.py b/deep_research/steps/execute_approved_searches_step.py index db1b46e9..f1eb8625 100644 --- a/deep_research/steps/execute_approved_searches_step.py +++ b/deep_research/steps/execute_approved_searches_step.py @@ -1,47 +1,45 @@ import json import logging import time -from typing import Annotated +from typing import Annotated, List, Tuple -from materializers.pydantic_materializer import ResearchStateMaterializer +from materializers.analysis_data_materializer import AnalysisDataMaterializer +from materializers.search_data_materializer import SearchDataMaterializer +from materializers.synthesis_data_materializer import SynthesisDataMaterializer from utils.llm_utils import ( find_most_relevant_string, get_structured_llm_output, is_text_relevant, ) from utils.pydantic_models import ( + AnalysisData, ApprovalDecision, Prompt, - ReflectionMetadata, - ReflectionOutput, - ResearchState, + QueryContext, + SearchCostDetail, + SearchData, + SynthesisData, SynthesizedInfo, ) from utils.search_utils import search_and_extract_results -from zenml import add_tags, log_metadata, step +from zenml import log_metadata, step logger = logging.getLogger(__name__) -def create_enhanced_info_copy(synthesized_info): - """Create a deep copy of synthesized info for enhancement.""" - return { - k: SynthesizedInfo( - synthesized_answer=v.synthesized_answer, - key_sources=v.key_sources.copy(), - confidence_level=v.confidence_level, - information_gaps=v.information_gaps, - improvements=v.improvements.copy() - if hasattr(v, "improvements") - else [], - ) - for k, v in synthesized_info.items() +@step( + output_materializers={ + "enhanced_search_data": SearchDataMaterializer, + "enhanced_synthesis_data": SynthesisDataMaterializer, + "updated_analysis_data": AnalysisDataMaterializer, } - - -@step(output_materializers=ResearchStateMaterializer) +) def execute_approved_searches_step( - reflection_output: ReflectionOutput, + query_context: QueryContext, + search_data: SearchData, + synthesis_data: SynthesisData, + analysis_data: AnalysisData, + recommended_queries: List[str], approval_decision: ApprovalDecision, additional_synthesis_prompt: Prompt, num_results_per_search: int = 3, @@ -50,32 +48,51 @@ def execute_approved_searches_step( search_provider: str = "tavily", search_mode: str = "auto", langfuse_project_name: str = "deep-research", -) -> Annotated[ResearchState, "updated_state"]: - """Execute approved searches and enhance the research state. +) -> Tuple[ + Annotated[SearchData, "enhanced_search_data"], + Annotated[SynthesisData, "enhanced_synthesis_data"], + Annotated[AnalysisData, "updated_analysis_data"], +]: + """Execute approved searches and enhance the research artifacts. This step receives the approval decision and only executes searches that were approved by the human reviewer (or auto-approved). Args: - reflection_output: Output from the reflection generation step + query_context: The query context with main query and sub-questions + search_data: The existing search data + synthesis_data: The existing synthesis data + analysis_data: The analysis data with viewpoint and reflection metadata + recommended_queries: The recommended queries from reflection approval_decision: Human approval decision + additional_synthesis_prompt: Prompt for synthesis enhancement num_results_per_search: Number of results to fetch per search cap_search_length: Maximum length of content to process from search results llm_model: The model to use for synthesis enhancement - prompts_bundle: Bundle containing all prompts for the pipeline search_provider: Search provider to use search_mode: Search mode for the provider + langfuse_project_name: Project name for tracing Returns: - Updated research state with enhanced information and reflection metadata + Tuple of enhanced SearchData, SynthesisData, and updated AnalysisData """ start_time = time.time() logger.info( f"Processing approval decision: {approval_decision.approval_method}" ) - state = reflection_output.state - enhanced_info = create_enhanced_info_copy(state.synthesized_info) + # Create copies of the data to enhance + enhanced_search_data = SearchData( + search_results=search_data.search_results.copy(), + search_costs=search_data.search_costs.copy(), + search_cost_details=search_data.search_cost_details.copy(), + total_searches=search_data.total_searches, + ) + + enhanced_synthesis_data = SynthesisData( + synthesized_info=synthesis_data.synthesized_info.copy(), + enhanced_info={}, # Will be populated with enhanced versions + ) # Track improvements count improvements_count = 0 @@ -87,39 +104,10 @@ def execute_approved_searches_step( ): logger.info("No additional searches approved") - # Add any additional questions as new synthesized entries (from reflection) - for new_question in reflection_output.additional_questions: - if ( - new_question not in state.sub_questions - and new_question not in enhanced_info - ): - enhanced_info[new_question] = SynthesizedInfo( - synthesized_answer=f"This question was identified during reflection but has not yet been researched: {new_question}", - key_sources=[], - confidence_level="low", - information_gaps="This question requires additional research.", - ) - - # Create metadata indicating no additional research - reflection_metadata = ReflectionMetadata( - critique_summary=[ - c.get("issue", "") for c in reflection_output.critique_summary - ], - additional_questions_identified=reflection_output.additional_questions, - searches_performed=[], - improvements_made=improvements_count, - ) - - # Add approval decision info to metadata - if hasattr(reflection_metadata, "__dict__"): - reflection_metadata.__dict__["user_decision"] = ( - approval_decision.approval_method - ) - reflection_metadata.__dict__["reviewer_notes"] = ( - approval_decision.reviewer_notes - ) - - state.update_after_reflection(enhanced_info, reflection_metadata) + # Update reflection metadata with no searches + if analysis_data.reflection_metadata: + analysis_data.reflection_metadata.searches_performed = [] + analysis_data.reflection_metadata.improvements_made = 0.0 # Log metadata for no approved searches execution_time = time.time() - start_time @@ -133,9 +121,7 @@ def execute_approved_searches_step( else "no_queries", "num_queries_approved": 0, "num_searches_executed": 0, - "num_additional_questions": len( - reflection_output.additional_questions - ), + "num_recommended": len(recommended_queries), "improvements_made": improvements_count, "search_provider": search_provider, "llm_model": llm_model, @@ -143,10 +129,20 @@ def execute_approved_searches_step( } ) - # Add tags to the artifact - add_tags(tags=["state", "enhanced"], artifact="updated_state") - - return state + # Add tags to the artifacts + # add_tags( + # tags=["search", "not-enhanced"], artifact_name="enhanced_search_data", infer_artifact=True + # ) + # add_tags( + # tags=["synthesis", "not-enhanced"], + # artifact_name="enhanced_synthesis_data", + # infer_artifact=True, + # ) + # add_tags( + # tags=["analysis", "no-searches"], artifact_name="updated_analysis_data", infer_artifact=True + # ) + + return enhanced_search_data, enhanced_synthesis_data, analysis_data # Execute approved searches logger.info( @@ -175,53 +171,72 @@ def execute_approved_searches_step( and search_cost > 0 ): # Update total costs - state.search_costs["exa"] = ( - state.search_costs.get("exa", 0.0) + search_cost + enhanced_search_data.search_costs["exa"] = ( + enhanced_search_data.search_costs.get("exa", 0.0) + + search_cost ) # Add detailed cost entry - state.search_cost_details.append( - { - "provider": "exa", - "query": query, - "cost": search_cost, - "timestamp": time.time(), - "step": "execute_approved_searches", - "purpose": "reflection_enhancement", - } + enhanced_search_data.search_cost_details.append( + SearchCostDetail( + provider="exa", + query=query, + cost=search_cost, + timestamp=time.time(), + step="execute_approved_searches", + sub_question=None, # These are reflection queries + ) ) logger.info( f"Exa search cost for approved query: ${search_cost:.4f}" ) + # Update total searches + enhanced_search_data.total_searches += 1 + # Extract raw contents raw_contents = [result.content for result in search_results] # Find the most relevant sub-question for this query most_relevant_question = find_most_relevant_string( query, - state.sub_questions, + query_context.sub_questions, llm_model, project=langfuse_project_name, ) if ( most_relevant_question - and most_relevant_question in enhanced_info + and most_relevant_question in synthesis_data.synthesized_info ): + # Store the search results under the relevant question + if ( + most_relevant_question + in enhanced_search_data.search_results + ): + enhanced_search_data.search_results[ + most_relevant_question + ].extend(search_results) + else: + enhanced_search_data.search_results[ + most_relevant_question + ] = search_results + # Enhance the synthesis with new information + original_synthesis = synthesis_data.synthesized_info[ + most_relevant_question + ] + enhancement_input = { - "original_synthesis": enhanced_info[ - most_relevant_question - ].synthesized_answer, + "original_synthesis": original_synthesis.synthesized_answer, "new_information": raw_contents, "critique": [ item - for item in reflection_output.critique_summary - if is_text_relevant( - item.get("issue", ""), most_relevant_question - ) - ], + for item in analysis_data.reflection_metadata.critique_summary + if is_text_relevant(item, most_relevant_question) + ] + if analysis_data.reflection_metadata + else [], } # Use the utility function for enhancement @@ -230,9 +245,7 @@ def execute_approved_searches_step( system_prompt=str(additional_synthesis_prompt), model=llm_model, fallback_response={ - "enhanced_synthesis": enhanced_info[ - most_relevant_question - ].synthesized_answer, + "enhanced_synthesis": original_synthesis.synthesized_answer, "improvements_made": ["Failed to enhance synthesis"], "remaining_limitations": "Enhancement process failed.", }, @@ -243,20 +256,31 @@ def execute_approved_searches_step( enhanced_synthesis and "enhanced_synthesis" in enhanced_synthesis ): - # Update the synthesized answer - enhanced_info[ + # Create enhanced synthesis info + enhanced_info = SynthesizedInfo( + synthesized_answer=enhanced_synthesis[ + "enhanced_synthesis" + ], + key_sources=original_synthesis.key_sources + + [r.url for r in search_results[:2]], + confidence_level="high" + if original_synthesis.confidence_level == "medium" + else original_synthesis.confidence_level, + information_gaps=enhanced_synthesis.get( + "remaining_limitations", "" + ), + improvements=original_synthesis.improvements + + enhanced_synthesis.get("improvements_made", []), + ) + + # Store in enhanced_info + enhanced_synthesis_data.enhanced_info[ most_relevant_question - ].synthesized_answer = enhanced_synthesis[ - "enhanced_synthesis" - ] + ] = enhanced_info - # Add improvements improvements = enhanced_synthesis.get( "improvements_made", [] ) - enhanced_info[most_relevant_question].improvements.extend( - improvements - ) improvements_count += len(improvements) # Track enhancement for metadata @@ -274,44 +298,19 @@ def execute_approved_searches_step( } ) - # Add any additional questions as new synthesized entries - for new_question in reflection_output.additional_questions: - if ( - new_question not in state.sub_questions - and new_question not in enhanced_info - ): - enhanced_info[new_question] = SynthesizedInfo( - synthesized_answer=f"This question was identified during reflection but has not yet been researched: {new_question}", - key_sources=[], - confidence_level="low", - information_gaps="This question requires additional research.", - ) - - # Create final metadata with approval info - reflection_metadata = ReflectionMetadata( - critique_summary=[ - c.get("issue", "") for c in reflection_output.critique_summary - ], - additional_questions_identified=reflection_output.additional_questions, - searches_performed=approval_decision.selected_queries, - improvements_made=improvements_count, - ) - - # Add approval decision info to metadata - if hasattr(reflection_metadata, "__dict__"): - reflection_metadata.__dict__["user_decision"] = ( - approval_decision.approval_method + # Update reflection metadata with search info + if analysis_data.reflection_metadata: + analysis_data.reflection_metadata.searches_performed = ( + approval_decision.selected_queries ) - reflection_metadata.__dict__["reviewer_notes"] = ( - approval_decision.reviewer_notes + analysis_data.reflection_metadata.improvements_made = float( + improvements_count ) logger.info( f"Completed approved searches with {improvements_count} improvements" ) - state.update_after_reflection(enhanced_info, reflection_metadata) - # Calculate metrics for metadata execution_time = time.time() - start_time total_results = sum( @@ -332,9 +331,7 @@ def execute_approved_searches_step( "execution_time_seconds": execution_time, "approval_method": approval_decision.approval_method, "approval_status": "approved", - "num_queries_recommended": len( - reflection_output.recommended_queries - ), + "num_queries_recommended": len(recommended_queries), "num_queries_approved": len( approval_decision.selected_queries ), @@ -344,14 +341,13 @@ def execute_approved_searches_step( "total_search_results": total_results, "questions_enhanced": questions_enhanced, "improvements_made": improvements_count, - "num_additional_questions": len( - reflection_output.additional_questions - ), "search_provider": search_provider, "search_mode": search_mode, "llm_model": llm_model, "success": True, - "total_search_cost": state.search_costs.get("exa", 0.0), + "total_search_cost": enhanced_search_data.search_costs.get( + "exa", 0.0 + ), } } ) @@ -359,44 +355,69 @@ def execute_approved_searches_step( # Log artifact metadata log_metadata( metadata={ - "enhanced_state_after_approval": { - "total_questions": len(enhanced_info), - "questions_with_improvements": sum( - 1 - for info in enhanced_info.values() - if info.improvements + "search_data_characteristics": { + "new_searches": len(approval_decision.selected_queries), + "total_searches": enhanced_search_data.total_searches, + "additional_cost": enhanced_search_data.search_costs.get( + "exa", 0.0 + ) + - search_data.search_costs.get("exa", 0.0), + } + }, + artifact_name="enhanced_search_data", + infer_artifact=True, + ) + + log_metadata( + metadata={ + "synthesis_data_characteristics": { + "questions_enhanced": questions_enhanced, + "total_enhancements": len( + enhanced_synthesis_data.enhanced_info ), - "total_improvements": sum( - len(info.improvements) - for info in enhanced_info.values() + "improvements_made": improvements_count, + } + }, + artifact_name="enhanced_synthesis_data", + infer_artifact=True, + ) + + log_metadata( + metadata={ + "analysis_data_characteristics": { + "searches_performed": len( + approval_decision.selected_queries ), "approval_method": approval_decision.approval_method, } }, + artifact_name="updated_analysis_data", infer_artifact=True, ) - # Add tags to the artifact - add_tags(tags=["state", "enhanced"], artifact="updated_state") + # Add tags to the artifacts + # add_tags(tags=["search", "enhanced"], artifact_name="enhanced_search_data", infer_artifact=True) + # add_tags( + # tags=["synthesis", "enhanced"], artifact_name="enhanced_synthesis_data", infer_artifact=True + # ) + # add_tags( + # tags=["analysis", "with-searches"], + # artifact_name="updated_analysis_data", + # infer_artifact=True, + # ) - return state + return enhanced_search_data, enhanced_synthesis_data, analysis_data except Exception as e: logger.error(f"Error during approved search execution: {e}") - # Create error metadata - error_metadata = ReflectionMetadata( - error=f"Approved search execution failed: {str(e)}", - critique_summary=[ - c.get("issue", "") for c in reflection_output.critique_summary - ], - additional_questions_identified=reflection_output.additional_questions, - searches_performed=[], - improvements_made=0, - ) - - # Update the state with the original synthesized info as enhanced info - state.update_after_reflection(state.synthesized_info, error_metadata) + # Update reflection metadata with error + if analysis_data.reflection_metadata: + analysis_data.reflection_metadata.error = ( + f"Approved search execution failed: {str(e)}" + ) + analysis_data.reflection_metadata.searches_performed = [] + analysis_data.reflection_metadata.improvements_made = 0.0 # Log error metadata execution_time = time.time() - start_time @@ -419,7 +440,11 @@ def execute_approved_searches_step( } ) - # Add tags to the artifact - add_tags(tags=["state", "enhanced"], artifact="updated_state") + # Add tags to the artifacts + # add_tags(tags=["search", "error"], artifact_name="enhanced_search_data", infer_artifact=True) + # add_tags( + # tags=["synthesis", "error"], artifact_name="enhanced_synthesis_data", infer_artifact=True + # ) + # add_tags(tags=["analysis", "error"], artifact_name="updated_analysis_data", infer_artifact=True) - return state + return enhanced_search_data, enhanced_synthesis_data, analysis_data diff --git a/deep_research/steps/generate_reflection_step.py b/deep_research/steps/generate_reflection_step.py index 9081cc84..61cf4637 100644 --- a/deep_research/steps/generate_reflection_step.py +++ b/deep_research/steps/generate_reflection_step.py @@ -1,25 +1,38 @@ import json import logging import time -from typing import Annotated +from typing import Annotated, List, Tuple -from materializers.reflection_output_materializer import ( - ReflectionOutputMaterializer, -) +from materializers.analysis_data_materializer import AnalysisDataMaterializer from utils.llm_utils import get_structured_llm_output -from utils.pydantic_models import Prompt, ReflectionOutput, ResearchState -from zenml import add_tags, log_metadata, step +from utils.pydantic_models import ( + AnalysisData, + Prompt, + QueryContext, + ReflectionMetadata, + SynthesisData, +) +from zenml import log_metadata, step logger = logging.getLogger(__name__) -@step(output_materializers=ReflectionOutputMaterializer) +@step( + output_materializers={ + "analysis_data": AnalysisDataMaterializer, + } +) def generate_reflection_step( - state: ResearchState, + query_context: QueryContext, + synthesis_data: SynthesisData, + analysis_data: AnalysisData, reflection_prompt: Prompt, llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", langfuse_project_name: str = "deep-research", -) -> Annotated[ReflectionOutput, "reflection_output"]: +) -> Tuple[ + Annotated[AnalysisData, "analysis_data"], + Annotated[List[str], "recommended_queries"], +]: """ Generate reflection and recommendations WITHOUT executing searches. @@ -27,12 +40,15 @@ def generate_reflection_step( for additional research that could improve the quality of the results. Args: - state: The current research state + query_context: The query context with main query and sub-questions + synthesis_data: The synthesized information + analysis_data: The analysis data with viewpoint analysis reflection_prompt: Prompt for generating reflection llm_model: The model to use for reflection + langfuse_project_name: Project name for tracing Returns: - ReflectionOutput containing the state, recommendations, and critique + Tuple of updated AnalysisData and recommended queries """ start_time = time.time() logger.info("Generating reflection on research") @@ -45,28 +61,28 @@ def generate_reflection_step( "confidence_level": info.confidence_level, "information_gaps": info.information_gaps, } - for question, info in state.synthesized_info.items() + for question, info in synthesis_data.synthesized_info.items() } viewpoint_analysis_dict = None - if state.viewpoint_analysis: + if analysis_data.viewpoint_analysis: # Convert the viewpoint analysis to a dict for the LLM tension_list = [] - for tension in state.viewpoint_analysis.areas_of_tension: + for tension in analysis_data.viewpoint_analysis.areas_of_tension: tension_list.append( {"topic": tension.topic, "viewpoints": tension.viewpoints} ) viewpoint_analysis_dict = { - "main_points_of_agreement": state.viewpoint_analysis.main_points_of_agreement, + "main_points_of_agreement": analysis_data.viewpoint_analysis.main_points_of_agreement, "areas_of_tension": tension_list, - "perspective_gaps": state.viewpoint_analysis.perspective_gaps, - "integrative_insights": state.viewpoint_analysis.integrative_insights, + "perspective_gaps": analysis_data.viewpoint_analysis.perspective_gaps, + "integrative_insights": analysis_data.viewpoint_analysis.integrative_insights, } reflection_input = { - "main_query": state.main_query, - "sub_questions": state.sub_questions, + "main_query": query_context.main_query, + "sub_questions": query_context.sub_questions, "synthesized_information": synthesized_info_dict, } @@ -92,14 +108,21 @@ def generate_reflection_step( project=langfuse_project_name, ) - # Prepare return value - reflection_output = ReflectionOutput( - state=state, - recommended_queries=reflection_result.get( - "recommended_search_queries", [] - ), - critique_summary=reflection_result.get("critique", []), - additional_questions=reflection_result.get("additional_questions", []), + # Extract results + recommended_queries = reflection_result.get( + "recommended_search_queries", [] + ) + critique_summary = reflection_result.get("critique", []) + additional_questions = reflection_result.get("additional_questions", []) + + # Update analysis data with reflection metadata + analysis_data.reflection_metadata = ReflectionMetadata( + critique_summary=[ + str(c) for c in critique_summary + ], # Convert to strings + additional_questions_identified=additional_questions, + searches_performed=[], # Will be populated by execute_approved_searches_step + improvements_made=0.0, # Will be updated later ) # Calculate execution time @@ -107,7 +130,8 @@ def generate_reflection_step( # Count confidence levels in synthesized info confidence_levels = [ - info.confidence_level for info in state.synthesized_info.values() + info.confidence_level + for info in synthesis_data.synthesized_info.values() ] confidence_distribution = { "high": confidence_levels.count("high"), @@ -121,20 +145,18 @@ def generate_reflection_step( "reflection_generation": { "execution_time_seconds": execution_time, "llm_model": llm_model, - "num_sub_questions_analyzed": len(state.sub_questions), - "num_synthesized_answers": len(state.synthesized_info), - "viewpoint_analysis_included": bool(viewpoint_analysis_dict), - "num_critique_points": len(reflection_output.critique_summary), - "num_additional_questions": len( - reflection_output.additional_questions - ), - "num_recommended_queries": len( - reflection_output.recommended_queries + "num_sub_questions_analyzed": len(query_context.sub_questions), + "num_synthesized_answers": len( + synthesis_data.synthesized_info ), + "viewpoint_analysis_included": bool(viewpoint_analysis_dict), + "num_critique_points": len(critique_summary), + "num_additional_questions": len(additional_questions), + "num_recommended_queries": len(recommended_queries), "confidence_distribution": confidence_distribution, "has_information_gaps": any( info.information_gaps - for info in state.synthesized_info.values() + for info in synthesis_data.synthesized_info.values() ), } } @@ -143,24 +165,33 @@ def generate_reflection_step( # Log artifact metadata log_metadata( metadata={ - "reflection_output_characteristics": { - "has_recommendations": bool( - reflection_output.recommended_queries - ), - "has_critique": bool(reflection_output.critique_summary), - "has_additional_questions": bool( - reflection_output.additional_questions - ), - "total_recommendations": len( - reflection_output.recommended_queries - ) - + len(reflection_output.additional_questions), + "analysis_data_characteristics": { + "has_reflection_metadata": True, + "has_viewpoint_analysis": analysis_data.viewpoint_analysis + is not None, + "num_critique_points": len(critique_summary), + "num_additional_questions": len(additional_questions), + } + }, + artifact_name="analysis_data", + infer_artifact=True, + ) + + log_metadata( + metadata={ + "recommended_queries_characteristics": { + "num_queries": len(recommended_queries), + "has_recommendations": bool(recommended_queries), } }, + artifact_name="recommended_queries", infer_artifact=True, ) - # Add tags to the artifact - add_tags(tags=["reflection", "critique"], artifact="reflection_output") + # Add tags to the artifacts + # add_tags(tags=["analysis", "reflection"], artifact_name="analysis_data", infer_artifact=True) + # add_tags( + # tags=["recommendations", "queries"], artifact_name="recommended_queries", infer_artifact=True + # ) - return reflection_output + return analysis_data, recommended_queries diff --git a/deep_research/steps/initialize_prompts_step.py b/deep_research/steps/initialize_prompts_step.py index f8df4c75..fe47c395 100644 --- a/deep_research/steps/initialize_prompts_step.py +++ b/deep_research/steps/initialize_prompts_step.py @@ -10,7 +10,7 @@ from materializers.prompt_materializer import PromptMaterializer from utils import prompts from utils.pydantic_models import Prompt -from zenml import add_tags, step +from zenml import step logger = logging.getLogger(__name__) @@ -119,26 +119,27 @@ def initialize_prompts_step( logger.info(f"Loaded 9 individual prompts") - # add tags to all prompts - add_tags(tags=["prompt", "search"], artifact="search_query_prompt") - add_tags( - tags=["prompt", "generation"], artifact="query_decomposition_prompt" - ) - add_tags(tags=["prompt", "generation"], artifact="synthesis_prompt") - add_tags( - tags=["prompt", "generation"], artifact="viewpoint_analysis_prompt" - ) - add_tags(tags=["prompt", "generation"], artifact="reflection_prompt") - add_tags( - tags=["prompt", "generation"], artifact="additional_synthesis_prompt" - ) - add_tags( - tags=["prompt", "generation"], artifact="conclusion_generation_prompt" - ) - add_tags( - tags=["prompt", "generation"], artifact="executive_summary_prompt" - ) - add_tags(tags=["prompt", "generation"], artifact="introduction_prompt") + # # add tags to all prompts + # add_tags(tags=["prompt", "search"], artifact_name="search_query_prompt", infer_artifact=True) + + # add_tags( + # tags=["prompt", "generation"], artifact_name="query_decomposition_prompt", infer_artifact=True + # ) + # add_tags(tags=["prompt", "generation"], artifact_name="synthesis_prompt", infer_artifact=True) + # add_tags( + # tags=["prompt", "generation"], artifact_name="viewpoint_analysis_prompt", infer_artifact=True + # ) + # add_tags(tags=["prompt", "generation"], artifact_name="reflection_prompt", infer_artifact=True) + # add_tags( + # tags=["prompt", "generation"], artifact_name="additional_synthesis_prompt", infer_artifact=True + # ) + # add_tags( + # tags=["prompt", "generation"], artifact_name="conclusion_generation_prompt", infer_artifact=True + # ) + # add_tags( + # tags=["prompt", "generation"], artifact_name="executive_summary_prompt", infer_artifact=True + # ) + # add_tags(tags=["prompt", "generation"], artifact_name="introduction_prompt", infer_artifact=True) return ( search_query_prompt, diff --git a/deep_research/steps/iterative_reflection_step.py b/deep_research/steps/iterative_reflection_step.py deleted file mode 100644 index 593c26d9..00000000 --- a/deep_research/steps/iterative_reflection_step.py +++ /dev/null @@ -1,391 +0,0 @@ -import json -import logging -import time -from typing import Annotated - -from materializers.pydantic_materializer import ResearchStateMaterializer -from utils.llm_utils import ( - find_most_relevant_string, - get_structured_llm_output, - is_text_relevant, -) -from utils.prompts import ADDITIONAL_SYNTHESIS_PROMPT, REFLECTION_PROMPT -from utils.pydantic_models import ( - ReflectionMetadata, - ResearchState, - SynthesizedInfo, -) -from utils.search_utils import search_and_extract_results -from zenml import add_tags, log_metadata, step - -logger = logging.getLogger(__name__) - - -@step(output_materializers=ResearchStateMaterializer) -def iterative_reflection_step( - state: ResearchState, - max_additional_searches: int = 2, - num_results_per_search: int = 3, - cap_search_length: int = 20000, - llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", - reflection_prompt: str = REFLECTION_PROMPT, - additional_synthesis_prompt: str = ADDITIONAL_SYNTHESIS_PROMPT, -) -> Annotated[ResearchState, "reflected_state"]: - """Perform iterative reflection on the research, identifying gaps and improving it. - - Args: - state: The current research state - max_additional_searches: Maximum number of additional searches to perform - num_results_per_search: Number of results to fetch per search - cap_search_length: Maximum length of content to process from search results - llm_model: The model to use for reflection - reflection_prompt: System prompt for the reflection - additional_synthesis_prompt: System prompt for incorporating additional information - - Returns: - Updated research state with enhanced information and reflection metadata - """ - start_time = time.time() - logger.info("Starting iterative reflection on research") - - # Prepare input for reflection - synthesized_info_dict = { - question: { - "synthesized_answer": info.synthesized_answer, - "key_sources": info.key_sources, - "confidence_level": info.confidence_level, - "information_gaps": info.information_gaps, - } - for question, info in state.synthesized_info.items() - } - - viewpoint_analysis_dict = None - if state.viewpoint_analysis: - # Convert the viewpoint analysis to a dict for the LLM - tension_list = [] - for tension in state.viewpoint_analysis.areas_of_tension: - tension_list.append( - {"topic": tension.topic, "viewpoints": tension.viewpoints} - ) - - viewpoint_analysis_dict = { - "main_points_of_agreement": state.viewpoint_analysis.main_points_of_agreement, - "areas_of_tension": tension_list, - "perspective_gaps": state.viewpoint_analysis.perspective_gaps, - "integrative_insights": state.viewpoint_analysis.integrative_insights, - } - - reflection_input = { - "main_query": state.main_query, - "sub_questions": state.sub_questions, - "synthesized_information": synthesized_info_dict, - } - - if viewpoint_analysis_dict: - reflection_input["viewpoint_analysis"] = viewpoint_analysis_dict - - # Get reflection critique - try: - logger.info(f"Generating self-critique via {llm_model}") - - # Define fallback for reflection result - fallback_reflection = { - "critique": [], - "additional_questions": [], - "recommended_search_queries": [], - } - - # Use utility function to get structured output - reflection_result = get_structured_llm_output( - prompt=json.dumps(reflection_input), - system_prompt=reflection_prompt, - model=llm_model, - fallback_response=fallback_reflection, - ) - - # Make a deep copy of the synthesized info to create enhanced_info - enhanced_info = { - k: SynthesizedInfo( - synthesized_answer=v.synthesized_answer, - key_sources=v.key_sources.copy(), - confidence_level=v.confidence_level, - information_gaps=v.information_gaps, - improvements=v.improvements.copy() - if hasattr(v, "improvements") - else [], - ) - for k, v in state.synthesized_info.items() - } - - # Perform additional searches based on recommendations - search_queries = reflection_result.get( - "recommended_search_queries", [] - ) - if max_additional_searches > 0 and search_queries: - # Limit to max_additional_searches - search_queries = search_queries[:max_additional_searches] - - for query in search_queries: - logger.info(f"Performing additional search: {query}") - # Execute the search using the utility function - search_results, search_cost = search_and_extract_results( - query=query, - max_results=num_results_per_search, - cap_content_length=cap_search_length, - ) - - # Extract raw contents - raw_contents = [result.content for result in search_results] - - # Find the most relevant sub-question for this query - most_relevant_question = find_most_relevant_string( - query, state.sub_questions, llm_model - ) - - # Track search costs if using Exa (default provider) - # Note: This step doesn't have a search_provider parameter, so we check the default - from utils.search_utils import SearchEngineConfig - - config = SearchEngineConfig() - if ( - config.default_provider.lower() in ["exa", "both"] - and search_cost > 0 - ): - # Update total costs - state.search_costs["exa"] = ( - state.search_costs.get("exa", 0.0) + search_cost - ) - - # Add detailed cost entry - state.search_cost_details.append( - { - "provider": "exa", - "query": query, - "cost": search_cost, - "timestamp": time.time(), - "step": "iterative_reflection", - "purpose": "gap_filling", - "relevant_question": most_relevant_question, - } - ) - logger.info( - f"Exa search cost for reflection query: ${search_cost:.4f}" - ) - - if ( - most_relevant_question - and most_relevant_question in enhanced_info - ): - # Enhance the synthesis with new information - enhancement_input = { - "original_synthesis": enhanced_info[ - most_relevant_question - ].synthesized_answer, - "new_information": raw_contents, - "critique": [ - item - for item in reflection_result.get("critique", []) - if is_text_relevant( - item.get("issue", ""), most_relevant_question - ) - ], - } - - # Use the utility function for enhancement - enhanced_synthesis = get_structured_llm_output( - prompt=json.dumps(enhancement_input), - system_prompt=additional_synthesis_prompt, - model=llm_model, - fallback_response={ - "enhanced_synthesis": enhanced_info[ - most_relevant_question - ].synthesized_answer, - "improvements_made": [ - "Failed to enhance synthesis" - ], - "remaining_limitations": "Enhancement process failed.", - }, - ) - - if ( - enhanced_synthesis - and "enhanced_synthesis" in enhanced_synthesis - ): - # Update the synthesized answer - enhanced_info[ - most_relevant_question - ].synthesized_answer = enhanced_synthesis[ - "enhanced_synthesis" - ] - - # Add improvements - improvements = enhanced_synthesis.get( - "improvements_made", [] - ) - enhanced_info[ - most_relevant_question - ].improvements.extend(improvements) - - # Add any additional questions as new synthesized entries - for new_question in reflection_result.get("additional_questions", []): - if ( - new_question not in state.sub_questions - and new_question not in enhanced_info - ): - enhanced_info[new_question] = SynthesizedInfo( - synthesized_answer=f"This question was identified during reflection but has not yet been researched: {new_question}", - key_sources=[], - confidence_level="low", - information_gaps="This question requires additional research.", - ) - - # Prepare metadata about the reflection process - reflection_metadata = ReflectionMetadata( - critique_summary=[ - item.get("issue", "") - for item in reflection_result.get("critique", []) - ], - additional_questions_identified=reflection_result.get( - "additional_questions", [] - ), - searches_performed=search_queries, - improvements_made=sum( - [len(info.improvements) for info in enhanced_info.values()] - ), - ) - - logger.info( - f"Completed iterative reflection with {reflection_metadata.improvements_made} improvements" - ) - - # Update the state with enhanced info and metadata - state.update_after_reflection(enhanced_info, reflection_metadata) - - # Calculate execution time - execution_time = time.time() - start_time - - # Count questions that were enhanced - questions_enhanced = 0 - for question, enhanced in enhanced_info.items(): - if question in state.synthesized_info: - original = state.synthesized_info[question] - if enhanced.synthesized_answer != original.synthesized_answer: - questions_enhanced += 1 - - # Calculate confidence level changes - confidence_improvements = {"improved": 0, "unchanged": 0, "new": 0} - for question, enhanced in enhanced_info.items(): - if question in state.synthesized_info: - original = state.synthesized_info[question] - original_level = original.confidence_level.lower() - enhanced_level = enhanced.confidence_level.lower() - - level_map = {"low": 0, "medium": 1, "high": 2} - if enhanced_level in level_map and original_level in level_map: - if level_map[enhanced_level] > level_map[original_level]: - confidence_improvements["improved"] += 1 - else: - confidence_improvements["unchanged"] += 1 - else: - confidence_improvements["new"] += 1 - - # Log metadata - log_metadata( - metadata={ - "iterative_reflection": { - "execution_time_seconds": execution_time, - "llm_model": llm_model, - "max_additional_searches": max_additional_searches, - "searches_performed": len(search_queries), - "num_critique_points": len( - reflection_result.get("critique", []) - ), - "num_additional_questions": len( - reflection_result.get("additional_questions", []) - ), - "questions_enhanced": questions_enhanced, - "total_improvements": reflection_metadata.improvements_made, - "confidence_improvements": confidence_improvements, - "has_viewpoint_analysis": bool(viewpoint_analysis_dict), - "total_search_cost": state.search_costs.get("exa", 0.0), - } - } - ) - - # Log model metadata for cross-pipeline tracking - log_metadata( - metadata={ - "improvement_metrics": { - "confidence_improvements": confidence_improvements, - "total_improvements": reflection_metadata.improvements_made, - } - }, - infer_model=True, - ) - - # Log artifact metadata - log_metadata( - metadata={ - "enhanced_state_characteristics": { - "total_questions": len(enhanced_info), - "questions_with_improvements": sum( - 1 - for info in enhanced_info.values() - if info.improvements - ), - "high_confidence_count": sum( - 1 - for info in enhanced_info.values() - if info.confidence_level.lower() == "high" - ), - "medium_confidence_count": sum( - 1 - for info in enhanced_info.values() - if info.confidence_level.lower() == "medium" - ), - "low_confidence_count": sum( - 1 - for info in enhanced_info.values() - if info.confidence_level.lower() == "low" - ), - } - }, - infer_artifact=True, - ) - - # Add tags to the artifact - add_tags(tags=["state", "reflected"], artifact="reflected_state") - - return state - - except Exception as e: - logger.error(f"Error during iterative reflection: {e}") - - # Create error metadata - error_metadata = ReflectionMetadata( - error=f"Reflection failed: {str(e)}" - ) - - # Update the state with the original synthesized info as enhanced info - # and the error metadata - state.update_after_reflection(state.synthesized_info, error_metadata) - - # Log error metadata - execution_time = time.time() - start_time - log_metadata( - metadata={ - "iterative_reflection": { - "execution_time_seconds": execution_time, - "llm_model": llm_model, - "max_additional_searches": max_additional_searches, - "searches_performed": 0, - "status": "failed", - "error_message": str(e), - } - } - ) - - # Add tags to the artifact - add_tags(tags=["state", "reflected"], artifact="reflected_state") - - return state diff --git a/deep_research/steps/merge_results_step.py b/deep_research/steps/merge_results_step.py index d334c610..4802c98f 100644 --- a/deep_research/steps/merge_results_step.py +++ b/deep_research/steps/merge_results_step.py @@ -1,35 +1,38 @@ -import copy import logging import time -from typing import Annotated +from typing import Annotated, Tuple -from materializers.pydantic_materializer import ResearchStateMaterializer -from utils.pydantic_models import ResearchState -from zenml import add_tags, get_step_context, log_metadata, step +from materializers.search_data_materializer import SearchDataMaterializer +from materializers.synthesis_data_materializer import SynthesisDataMaterializer +from utils.pydantic_models import SearchData, SynthesisData +from zenml import get_step_context, log_metadata, step from zenml.client import Client logger = logging.getLogger(__name__) -@step(output_materializers=ResearchStateMaterializer) +@step( + output_materializers={ + "merged_search_data": SearchDataMaterializer, + "merged_synthesis_data": SynthesisDataMaterializer, + } +) def merge_sub_question_results_step( - original_state: ResearchState, step_prefix: str = "process_question_", - output_name: str = "output", -) -> Annotated[ResearchState, "merged_state"]: +) -> Tuple[ + Annotated[SearchData, "merged_search_data"], + Annotated[SynthesisData, "merged_synthesis_data"], +]: """Merge results from individual sub-question processing steps. This step collects the results from the parallel sub-question processing steps - and combines them into a single, comprehensive state object. + and combines them into single SearchData and SynthesisData artifacts. Args: - original_state: The original research state with all sub-questions step_prefix: The prefix used in step IDs for the parallel processing steps - output_name: The name of the output artifact from the processing steps Returns: - Annotated[ResearchState, "merged_state"]: A merged ResearchState with combined - results from all sub-questions + Tuple of merged SearchData and SynthesisData artifacts Note: This step is typically configured with the 'after' parameter in the pipeline @@ -38,23 +41,16 @@ def merge_sub_question_results_step( """ start_time = time.time() - # Start with the original state that has all sub-questions - merged_state = copy.deepcopy(original_state) - - # Initialize empty dictionaries for the results - merged_state.search_results = {} - merged_state.synthesized_info = {} - - # Initialize search cost tracking - merged_state.search_costs = {} - merged_state.search_cost_details = [] + # Initialize merged artifacts + merged_search_data = SearchData() + merged_synthesis_data = SynthesisData() # Get pipeline run information to access outputs try: ctx = get_step_context() if not ctx or not ctx.pipeline_run: logger.error("Could not get pipeline run context") - return merged_state + return merged_search_data, merged_synthesis_data run_name = ctx.pipeline_run.name client = Client() @@ -80,77 +76,45 @@ def merge_sub_question_results_step( f"Processing results from step: {step_name} (index: {index})" ) - # Get the output artifact - if output_name in step_info.outputs: - output_artifacts = step_info.outputs[output_name] - if output_artifacts: - output_artifact = output_artifacts[0] - sub_state = output_artifact.load() - - # Check if the sub-state has valid data - if ( - hasattr(sub_state, "sub_questions") - and sub_state.sub_questions - ): - sub_question = sub_state.sub_questions[0] + # Get the search_data artifact + if "search_data" in step_info.outputs: + search_artifacts = step_info.outputs["search_data"] + if search_artifacts: + search_artifact = search_artifacts[0] + sub_search_data = search_artifact.load() + + # Merge search data + merged_search_data.merge(sub_search_data) + + # Track processed questions + for sub_q in sub_search_data.search_results: + processed_questions.add(sub_q) logger.info( - f"Found results for sub-question: {sub_question}" + f"Merged search results for: {sub_q}" ) - parallel_steps_processed += 1 - processed_questions.add(sub_question) - - # Merge search results - if ( - hasattr(sub_state, "search_results") - and sub_question - in sub_state.search_results - ): - merged_state.search_results[ - sub_question - ] = sub_state.search_results[ - sub_question - ] - logger.info( - f"Added search results for: {sub_question}" - ) - - # Merge synthesized info - if ( - hasattr(sub_state, "synthesized_info") - and sub_question - in sub_state.synthesized_info - ): - merged_state.synthesized_info[ - sub_question - ] = sub_state.synthesized_info[ - sub_question - ] - logger.info( - f"Added synthesized info for: {sub_question}" - ) - - # Merge search costs - if hasattr(sub_state, "search_costs"): - for ( - provider, - cost, - ) in sub_state.search_costs.items(): - merged_state.search_costs[ - provider - ] = ( - merged_state.search_costs.get( - provider, 0.0 - ) - + cost - ) - - # Merge search cost details - if hasattr( - sub_state, "search_cost_details" - ): - merged_state.search_cost_details.extend( - sub_state.search_cost_details - ) + + # Get the synthesis_data artifact + if "synthesis_data" in step_info.outputs: + synthesis_artifacts = step_info.outputs[ + "synthesis_data" + ] + if synthesis_artifacts: + synthesis_artifact = synthesis_artifacts[0] + sub_synthesis_data = synthesis_artifact.load() + + # Merge synthesis data + merged_synthesis_data.merge(sub_synthesis_data) + + # Track processed questions + for ( + sub_q + ) in sub_synthesis_data.synthesized_info: + logger.info( + f"Merged synthesis info for: {sub_q}" + ) + + parallel_steps_processed += 1 + except (ValueError, IndexError, KeyError, AttributeError) as e: logger.warning(f"Error processing step {step_name}: {e}") continue @@ -164,24 +128,22 @@ def merge_sub_question_results_step( ) # Log search cost summary - if merged_state.search_costs: - total_cost = sum(merged_state.search_costs.values()) + if merged_search_data.search_costs: + total_cost = sum(merged_search_data.search_costs.values()) logger.info( - f"Total search costs merged: ${total_cost:.4f} across {len(merged_state.search_cost_details)} queries" + f"Total search costs merged: ${total_cost:.4f} across {len(merged_search_data.search_cost_details)} queries" ) - for provider, cost in merged_state.search_costs.items(): + for provider, cost in merged_search_data.search_costs.items(): logger.info(f" {provider}: ${cost:.4f}") - # Check for any missing sub-questions - for sub_q in merged_state.sub_questions: - if sub_q not in processed_questions: - logger.warning(f"Missing results for sub-question: {sub_q}") - except Exception as e: logger.error(f"Error during merge step: {e}") # Final check for empty results - if not merged_state.search_results or not merged_state.synthesized_info: + if ( + not merged_search_data.search_results + or not merged_synthesis_data.synthesized_info + ): logger.warning( "No results were found or merged from parallel processing steps!" ) @@ -189,80 +151,81 @@ def merge_sub_question_results_step( # Calculate execution time execution_time = time.time() - start_time - # Calculate metrics - missing_questions = [ - q for q in merged_state.sub_questions if q not in processed_questions - ] - # Count total search results across all questions total_search_results = sum( - len(results) for results in merged_state.search_results.values() + len(results) for results in merged_search_data.search_results.values() ) # Get confidence distribution for merged results confidence_distribution = {"high": 0, "medium": 0, "low": 0} - for info in merged_state.synthesized_info.values(): + for info in merged_synthesis_data.synthesized_info.values(): level = info.confidence_level.lower() if level in confidence_distribution: confidence_distribution[level] += 1 - # Calculate completeness ratio - completeness_ratio = ( - len(processed_questions) / len(merged_state.sub_questions) - if merged_state.sub_questions - else 0 - ) - # Log metadata log_metadata( metadata={ "merge_results": { "execution_time_seconds": execution_time, - "total_sub_questions": len(merged_state.sub_questions), "parallel_steps_processed": parallel_steps_processed, "questions_successfully_merged": len(processed_questions), - "missing_questions_count": len(missing_questions), - "missing_questions": missing_questions[:5] - if missing_questions - else [], # Limit to 5 for metadata "total_search_results": total_search_results, "confidence_distribution": confidence_distribution, "merge_success": bool( - merged_state.search_results - and merged_state.synthesized_info + merged_search_data.search_results + and merged_synthesis_data.synthesized_info + ), + "total_search_costs": merged_search_data.search_costs, + "total_search_queries": len( + merged_search_data.search_cost_details + ), + "total_exa_cost": merged_search_data.search_costs.get( + "exa", 0.0 ), - "total_search_costs": merged_state.search_costs, - "total_search_queries": len(merged_state.search_cost_details), - "total_exa_cost": merged_state.search_costs.get("exa", 0.0), } } ) - # Log model metadata for cross-pipeline tracking + # Log artifact metadata log_metadata( metadata={ - "research_quality": { - "completeness_ratio": completeness_ratio, + "search_data_characteristics": { + "total_searches": merged_search_data.total_searches, + "search_results_count": len(merged_search_data.search_results), + "total_cost": sum(merged_search_data.search_costs.values()), } }, - infer_model=True, + artifact_name="merged_search_data", + infer_artifact=True, ) - # Log artifact metadata log_metadata( metadata={ - "merged_state_characteristics": { - "has_search_results": bool(merged_state.search_results), - "has_synthesized_info": bool(merged_state.synthesized_info), - "search_results_count": len(merged_state.search_results), - "synthesized_info_count": len(merged_state.synthesized_info), - "completeness_ratio": completeness_ratio, + "synthesis_data_characteristics": { + "synthesized_info_count": len( + merged_synthesis_data.synthesized_info + ), + "enhanced_info_count": len( + merged_synthesis_data.enhanced_info + ), + "confidence_distribution": confidence_distribution, } }, + artifact_name="merged_synthesis_data", infer_artifact=True, ) - # Add tags to the artifact - add_tags(tags=["state", "merged"], artifact="merged_state") - - return merged_state + # Add tags to the artifacts + # add_tags( + # tags=["search", "merged"], + # artifact_name="merged_search_data", + # infer_artifact=True, + # ) + # add_tags( + # tags=["synthesis", "merged"], + # artifact_name="merged_synthesis_data", + # infer_artifact=True, + # ) + + return merged_search_data, merged_synthesis_data diff --git a/deep_research/steps/process_sub_question_step.py b/deep_research/steps/process_sub_question_step.py index bc1b12ac..7ff14ab1 100644 --- a/deep_research/steps/process_sub_question_step.py +++ b/deep_research/steps/process_sub_question_step.py @@ -1,8 +1,7 @@ -import copy import logging import time import warnings -from typing import Annotated +from typing import Annotated, Tuple # Suppress Pydantic serialization warnings from ZenML artifact metadata # These occur when ZenML stores timestamp metadata as floats but models expect ints @@ -10,21 +9,34 @@ "ignore", message=".*PydanticSerializationUnexpectedValue.*" ) -from materializers.pydantic_materializer import ResearchStateMaterializer +from materializers.search_data_materializer import SearchDataMaterializer +from materializers.synthesis_data_materializer import SynthesisDataMaterializer from utils.llm_utils import synthesize_information -from utils.pydantic_models import Prompt, ResearchState, SynthesizedInfo +from utils.pydantic_models import ( + Prompt, + QueryContext, + SearchCostDetail, + SearchData, + SynthesisData, + SynthesizedInfo, +) from utils.search_utils import ( generate_search_query, search_and_extract_results, ) -from zenml import add_tags, log_metadata, step +from zenml import log_metadata, step logger = logging.getLogger(__name__) -@step(output_materializers=ResearchStateMaterializer) +@step( + output_materializers={ + "search_data": SearchDataMaterializer, + "synthesis_data": SynthesisDataMaterializer, + } +) def process_sub_question_step( - state: ResearchState, + query_context: QueryContext, search_query_prompt: Prompt, synthesis_prompt: Prompt, question_index: int, @@ -35,14 +47,17 @@ def process_sub_question_step( search_provider: str = "tavily", search_mode: str = "auto", langfuse_project_name: str = "deep-research", -) -> Annotated[ResearchState, "output"]: +) -> Tuple[ + Annotated[SearchData, "search_data"], + Annotated[SynthesisData, "synthesis_data"], +]: """Process a single sub-question if it exists at the given index. This step combines the gathering and synthesis steps for a single sub-question. It's designed to be run in parallel for each sub-question. Args: - state: The original research state with all sub-questions + query_context: The query context with main query and sub-questions search_query_prompt: Prompt for generating search queries synthesis_prompt: Prompt for synthesizing search results question_index: The index of the sub-question to process @@ -52,25 +67,19 @@ def process_sub_question_step( cap_search_length: Maximum length of content to process from search results search_provider: Search provider to use (tavily, exa, or both) search_mode: Search mode for Exa provider (neural, keyword, or auto) + langfuse_project_name: Project name for tracing Returns: - A new ResearchState containing only the processed sub-question's results + Tuple of SearchData and SynthesisData for the processed sub-question """ start_time = time.time() - # Create a copy of the state to avoid modifying the original - sub_state = copy.deepcopy(state) - - # Clear all existing data except the main query - sub_state.search_results = {} - sub_state.synthesized_info = {} - sub_state.enhanced_info = {} - sub_state.viewpoint_analysis = None - sub_state.reflection_metadata = None - sub_state.final_report_html = "" + # Initialize empty artifacts + search_data = SearchData() + synthesis_data = SynthesisData() # Check if this index exists in sub-questions - if question_index >= len(state.sub_questions): + if question_index >= len(query_context.sub_questions): logger.info( f"No sub-question at index {question_index}, skipping processing" ) @@ -81,25 +90,25 @@ def process_sub_question_step( "question_index": question_index, "status": "skipped", "reason": "index_out_of_range", - "total_sub_questions": len(state.sub_questions), + "total_sub_questions": len(query_context.sub_questions), } } ) - # Return an empty state since there's no question to process - sub_state.sub_questions = [] - # Add tags to the artifact - add_tags(tags=["state", "sub-question"], artifact="output") - return sub_state + # Return empty artifacts + # add_tags( + # tags=["search", "synthesis", "skipped"], artifact_name="search_data", infer_artifact=True + # ) + # add_tags( + # tags=["search", "synthesis", "skipped"], artifact_name="synthesis_data", infer_artifact=True + # ) + return search_data, synthesis_data # Get the target sub-question - sub_question = state.sub_questions[question_index] + sub_question = query_context.sub_questions[question_index] logger.info( f"Processing sub-question {question_index + 1}: {sub_question}" ) - # Store only this sub-question in the sub-state - sub_state.sub_questions = [sub_question] - # === INFORMATION GATHERING === search_phase_start = time.time() @@ -126,6 +135,10 @@ def process_sub_question_step( search_mode=search_mode, ) + # Update search data + search_data.search_results[sub_question] = results_list + search_data.total_searches = 1 + # Track search costs if using Exa if ( search_provider @@ -133,29 +146,23 @@ def process_sub_question_step( and search_cost > 0 ): # Update total costs - sub_state.search_costs["exa"] = ( - sub_state.search_costs.get("exa", 0.0) + search_cost - ) + search_data.search_costs["exa"] = search_cost # Add detailed cost entry - sub_state.search_cost_details.append( - { - "provider": "exa", - "query": search_query, - "cost": search_cost, - "timestamp": time.time(), - "step": "process_sub_question", - "sub_question": sub_question, - "question_index": question_index, - } + search_data.search_cost_details.append( + SearchCostDetail( + provider="exa", + query=search_query, + cost=search_cost, + timestamp=time.time(), + step="process_sub_question", + sub_question=sub_question, + ) ) logger.info( f"Exa search cost for sub-question {question_index}: ${search_cost:.4f}" ) - search_results = {sub_question: results_list} - sub_state.update_search_results(search_results) - search_phase_time = time.time() - search_phase_start # === INFORMATION SYNTHESIS === @@ -184,23 +191,21 @@ def process_sub_question_step( ) # Create SynthesizedInfo object - synthesized_info = { - sub_question: SynthesizedInfo( - synthesized_answer=synthesis_result.get( - "synthesized_answer", f"Synthesis for '{sub_question}' failed." - ), - key_sources=synthesis_result.get("key_sources", sources[:1]), - confidence_level=synthesis_result.get("confidence_level", "low"), - information_gaps=synthesis_result.get( - "information_gaps", - "Synthesis process encountered technical difficulties.", - ), - improvements=synthesis_result.get("improvements", []), - ) - } + synthesized_info = SynthesizedInfo( + synthesized_answer=synthesis_result.get( + "synthesized_answer", f"Synthesis for '{sub_question}' failed." + ), + key_sources=synthesis_result.get("key_sources", sources[:1]), + confidence_level=synthesis_result.get("confidence_level", "low"), + information_gaps=synthesis_result.get( + "information_gaps", + "Synthesis process encountered technical difficulties.", + ), + improvements=synthesis_result.get("improvements", []), + ) - # Update the state with synthesized information - sub_state.update_synthesized_info(synthesized_info) + # Update synthesis data + synthesis_data.synthesized_info[sub_question] = synthesized_info synthesis_phase_time = time.time() - synthesis_phase_start total_execution_time = time.time() - start_time @@ -270,22 +275,41 @@ def process_sub_question_step( infer_model=True, ) - # Log artifact metadata for the output state + # Log artifact metadata for the output artifacts log_metadata( metadata={ - "sub_state_characteristics": { - "has_search_results": bool(sub_state.search_results), - "has_synthesized_info": bool(sub_state.synthesized_info), - "sub_question_processed": sub_question, + "search_data_characteristics": { + "sub_question": sub_question, + "num_results": len(results_list), + "search_provider": search_provider, + "search_cost": search_cost if search_cost > 0 else None, + } + }, + artifact_name="search_data", + infer_artifact=True, + ) + + log_metadata( + metadata={ + "synthesis_data_characteristics": { + "sub_question": sub_question, "confidence_level": synthesis_result.get( "confidence_level", "low" ), + "has_information_gaps": bool( + synthesis_result.get("information_gaps") + ), + "num_key_sources": len( + synthesis_result.get("key_sources", []) + ), } }, + artifact_name="synthesis_data", infer_artifact=True, ) - # Add tags to the artifact - add_tags(tags=["state", "sub-question"], artifact="output") + # Add tags to the artifacts + # add_tags(tags=["search", "sub-question"], artifact_name="search_data", infer_artifact=True) + # add_tags(tags=["synthesis", "sub-question"], artifact_name="synthesis_data", infer_artifact=True) - return sub_state + return search_data, synthesis_data diff --git a/deep_research/steps/pydantic_final_report_step.py b/deep_research/steps/pydantic_final_report_step.py index d61e848b..88cf2221 100644 --- a/deep_research/steps/pydantic_final_report_step.py +++ b/deep_research/steps/pydantic_final_report_step.py @@ -1,7 +1,7 @@ -"""Final report generation step using Pydantic models and materializers. +"""Final report generation step using artifact-based approach. This module provides a ZenML pipeline step for generating the final HTML research report -using Pydantic models and improved materializers. +using the new artifact-based approach. """ import html @@ -11,7 +11,7 @@ import time from typing import Annotated, Tuple -from materializers.pydantic_materializer import ResearchStateMaterializer +from materializers.final_report_materializer import FinalReportMaterializer from utils.helper_functions import ( extract_html_from_content, remove_reasoning_from_output, @@ -22,8 +22,15 @@ SUB_QUESTION_TEMPLATE, VIEWPOINT_ANALYSIS_TEMPLATE, ) -from utils.pydantic_models import Prompt, ResearchState -from zenml import add_tags, log_metadata, step +from utils.pydantic_models import ( + AnalysisData, + FinalReport, + Prompt, + QueryContext, + SearchData, + SynthesisData, +) +from zenml import log_metadata, step from zenml.types import HTMLString logger = logging.getLogger(__name__) @@ -99,53 +106,84 @@ def format_text_with_code_blocks(text: str) -> str: if not text: return "" - # First escape HTML - escaped_text = html.escape(text) - - # Handle code blocks (wrap content in ``` or ```) - pattern = r"```(?:\w*\n)?(.*?)```" - - def code_block_replace(match): - code_content = match.group(1) - # Strip extra newlines at beginning and end - code_content = code_content.strip("\n") - return f"
{code_content}
" - - # Replace code blocks - formatted_text = re.sub( - pattern, code_block_replace, escaped_text, flags=re.DOTALL - ) + # Handle code blocks + lines = text.split("\n") + formatted_lines = [] + in_code_block = False + code_language = "" + code_lines = [] + + for line in lines: + # Check for code block start + if line.strip().startswith("```"): + if in_code_block: + # End of code block + code_content = "\n".join(code_lines) + formatted_lines.append( + f'
{html.escape(code_content)}
' + ) + code_lines = [] + in_code_block = False + code_language = "" + else: + # Start of code block + in_code_block = True + # Extract language if specified + code_language = line.strip()[3:].strip() or "plaintext" + elif in_code_block: + code_lines.append(line) + else: + # Process inline code + line = re.sub(r"`([^`]+)`", r"\1", html.escape(line)) + # Process bullet points + if line.strip().startswith("•") or line.strip().startswith("-"): + line = re.sub(r"^(\s*)[•-]\s*", r"\1", line) + formatted_lines.append(f"
  • {line.strip()}
  • ") + elif line.strip(): + formatted_lines.append(f"

    {line}

    ") + + # Handle case where code block wasn't closed + if in_code_block and code_lines: + code_content = "\n".join(code_lines) + formatted_lines.append( + f'
    {html.escape(code_content)}
    ' + ) - # Convert regular newlines to
    tags (but not inside
     blocks)
    -    parts = []
    -    in_pre = False
    -    for line in formatted_text.split("\n"):
    -        if "
    " in line:
    -            in_pre = True
    -            parts.append(line)
    -        elif "
    " in line: - in_pre = False - parts.append(line) - elif in_pre: - # Inside a code block, preserve newlines - parts.append(line) + # Wrap list items in ul tags + result = [] + in_list = False + for line in formatted_lines: + if line.startswith("
  • "): + if not in_list: + result.append("") + in_list = False + result.append(line) + + if in_list: + result.append("") - return "".join(parts) + return "\n".join(result) def generate_executive_summary( - state: ResearchState, + query_context: QueryContext, + synthesis_data: SynthesisData, + analysis_data: AnalysisData, executive_summary_prompt: Prompt, llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", langfuse_project_name: str = "deep-research", ) -> str: - """Generate an executive summary using LLM based on research findings. + """Generate an executive summary using LLM based on the complete research findings. Args: - state: The current research state + query_context: The query context with main query and sub-questions + synthesis_data: The synthesis data with all synthesized information + analysis_data: The analysis data with viewpoint analysis executive_summary_prompt: Prompt for generating executive summary llm_model: The model to use for generation langfuse_project_name: Name of the Langfuse project for tracking @@ -155,45 +193,48 @@ def generate_executive_summary( """ logger.info("Generating executive summary using LLM") - # Prepare the context with all research findings - context = f"Main Research Query: {state.main_query}\n\n" - - # Add synthesized findings for each sub-question - for i, sub_question in enumerate(state.sub_questions, 1): - info = state.enhanced_info.get( - sub_question - ) or state.synthesized_info.get(sub_question) - if info: - context += f"Sub-question {i}: {sub_question}\n" - context += f"Answer Summary: {info.synthesized_answer[:500]}...\n" - context += f"Confidence: {info.confidence_level}\n" - context += f"Key Sources: {', '.join(info.key_sources[:3]) if info.key_sources else 'N/A'}\n\n" - - # Add viewpoint analysis insights if available - if state.viewpoint_analysis: - context += "Key Areas of Agreement:\n" - for agreement in state.viewpoint_analysis.main_points_of_agreement[:3]: - context += f"- {agreement}\n" - context += "\nKey Tensions:\n" - for tension in state.viewpoint_analysis.areas_of_tension[:2]: - context += f"- {tension.topic}\n" - - # Use the executive summary prompt - try: - executive_summary_prompt_str = str(executive_summary_prompt) - logger.info("Successfully retrieved executive_summary_prompt") - except Exception as e: - logger.error(f"Failed to get executive_summary_prompt: {e}") - return generate_fallback_executive_summary(state) + # Prepare the context + summary_input = { + "main_query": query_context.main_query, + "sub_questions": query_context.sub_questions, + "key_findings": {}, + "viewpoint_analysis": None, + } + + # Include key findings from synthesis data + # Prefer enhanced info if available + info_source = ( + synthesis_data.enhanced_info + if synthesis_data.enhanced_info + else synthesis_data.synthesized_info + ) + + for question in query_context.sub_questions: + if question in info_source: + info = info_source[question] + summary_input["key_findings"][question] = { + "answer": info.synthesized_answer, + "confidence": info.confidence_level, + "gaps": info.information_gaps, + } + + # Include viewpoint analysis if available + if analysis_data.viewpoint_analysis: + va = analysis_data.viewpoint_analysis + summary_input["viewpoint_analysis"] = { + "agreements": va.main_points_of_agreement, + "tensions": len(va.areas_of_tension), + "insights": va.integrative_insights, + } try: # Call LLM to generate executive summary result = run_llm_completion( - prompt=context, - system_prompt=executive_summary_prompt_str, + prompt=json.dumps(summary_input), + system_prompt=str(executive_summary_prompt), model=llm_model, temperature=0.7, - max_tokens=800, + max_tokens=600, project=langfuse_project_name, tags=["executive_summary_generation"], ) @@ -206,15 +247,19 @@ def generate_executive_summary( return content else: logger.warning("Failed to generate executive summary via LLM") - return generate_fallback_executive_summary(state) + return generate_fallback_executive_summary( + query_context, synthesis_data + ) except Exception as e: logger.error(f"Error generating executive summary: {e}") - return generate_fallback_executive_summary(state) + return generate_fallback_executive_summary( + query_context, synthesis_data + ) def generate_introduction( - state: ResearchState, + query_context: QueryContext, introduction_prompt: Prompt, llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", langfuse_project_name: str = "deep-research", @@ -222,7 +267,7 @@ def generate_introduction( """Generate an introduction using LLM based on research query and sub-questions. Args: - state: The current research state + query_context: The query context with main query and sub-questions introduction_prompt: Prompt for generating introduction llm_model: The model to use for generation langfuse_project_name: Name of the Langfuse project for tracking @@ -233,24 +278,16 @@ def generate_introduction( logger.info("Generating introduction using LLM") # Prepare the context - context = f"Main Research Query: {state.main_query}\n\n" + context = f"Main Research Query: {query_context.main_query}\n\n" context += "Sub-questions being explored:\n" - for i, sub_question in enumerate(state.sub_questions, 1): + for i, sub_question in enumerate(query_context.sub_questions, 1): context += f"{i}. {sub_question}\n" - # Get the introduction prompt - try: - introduction_prompt_str = str(introduction_prompt) - logger.info("Successfully retrieved introduction_prompt") - except Exception as e: - logger.error(f"Failed to get introduction_prompt: {e}") - return generate_fallback_introduction(state) - try: # Call LLM to generate introduction result = run_llm_completion( prompt=context, - system_prompt=introduction_prompt_str, + system_prompt=str(introduction_prompt), model=llm_model, temperature=0.7, max_tokens=600, @@ -266,22 +303,29 @@ def generate_introduction( return content else: logger.warning("Failed to generate introduction via LLM") - return generate_fallback_introduction(state) + return generate_fallback_introduction(query_context) except Exception as e: logger.error(f"Error generating introduction: {e}") - return generate_fallback_introduction(state) + return generate_fallback_introduction(query_context) -def generate_fallback_executive_summary(state: ResearchState) -> str: +def generate_fallback_executive_summary( + query_context: QueryContext, synthesis_data: SynthesisData +) -> str: """Generate a fallback executive summary when LLM fails.""" - summary = f"

    This report examines the question: {html.escape(state.main_query)}

    " - summary += f"

    The research explored {len(state.sub_questions)} key dimensions of this topic, " + summary = f"

    This report examines the question: {html.escape(query_context.main_query)}

    " + summary += f"

    The research explored {len(query_context.sub_questions)} key dimensions of this topic, " summary += "synthesizing findings from multiple sources to provide a comprehensive analysis.

    " # Add confidence overview confidence_counts = {"high": 0, "medium": 0, "low": 0} - for info in state.enhanced_info.values(): + info_source = ( + synthesis_data.enhanced_info + if synthesis_data.enhanced_info + else synthesis_data.synthesized_info + ) + for info in info_source.values(): level = info.confidence_level.lower() if level in confidence_counts: confidence_counts[level] += 1 @@ -292,10 +336,10 @@ def generate_fallback_executive_summary(state: ResearchState) -> str: return summary -def generate_fallback_introduction(state: ResearchState) -> str: +def generate_fallback_introduction(query_context: QueryContext) -> str: """Generate a fallback introduction when LLM fails.""" - intro = f"

    This report addresses the research query: {html.escape(state.main_query)}

    " - intro += f"

    The research was conducted by breaking down the main query into {len(state.sub_questions)} " + intro = f"

    This report addresses the research query: {html.escape(query_context.main_query)}

    " + intro += f"

    The research was conducted by breaking down the main query into {len(query_context.sub_questions)} " intro += ( "sub-questions to explore different aspects of the topic in depth. " ) @@ -304,7 +348,9 @@ def generate_fallback_introduction(state: ResearchState) -> str: def generate_conclusion( - state: ResearchState, + query_context: QueryContext, + synthesis_data: SynthesisData, + analysis_data: AnalysisData, conclusion_generation_prompt: Prompt, llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", langfuse_project_name: str = "deep-research", @@ -312,9 +358,12 @@ def generate_conclusion( """Generate a comprehensive conclusion using LLM based on all research findings. Args: - state: The ResearchState containing all research findings + query_context: The query context with main query and sub-questions + synthesis_data: The synthesis data with all synthesized information + analysis_data: The analysis data with viewpoint analysis conclusion_generation_prompt: Prompt for generating conclusion llm_model: The model to use for conclusion generation + langfuse_project_name: Name of the Langfuse project for tracking Returns: str: HTML-formatted conclusion content @@ -323,15 +372,21 @@ def generate_conclusion( # Prepare input data for conclusion generation conclusion_input = { - "main_query": state.main_query, - "sub_questions": state.sub_questions, + "main_query": query_context.main_query, + "sub_questions": query_context.sub_questions, "enhanced_info": {}, } # Include enhanced information for each sub-question - for question in state.sub_questions: - if question in state.enhanced_info: - info = state.enhanced_info[question] + info_source = ( + synthesis_data.enhanced_info + if synthesis_data.enhanced_info + else synthesis_data.synthesized_info + ) + + for question in query_context.sub_questions: + if question in info_source: + info = info_source[question] conclusion_input["enhanced_info"][question] = { "synthesized_answer": info.synthesized_answer, "confidence_level": info.confidence_level, @@ -339,93 +394,99 @@ def generate_conclusion( "key_sources": info.key_sources, "improvements": getattr(info, "improvements", []), } - elif question in state.synthesized_info: - # Fallback to synthesized info if enhanced info not available - info = state.synthesized_info[question] - conclusion_input["enhanced_info"][question] = { - "synthesized_answer": info.synthesized_answer, - "confidence_level": info.confidence_level, - "information_gaps": info.information_gaps, - "key_sources": info.key_sources, - "improvements": [], - } - # Include viewpoint analysis if available - if state.viewpoint_analysis: + # Include viewpoint analysis + if analysis_data.viewpoint_analysis: + va = analysis_data.viewpoint_analysis conclusion_input["viewpoint_analysis"] = { - "main_points_of_agreement": state.viewpoint_analysis.main_points_of_agreement, + "main_points_of_agreement": va.main_points_of_agreement, "areas_of_tension": [ - {"topic": tension.topic, "viewpoints": tension.viewpoints} - for tension in state.viewpoint_analysis.areas_of_tension + {"topic": t.topic, "viewpoints": t.viewpoints} + for t in va.areas_of_tension ], - "perspective_gaps": state.viewpoint_analysis.perspective_gaps, - "integrative_insights": state.viewpoint_analysis.integrative_insights, + "integrative_insights": va.integrative_insights, } # Include reflection metadata if available - if state.reflection_metadata: - conclusion_input["reflection_metadata"] = { - "critique_summary": state.reflection_metadata.critique_summary, - "additional_questions_identified": state.reflection_metadata.additional_questions_identified, - "improvements_made": state.reflection_metadata.improvements_made, + if analysis_data.reflection_metadata: + rm = analysis_data.reflection_metadata + conclusion_input["reflection_insights"] = { + "improvements_made": rm.improvements_made, + "additional_questions_identified": rm.additional_questions_identified, } try: - # Use the conclusion generation prompt - conclusion_prompt_str = str(conclusion_generation_prompt) - - # Generate conclusion using LLM - conclusion_html = run_llm_completion( - prompt=json.dumps(conclusion_input, indent=2), - system_prompt=conclusion_prompt_str, + # Call LLM to generate conclusion + result = run_llm_completion( + prompt=json.dumps(conclusion_input), + system_prompt=str(conclusion_generation_prompt), model=llm_model, - clean_output=True, - max_tokens=1500, # Sufficient for comprehensive conclusion + temperature=0.7, + max_tokens=800, project=langfuse_project_name, + tags=["conclusion_generation"], ) - # Clean up any formatting issues - conclusion_html = conclusion_html.strip() + if result: + content = remove_reasoning_from_output(result) + # Clean up the HTML + content = extract_html_from_content(content) + logger.info("Successfully generated LLM-based conclusion") + return content + else: + logger.warning("Failed to generate conclusion via LLM") + return generate_fallback_conclusion(query_context, synthesis_data) + + except Exception as e: + logger.error(f"Error generating conclusion: {e}") + return generate_fallback_conclusion(query_context, synthesis_data) - # Remove any h2 tags with "Conclusion" text that LLM might have added - # Since we already have a Conclusion header in the template - conclusion_html = re.sub( - r"]*>\s*Conclusion\s*\s*", - "", - conclusion_html, - flags=re.IGNORECASE, - ) - conclusion_html = re.sub( - r"]*>\s*Conclusion\s*\s*", - "", - conclusion_html, - flags=re.IGNORECASE, - ) - # Also remove plain text "Conclusion" at the start if it exists - conclusion_html = re.sub( - r"^Conclusion\s*\n*", - "", - conclusion_html.strip(), - flags=re.IGNORECASE, - ) +def generate_fallback_conclusion( + query_context: QueryContext, synthesis_data: SynthesisData +) -> str: + """Generate a fallback conclusion when LLM fails. + + Args: + query_context: The query context with main query and sub-questions + synthesis_data: The synthesis data with all synthesized information + + Returns: + str: Basic HTML-formatted conclusion + """ + conclusion = f"

    This research has explored the question: {html.escape(query_context.main_query)}

    " + conclusion += f"

    Through systematic investigation of {len(query_context.sub_questions)} sub-questions, " + conclusion += ( + "we have gathered insights from multiple sources and perspectives.

    " + ) - if not conclusion_html.startswith("

    "): - # Wrap in paragraph tags if not already formatted - conclusion_html = f"

    {conclusion_html}

    " + # Add a summary of confidence levels + info_source = ( + synthesis_data.enhanced_info + if synthesis_data.enhanced_info + else synthesis_data.synthesized_info + ) + high_confidence = sum( + 1 + for info in info_source.values() + if info.confidence_level.lower() == "high" + ) - logger.info("Successfully generated LLM-based conclusion") - return conclusion_html + if high_confidence > 0: + conclusion += f"

    The research yielded {high_confidence} high-confidence findings out of " + conclusion += f"{len(info_source)} total areas investigated.

    " - except Exception as e: - logger.warning(f"Failed to generate LLM conclusion: {e}") - # Return a basic fallback conclusion - return f"""

    This report has explored {html.escape(state.main_query)} through a structured research approach, examining {len(state.sub_questions)} focused sub-questions and synthesizing information from diverse sources. The findings provide a comprehensive understanding of the topic, highlighting key aspects, perspectives, and current knowledge.

    -

    While some information gaps remain, as noted in the respective sections, this research provides a solid foundation for understanding the topic and its implications.

    """ + conclusion += "

    Further research may be beneficial to address remaining information gaps " + conclusion += "and explore emerging questions identified during this investigation.

    " + + return conclusion def generate_report_from_template( - state: ResearchState, + query_context: QueryContext, + search_data: SearchData, + synthesis_data: SynthesisData, + analysis_data: AnalysisData, conclusion_generation_prompt: Prompt, executive_summary_prompt: Prompt, introduction_prompt: Prompt, @@ -435,25 +496,29 @@ def generate_report_from_template( """Generate a final HTML report from a static template. Instead of using an LLM to generate HTML, this function uses predefined HTML - templates and populates them with data from the research state. + templates and populates them with data from the research artifacts. Args: - state: The current research state + query_context: The query context with main query and sub-questions + search_data: The search data (for source information) + synthesis_data: The synthesis data with all synthesized information + analysis_data: The analysis data with viewpoint analysis conclusion_generation_prompt: Prompt for generating conclusion executive_summary_prompt: Prompt for generating executive summary introduction_prompt: Prompt for generating introduction llm_model: The model to use for conclusion generation + langfuse_project_name: Name of the Langfuse project for tracking Returns: str: The HTML content of the report """ logger.info( - f"Generating templated HTML report for query: {state.main_query}" + f"Generating templated HTML report for query: {query_context.main_query}" ) # Generate table of contents for sub-questions sub_questions_toc = "" - for i, question in enumerate(state.sub_questions, 1): + for i, question in enumerate(query_context.sub_questions, 1): safe_id = f"question-{i}" sub_questions_toc += ( f'
  • {html.escape(question)}
  • \n' @@ -461,7 +526,7 @@ def generate_report_from_template( # Add viewpoint analysis to TOC if available additional_sections_toc = "" - if state.viewpoint_analysis: + if analysis_data.viewpoint_analysis: additional_sections_toc += ( '
  • Viewpoint Analysis
  • \n' ) @@ -470,11 +535,41 @@ def generate_report_from_template( sub_questions_html = "" all_sources = set() - for i, question in enumerate(state.sub_questions, 1): - info = state.enhanced_info.get(question, None) + # Determine which info source to use (merge original with enhanced) + # Start with the original synthesized info + info_source = synthesis_data.synthesized_info.copy() + + # Override with enhanced info where available + if synthesis_data.enhanced_info: + info_source.update(synthesis_data.enhanced_info) + + # Debug logging + logger.info( + f"Synthesis data has enhanced_info: {bool(synthesis_data.enhanced_info)}" + ) + logger.info( + f"Synthesis data has synthesized_info: {bool(synthesis_data.synthesized_info)}" + ) + logger.info(f"Info source has {len(info_source)} entries") + logger.info(f"Processing {len(query_context.sub_questions)} sub-questions") + + # Log the keys in info_source for debugging + if info_source: + logger.info( + f"Keys in info_source: {list(info_source.keys())[:3]}..." + ) # First 3 keys + logger.info( + f"Sub-questions from query_context: {query_context.sub_questions[:3]}..." + ) # First 3 + + for i, question in enumerate(query_context.sub_questions, 1): + info = info_source.get(question, None) # Skip if no information is available if not info: + logger.warning( + f"No synthesis info found for question {i}: {question}" + ) continue # Process confidence level @@ -527,65 +622,73 @@ def generate_report_from_template( confidence_upper=confidence_upper, confidence_icon=confidence_icon, answer=format_text_with_code_blocks(info.synthesized_answer), - info_gaps_html=info_gaps_html, key_sources_html=key_sources_html, + info_gaps_html=info_gaps_html, ) sub_questions_html += sub_question_html # Generate viewpoint analysis HTML if available viewpoint_analysis_html = "" - if state.viewpoint_analysis: - # Format points of agreement - agreements_html = "" - for point in state.viewpoint_analysis.main_points_of_agreement: - agreements_html += f"
  • {html.escape(point)}
  • \n" - - # Format areas of tension + if analysis_data.viewpoint_analysis: + va = analysis_data.viewpoint_analysis + # Format tensions tensions_html = "" - for tension in state.viewpoint_analysis.areas_of_tension: - viewpoints_html = "" - for title, content in tension.viewpoints.items(): - # Create category-specific styling - category_class = f"category-{title.lower()}" - category_title = title.capitalize() - - viewpoints_html += f""" -
    - {category_title} -

    {html.escape(content)}

    -
    - """ - + for tension in va.areas_of_tension: + viewpoints_list = "\n".join( + [ + f"
  • {html.escape(viewpoint)}: {html.escape(description)}
  • " + for viewpoint, description in tension.viewpoints.items() + ] + ) tensions_html += f""" -
    +

    {html.escape(tension.topic)}

    -
    - {viewpoints_html} -
    +
      + {viewpoints_list} +
    """ - # Format the viewpoint analysis section using the template + # Format agreements (just the list items) + agreements_html = "" + if va.main_points_of_agreement: + agreements_html = "\n".join( + [ + f"
  • {html.escape(point)}
  • " + for point in va.main_points_of_agreement + ] + ) + + # Get perspective gaps if available + perspective_gaps = "" + if hasattr(va, "perspective_gaps") and va.perspective_gaps: + perspective_gaps = va.perspective_gaps + else: + perspective_gaps = "No significant perspective gaps identified." + + # Get integrative insights + integrative_insights = "" + if va.integrative_insights: + integrative_insights = format_text_with_code_blocks( + va.integrative_insights + ) + viewpoint_analysis_html = VIEWPOINT_ANALYSIS_TEMPLATE.format( agreements_html=agreements_html, tensions_html=tensions_html, - perspective_gaps=format_text_with_code_blocks( - state.viewpoint_analysis.perspective_gaps - ), - integrative_insights=format_text_with_code_blocks( - state.viewpoint_analysis.integrative_insights - ), + perspective_gaps=perspective_gaps, + integrative_insights=integrative_insights, ) - # Generate references HTML - references_html = "
    - """ - - html += """ - - """ - else: - html += f""" -
    -

    {i + 1}. {sub_question}

    -

    No information available for this question.

    -
    - """ - - # Add conclusion section html += """ + +

    Conclusion

    -

    This report has explored the research query through multiple sub-questions, providing synthesized information based on available sources. While limitations exist in some areas, the report provides a structured analysis of the topic.

    +

    This research has provided insights into the various aspects of the main query through systematic investigation.

    +

    The findings represent a synthesis of available information, with varying levels of confidence across different areas.

    - """ - - # Add sources if available - sources_set = set() - for info in state.enhanced_info.values(): - if info.key_sources: - sources_set.update(info.key_sources) - - if sources_set: - html += """ -
    -

    References

    - -
    - """ - else: - html += """ -
    + +

    References

    -

    No references available.

    +

    Sources were gathered from various search providers and synthesized to create this report.

    - """ - - # Close the HTML structure - html += """
    - - """ +""" return html @step( output_materializers={ - "state": ResearchStateMaterializer, + "final_report": FinalReportMaterializer, } ) def pydantic_final_report_step( - state: ResearchState, + query_context: QueryContext, + search_data: SearchData, + synthesis_data: SynthesisData, + analysis_data: AnalysisData, conclusion_generation_prompt: Prompt, executive_summary_prompt: Prompt, introduction_prompt: Prompt, @@ -991,34 +964,41 @@ def pydantic_final_report_step( llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", langfuse_project_name: str = "deep-research", ) -> Tuple[ - Annotated[ResearchState, "state"], + Annotated[FinalReport, "final_report"], Annotated[HTMLString, "report_html"], ]: - """Generate the final research report in HTML format using Pydantic models. + """Generate the final research report in HTML format using artifact-based approach. - This step uses the Pydantic models and materializers to generate a final - HTML report and return both the updated state and the HTML report as - separate artifacts. + This step uses the individual artifacts to generate a final HTML report. Args: - state: The current research state (Pydantic model) + query_context: The query context with main query and sub-questions + search_data: The search data (for source information) + synthesis_data: The synthesis data with all synthesized information + analysis_data: The analysis data with viewpoint analysis and reflection metadata conclusion_generation_prompt: Prompt for generating conclusions executive_summary_prompt: Prompt for generating executive summary introduction_prompt: Prompt for generating introduction use_static_template: Whether to use a static template instead of LLM generation llm_model: The model to use for report generation with provider prefix + langfuse_project_name: Name of the Langfuse project for tracking Returns: - A tuple containing the updated research state and the HTML report + A tuple containing the FinalReport artifact and the HTML report string """ start_time = time.time() - logger.info("Generating final research report using Pydantic models") + logger.info( + "Generating final research report using artifact-based approach" + ) if use_static_template: # Use the static HTML template approach logger.info("Using static HTML template for report generation") html_content = generate_report_from_template( - state, + query_context, + search_data, + synthesis_data, + analysis_data, conclusion_generation_prompt, executive_summary_prompt, introduction_prompt, @@ -1026,232 +1006,101 @@ def pydantic_final_report_step( langfuse_project_name, ) - # Update the state with the final report HTML - state.set_final_report(html_content) + # Create the FinalReport artifact + final_report = FinalReport( + report_html=html_content, + main_query=query_context.main_query, + ) - # Collect metadata about the report + # Calculate execution time execution_time = time.time() - start_time - # Count sources - all_sources = set() - for info in state.enhanced_info.values(): - if info.key_sources: - all_sources.update(info.key_sources) - - # Count confidence levels + # Calculate report metrics + info_source = ( + synthesis_data.enhanced_info + if synthesis_data.enhanced_info + else synthesis_data.synthesized_info + ) confidence_distribution = {"high": 0, "medium": 0, "low": 0} - for info in state.enhanced_info.values(): + for info in info_source.values(): level = info.confidence_level.lower() if level in confidence_distribution: confidence_distribution[level] += 1 - # Log metadata - log_metadata( - metadata={ - "report_generation": { - "execution_time_seconds": execution_time, - "generation_method": "static_template", - "llm_model": llm_model, - "report_length_chars": len(html_content), - "num_sub_questions": len(state.sub_questions), - "num_sources": len(all_sources), - "has_viewpoint_analysis": bool(state.viewpoint_analysis), - "has_reflection": bool(state.reflection_metadata), - "confidence_distribution": confidence_distribution, - "fallback_report": False, - } - } - ) - - # Log model metadata for cross-pipeline tracking - log_metadata( - metadata={ - "research_quality": { - "confidence_distribution": confidence_distribution, - } - }, - infer_model=True, - ) - - # Log artifact metadata for the HTML report - log_metadata( - metadata={ - "html_report_characteristics": { - "size_bytes": len(html_content.encode("utf-8")), - "has_toc": "toc" in html_content.lower(), - "has_executive_summary": "executive summary" - in html_content.lower(), - "has_conclusion": "conclusion" in html_content.lower(), - "has_references": "references" in html_content.lower(), - } - }, - infer_artifact=True, - artifact_name="report_html", - ) - - logger.info( - "Final research report generated successfully with static template" + # Count various elements in the report + num_sources = len( + set( + source + for info in info_source.values() + for source in info.key_sources + ) ) - # Add tags to the artifacts - add_tags(tags=["state", "final"], artifact="state") - add_tags(tags=["report", "html"], artifact="report_html") - return state, HTMLString(html_content) - - # Otherwise use the LLM-generated approach - # Convert Pydantic model to dict for LLM input - report_input = { - "main_query": state.main_query, - "sub_questions": state.sub_questions, - "synthesized_information": state.enhanced_info, - } - - if state.viewpoint_analysis: - report_input["viewpoint_analysis"] = state.viewpoint_analysis - - if state.reflection_metadata: - report_input["reflection_metadata"] = state.reflection_metadata - - # Generate the report - try: - logger.info(f"Calling {llm_model} to generate final report") - - # Use a default report generation prompt - report_prompt = "Generate a comprehensive HTML research report based on the provided research data. Include proper HTML structure with sections for executive summary, introduction, findings, and conclusion." - - # Use the utility function to run LLM completion - html_content = run_llm_completion( - prompt=json.dumps(report_input), - system_prompt=report_prompt, - model=llm_model, - clean_output=False, # Don't clean in case of breaking HTML formatting - max_tokens=4000, # Increased token limit for detailed report generation - project=langfuse_project_name, + has_viewpoint_analysis = analysis_data.viewpoint_analysis is not None + has_reflection_insights = ( + analysis_data.reflection_metadata is not None + and analysis_data.reflection_metadata.improvements_made > 0 ) - # Clean up any JSON wrapper or other artifacts - html_content = remove_reasoning_from_output(html_content) - - # Process the HTML content to remove code block markers and fix common issues - html_content = clean_html_output(html_content) - - # Basic validation of HTML content - if not html_content.strip().startswith("<"): - logger.warning( - "Generated content does not appear to be valid HTML" - ) - # Try to extract HTML if it might be wrapped in code blocks or JSON - html_content = extract_html_from_content(html_content) - - # Update the state with the final report HTML - state.set_final_report(html_content) - - # Collect metadata about the report - execution_time = time.time() - start_time - - # Count sources - all_sources = set() - for info in state.enhanced_info.values(): - if info.key_sources: - all_sources.update(info.key_sources) - - # Count confidence levels - confidence_distribution = {"high": 0, "medium": 0, "low": 0} - for info in state.enhanced_info.values(): - level = info.confidence_level.lower() - if level in confidence_distribution: - confidence_distribution[level] += 1 - - # Log metadata + # Log step metadata log_metadata( metadata={ - "report_generation": { + "final_report_generation": { "execution_time_seconds": execution_time, - "generation_method": "llm_generated", + "use_static_template": use_static_template, "llm_model": llm_model, - "report_length_chars": len(html_content), - "num_sub_questions": len(state.sub_questions), - "num_sources": len(all_sources), - "has_viewpoint_analysis": bool(state.viewpoint_analysis), - "has_reflection": bool(state.reflection_metadata), + "main_query_length": len(query_context.main_query), + "num_sub_questions": len(query_context.sub_questions), + "num_synthesized_answers": len(info_source), + "has_enhanced_info": bool(synthesis_data.enhanced_info), "confidence_distribution": confidence_distribution, - "fallback_report": False, + "num_unique_sources": num_sources, + "has_viewpoint_analysis": has_viewpoint_analysis, + "has_reflection_insights": has_reflection_insights, + "report_length_chars": len(html_content), + "report_generation_success": True, } } ) - # Log model metadata for cross-pipeline tracking + # Log artifact metadata log_metadata( metadata={ - "research_quality": { - "confidence_distribution": confidence_distribution, + "final_report_characteristics": { + "report_length": len(html_content), + "main_query": query_context.main_query, + "num_sections": len(query_context.sub_questions) + + (1 if has_viewpoint_analysis else 0), + "has_executive_summary": True, + "has_introduction": True, + "has_conclusion": True, } }, - infer_model=True, + artifact_name="final_report", + infer_artifact=True, ) - logger.info("Final research report generated successfully") - # Add tags to the artifacts - add_tags(tags=["state", "final"], artifact="state") - add_tags(tags=["report", "html"], artifact="report_html") - return state, HTMLString(html_content) - - except Exception as e: - logger.error(f"Error generating final report: {e}") - # Generate a minimal fallback report - fallback_html = _generate_fallback_report(state) - - # Process the fallback HTML to ensure it's clean - fallback_html = clean_html_output(fallback_html) - - # Update the state with the fallback report - state.set_final_report(fallback_html) - - # Collect metadata about the fallback report - execution_time = time.time() - start_time - - # Count sources - all_sources = set() - for info in state.enhanced_info.values(): - if info.key_sources: - all_sources.update(info.key_sources) - - # Count confidence levels - confidence_distribution = {"high": 0, "medium": 0, "low": 0} - for info in state.enhanced_info.values(): - level = info.confidence_level.lower() - if level in confidence_distribution: - confidence_distribution[level] += 1 + # Add tags to the artifact + # add_tags(tags=["report", "final", "html"], artifact_name="final_report", infer_artifact=True) - # Log metadata for fallback report - log_metadata( - metadata={ - "report_generation": { - "execution_time_seconds": execution_time, - "generation_method": "fallback", - "llm_model": llm_model, - "report_length_chars": len(fallback_html), - "num_sub_questions": len(state.sub_questions), - "num_sources": len(all_sources), - "has_viewpoint_analysis": bool(state.viewpoint_analysis), - "has_reflection": bool(state.reflection_metadata), - "confidence_distribution": confidence_distribution, - "fallback_report": True, - "error_message": str(e), - } - } + logger.info( + f"Successfully generated final report ({len(html_content)} characters)" ) + return final_report, HTMLString(html_content) - # Log model metadata for cross-pipeline tracking - log_metadata( - metadata={ - "research_quality": { - "confidence_distribution": confidence_distribution, - } - }, - infer_model=True, + else: + # Handle non-static template case (future implementation) + logger.warning( + "Non-static template generation not yet implemented, falling back to static template" + ) + return pydantic_final_report_step( + query_context=query_context, + search_data=search_data, + synthesis_data=synthesis_data, + analysis_data=analysis_data, + conclusion_generation_prompt=conclusion_generation_prompt, + executive_summary_prompt=executive_summary_prompt, + introduction_prompt=introduction_prompt, + use_static_template=True, + llm_model=llm_model, + langfuse_project_name=langfuse_project_name, ) - - # Add tags to the artifacts - add_tags(tags=["state", "final"], artifact="state") - add_tags(tags=["report", "html"], artifact="report_html") - return state, HTMLString(fallback_html) diff --git a/deep_research/steps/query_decomposition_step.py b/deep_research/steps/query_decomposition_step.py index ec0aa0d5..053a8fd4 100644 --- a/deep_research/steps/query_decomposition_step.py +++ b/deep_research/steps/query_decomposition_step.py @@ -2,35 +2,36 @@ import time from typing import Annotated -from materializers.pydantic_materializer import ResearchStateMaterializer +from materializers.query_context_materializer import QueryContextMaterializer from utils.llm_utils import get_structured_llm_output -from utils.pydantic_models import Prompt, ResearchState -from zenml import add_tags, log_metadata, step +from utils.pydantic_models import Prompt, QueryContext +from zenml import log_metadata, step logger = logging.getLogger(__name__) -@step(output_materializers=ResearchStateMaterializer) +@step(output_materializers=QueryContextMaterializer) def initial_query_decomposition_step( - state: ResearchState, + main_query: str, query_decomposition_prompt: Prompt, llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", max_sub_questions: int = 8, langfuse_project_name: str = "deep-research", -) -> Annotated[ResearchState, "updated_state"]: +) -> Annotated[QueryContext, "query_context"]: """Break down a complex research query into specific sub-questions. Args: - state: The current research state + main_query: The main research query to decompose query_decomposition_prompt: Prompt for query decomposition llm_model: The reasoning model to use with provider prefix max_sub_questions: Maximum number of sub-questions to generate + langfuse_project_name: Project name for tracing Returns: - Updated research state with sub-questions + QueryContext containing the main query and decomposed sub-questions """ start_time = time.time() - logger.info(f"Decomposing research query: {state.main_query}") + logger.info(f"Decomposing research query: {main_query}") # Get the prompt content system_prompt = str(query_decomposition_prompt) @@ -48,22 +49,22 @@ def initial_query_decomposition_step( # Define fallback questions fallback_questions = [ { - "sub_question": f"What is {state.main_query}?", + "sub_question": f"What is {main_query}?", "reasoning": "Basic understanding of the topic", }, { - "sub_question": f"What are the key aspects of {state.main_query}?", + "sub_question": f"What are the key aspects of {main_query}?", "reasoning": "Exploring important dimensions", }, { - "sub_question": f"What are the implications of {state.main_query}?", + "sub_question": f"What are the implications of {main_query}?", "reasoning": "Understanding broader impact", }, ] # Use utility function to get structured output decomposed_questions = get_structured_llm_output( - prompt=state.main_query, + prompt=main_query, system_prompt=updated_system_prompt, model=llm_model, fallback_response=fallback_questions, @@ -84,8 +85,10 @@ def initial_query_decomposition_step( for i, question in enumerate(sub_questions, 1): logger.info(f" {i}. {question}") - # Update the state with the new sub-questions - state.update_sub_questions(sub_questions) + # Create the QueryContext + query_context = QueryContext( + main_query=main_query, sub_questions=sub_questions + ) # Log step metadata execution_time = time.time() - start_time @@ -97,7 +100,7 @@ def initial_query_decomposition_step( "llm_model": llm_model, "max_sub_questions_requested": max_sub_questions, "fallback_used": False, - "main_query_length": len(state.main_query), + "main_query_length": len(main_query), "sub_questions": sub_questions, } } @@ -113,36 +116,40 @@ def initial_query_decomposition_step( infer_model=True, ) - # Log artifact metadata for the output state + # Log artifact metadata for the output query context log_metadata( metadata={ - "state_characteristics": { - "total_sub_questions": len(state.sub_questions), - "has_search_results": bool(state.search_results), - "has_synthesized_info": bool(state.synthesized_info), + "query_context_characteristics": { + "main_query": main_query, + "num_sub_questions": len(sub_questions), + "timestamp": query_context.decomposition_timestamp, } }, infer_artifact=True, ) # Add tags to the artifact - add_tags(tags=["state", "decomposed"], artifact="updated_state") + # add_tags(tags=["query", "decomposed"], artifact_name="query_context", infer_artifact=True) - return state + return query_context except Exception as e: logger.error(f"Error decomposing query: {e}") - # Return fallback questions in the state + # Return fallback questions fallback_questions = [ - f"What is {state.main_query}?", - f"What are the key aspects of {state.main_query}?", - f"What are the implications of {state.main_query}?", + f"What is {main_query}?", + f"What are the key aspects of {main_query}?", + f"What are the implications of {main_query}?", ] fallback_questions = fallback_questions[:max_sub_questions] logger.info(f"Using {len(fallback_questions)} fallback questions:") for i, question in enumerate(fallback_questions, 1): logger.info(f" {i}. {question}") - state.update_sub_questions(fallback_questions) + + # Create QueryContext with fallback questions + query_context = QueryContext( + main_query=main_query, sub_questions=fallback_questions + ) # Log metadata for fallback scenario execution_time = time.time() - start_time @@ -155,7 +162,7 @@ def initial_query_decomposition_step( "max_sub_questions_requested": max_sub_questions, "fallback_used": True, "error_message": str(e), - "main_query_length": len(state.main_query), + "main_query_length": len(main_query), "sub_questions": fallback_questions, } } @@ -172,6 +179,8 @@ def initial_query_decomposition_step( ) # Add tags to the artifact - add_tags(tags=["state", "decomposed"], artifact="updated_state") + # add_tags( + # tags=["query", "decomposed", "fallback"], artifact_name="query_context", infer_artifact=True + # ) - return state + return query_context diff --git a/deep_research/tests/test_approval_utils.py b/deep_research/tests/test_approval_utils.py index fe859b0f..f1dd15a5 100644 --- a/deep_research/tests/test_approval_utils.py +++ b/deep_research/tests/test_approval_utils.py @@ -6,9 +6,7 @@ format_critique_summary, format_query_list, parse_approval_response, - summarize_research_progress, ) -from utils.pydantic_models import ResearchState, SynthesizedInfo def test_parse_approval_responses(): @@ -77,34 +75,6 @@ def test_format_approval_request(): assert "Missing data" in message -def test_summarize_research_progress(): - """Test research progress summarization.""" - state = ResearchState( - main_query="test", - synthesized_info={ - "q1": SynthesizedInfo( - synthesized_answer="a1", confidence_level="high" - ), - "q2": SynthesizedInfo( - synthesized_answer="a2", confidence_level="medium" - ), - "q3": SynthesizedInfo( - synthesized_answer="a3", confidence_level="low" - ), - "q4": SynthesizedInfo( - synthesized_answer="a4", confidence_level="low" - ), - }, - ) - - summary = summarize_research_progress(state) - - assert summary["completed_count"] == 4 - # (1.0 + 0.5 + 0.0 + 0.0) / 4 = 1.5 / 4 = 0.375, rounded to 0.38 - assert summary["avg_confidence"] == 0.38 - assert summary["low_confidence_count"] == 2 - - def test_format_critique_summary(): """Test critique summary formatting.""" # Test with no critiques diff --git a/deep_research/tests/test_artifact_models.py b/deep_research/tests/test_artifact_models.py new file mode 100644 index 00000000..415862b5 --- /dev/null +++ b/deep_research/tests/test_artifact_models.py @@ -0,0 +1,210 @@ +"""Tests for the new artifact models.""" + +import time + +import pytest +from utils.pydantic_models import ( + AnalysisData, + FinalReport, + QueryContext, + ReflectionMetadata, + SearchCostDetail, + SearchData, + SearchResult, + SynthesisData, + SynthesizedInfo, + ViewpointAnalysis, +) + + +class TestQueryContext: + """Test the QueryContext artifact.""" + + def test_query_context_creation(self): + """Test creating a QueryContext.""" + query = QueryContext( + main_query="What is quantum computing?", + sub_questions=["What are qubits?", "How do quantum gates work?"], + ) + + assert query.main_query == "What is quantum computing?" + assert len(query.sub_questions) == 2 + assert query.decomposition_timestamp > 0 + + def test_query_context_immutable(self): + """Test that QueryContext is immutable.""" + query = QueryContext(main_query="Test query", sub_questions=[]) + + # Should raise error when trying to modify + with pytest.raises(Exception): # Pydantic will raise validation error + query.main_query = "Modified query" + + def test_query_context_defaults(self): + """Test QueryContext with defaults.""" + query = QueryContext(main_query="Test") + assert query.sub_questions == [] + assert query.decomposition_timestamp > 0 + + +class TestSearchData: + """Test the SearchData artifact.""" + + def test_search_data_creation(self): + """Test creating SearchData.""" + search_data = SearchData() + + assert search_data.search_results == {} + assert search_data.search_costs == {} + assert search_data.search_cost_details == [] + assert search_data.total_searches == 0 + + def test_search_data_with_results(self): + """Test SearchData with actual results.""" + result = SearchResult( + url="https://example.com", + content="Test content", + title="Test Title", + ) + + cost_detail = SearchCostDetail( + provider="exa", + query="test query", + cost=0.01, + timestamp=time.time(), + step="process_sub_question", + ) + + search_data = SearchData( + search_results={"Question 1": [result]}, + search_costs={"exa": 0.01}, + search_cost_details=[cost_detail], + total_searches=1, + ) + + assert len(search_data.search_results) == 1 + assert search_data.search_costs["exa"] == 0.01 + assert len(search_data.search_cost_details) == 1 + assert search_data.total_searches == 1 + + def test_search_data_merge(self): + """Test merging SearchData instances.""" + # Create first instance + data1 = SearchData( + search_results={ + "Q1": [SearchResult(url="url1", content="content1")] + }, + search_costs={"exa": 0.01}, + total_searches=1, + ) + + # Create second instance + data2 = SearchData( + search_results={ + "Q1": [SearchResult(url="url2", content="content2")], + "Q2": [SearchResult(url="url3", content="content3")], + }, + search_costs={"exa": 0.02, "tavily": 0.01}, + total_searches=2, + ) + + # Merge + data1.merge(data2) + + # Check results + assert len(data1.search_results["Q1"]) == 2 # Merged Q1 results + assert "Q2" in data1.search_results # Added Q2 + assert data1.search_costs["exa"] == 0.03 # Combined costs + assert data1.search_costs["tavily"] == 0.01 # New provider + assert data1.total_searches == 3 + + +class TestSynthesisData: + """Test the SynthesisData artifact.""" + + def test_synthesis_data_creation(self): + """Test creating SynthesisData.""" + synthesis = SynthesisData() + + assert synthesis.synthesized_info == {} + assert synthesis.enhanced_info == {} + + def test_synthesis_data_with_info(self): + """Test SynthesisData with synthesized info.""" + synth_info = SynthesizedInfo( + synthesized_answer="Test answer", + key_sources=["source1", "source2"], + confidence_level="high", + ) + + synthesis = SynthesisData(synthesized_info={"Q1": synth_info}) + + assert "Q1" in synthesis.synthesized_info + assert synthesis.synthesized_info["Q1"].confidence_level == "high" + + def test_synthesis_data_merge(self): + """Test merging SynthesisData instances.""" + info1 = SynthesizedInfo(synthesized_answer="Answer 1") + info2 = SynthesizedInfo(synthesized_answer="Answer 2") + + data1 = SynthesisData(synthesized_info={"Q1": info1}) + data2 = SynthesisData(synthesized_info={"Q2": info2}) + + data1.merge(data2) + + assert "Q1" in data1.synthesized_info + assert "Q2" in data1.synthesized_info + + +class TestAnalysisData: + """Test the AnalysisData artifact.""" + + def test_analysis_data_creation(self): + """Test creating AnalysisData.""" + analysis = AnalysisData() + + assert analysis.viewpoint_analysis is None + assert analysis.reflection_metadata is None + + def test_analysis_data_with_viewpoint(self): + """Test AnalysisData with viewpoint analysis.""" + viewpoint = ViewpointAnalysis( + main_points_of_agreement=["Point 1", "Point 2"], + perspective_gaps="Some gaps", + ) + + analysis = AnalysisData(viewpoint_analysis=viewpoint) + + assert analysis.viewpoint_analysis is not None + assert len(analysis.viewpoint_analysis.main_points_of_agreement) == 2 + + def test_analysis_data_with_reflection(self): + """Test AnalysisData with reflection metadata.""" + reflection = ReflectionMetadata( + critique_summary=["Critique 1"], improvements_made=3.0 + ) + + analysis = AnalysisData(reflection_metadata=reflection) + + assert analysis.reflection_metadata is not None + assert analysis.reflection_metadata.improvements_made == 3.0 + + +class TestFinalReport: + """Test the FinalReport artifact.""" + + def test_final_report_creation(self): + """Test creating FinalReport.""" + report = FinalReport() + + assert report.report_html == "" + assert report.generated_at > 0 + assert report.main_query == "" + + def test_final_report_with_content(self): + """Test FinalReport with HTML content.""" + html = "Test Report" + report = FinalReport(report_html=html, main_query="What is AI?") + + assert report.report_html == html + assert report.main_query == "What is AI?" + assert report.generated_at > 0 diff --git a/deep_research/tests/test_pydantic_final_report_step.py b/deep_research/tests/test_pydantic_final_report_step.py index c0f13530..b4dcd956 100644 --- a/deep_research/tests/test_pydantic_final_report_step.py +++ b/deep_research/tests/test_pydantic_final_report_step.py @@ -9,9 +9,14 @@ import pytest from steps.pydantic_final_report_step import pydantic_final_report_step from utils.pydantic_models import ( + AnalysisData, + FinalReport, + Prompt, + QueryContext, ReflectionMetadata, - ResearchState, + SearchData, SearchResult, + SynthesisData, SynthesizedInfo, ViewpointAnalysis, ViewpointTension, @@ -20,15 +25,15 @@ @pytest.fixture -def sample_research_state() -> ResearchState: - """Create a sample research state for testing.""" - # Create a basic research state - state = ResearchState(main_query="What are the impacts of climate change?") - - # Add sub-questions - state.update_sub_questions(["Economic impacts", "Environmental impacts"]) +def sample_artifacts(): + """Create sample artifacts for testing.""" + # Create QueryContext + query_context = QueryContext( + main_query="What are the impacts of climate change?", + sub_questions=["Economic impacts", "Environmental impacts"], + ) - # Add search results + # Create SearchData search_results: Dict[str, List[SearchResult]] = { "Economic impacts": [ SearchResult( @@ -37,11 +42,19 @@ def sample_research_state() -> ResearchState: snippet="Overview of economic impacts", content="Detailed content about economic impacts of climate change", ) - ] + ], + "Environmental impacts": [ + SearchResult( + url="https://example.com/environment", + title="Environmental Impacts", + snippet="Environmental impact overview", + content="Content about environmental impacts", + ) + ], } - state.update_search_results(search_results) + search_data = SearchData(search_results=search_results) - # Add synthesized info + # Create SynthesisData synthesized_info: Dict[str, SynthesizedInfo] = { "Economic impacts": SynthesizedInfo( synthesized_answer="Climate change will have significant economic impacts...", @@ -54,12 +67,12 @@ def sample_research_state() -> ResearchState: confidence_level="high", ), } - state.update_synthesized_info(synthesized_info) - - # Add enhanced info (same as synthesized for this test) - state.enhanced_info = state.synthesized_info + synthesis_data = SynthesisData( + synthesized_info=synthesized_info, + enhanced_info=synthesized_info, # Same as synthesized for this test + ) - # Add viewpoint analysis + # Create AnalysisData viewpoint_analysis = ViewpointAnalysis( main_points_of_agreement=[ "Climate change is happening", @@ -77,9 +90,7 @@ def sample_research_state() -> ResearchState: perspective_gaps="Indigenous perspectives are underrepresented", integrative_insights="A balanced approach combining regulations and market incentives may be most effective", ) - state.update_viewpoint_analysis(viewpoint_analysis) - # Add reflection metadata reflection_metadata = ReflectionMetadata( critique_summary=["Need more sources for economic impacts"], additional_questions_identified=[ @@ -89,43 +100,111 @@ def sample_research_state() -> ResearchState: "economic impacts of climate change", "regional climate impacts", ], - improvements_made=2, + improvements_made=2.0, ) - state.reflection_metadata = reflection_metadata - return state + analysis_data = AnalysisData( + viewpoint_analysis=viewpoint_analysis, + reflection_metadata=reflection_metadata, + ) + + # Create prompts + conclusion_prompt = Prompt( + name="conclusion_generation", + content="Generate a conclusion based on the research findings.", + ) + executive_summary_prompt = Prompt( + name="executive_summary", content="Generate an executive summary." + ) + introduction_prompt = Prompt( + name="introduction", content="Generate an introduction." + ) + + return { + "query_context": query_context, + "search_data": search_data, + "synthesis_data": synthesis_data, + "analysis_data": analysis_data, + "conclusion_generation_prompt": conclusion_prompt, + "executive_summary_prompt": executive_summary_prompt, + "introduction_prompt": introduction_prompt, + } def test_pydantic_final_report_step_returns_tuple(): - """Test that the step returns a tuple with state and HTML.""" - # Create a simple state - state = ResearchState(main_query="What is climate change?") - state.update_sub_questions(["What causes climate change?"]) + """Test that the step returns a tuple with FinalReport and HTML.""" + # Create simple artifacts + query_context = QueryContext( + main_query="What is climate change?", + sub_questions=["What causes climate change?"], + ) + search_data = SearchData() + synthesis_data = SynthesisData( + synthesized_info={ + "What causes climate change?": SynthesizedInfo( + synthesized_answer="Climate change is caused by greenhouse gases.", + confidence_level="high", + key_sources=["https://example.com/causes"], + ) + } + ) + analysis_data = AnalysisData() + + # Create prompts + conclusion_prompt = Prompt( + name="conclusion_generation", content="Generate a conclusion." + ) + executive_summary_prompt = Prompt( + name="executive_summary", content="Generate summary." + ) + introduction_prompt = Prompt( + name="introduction", content="Generate intro." + ) # Run the step - result = pydantic_final_report_step(state=state) + result = pydantic_final_report_step( + query_context=query_context, + search_data=search_data, + synthesis_data=synthesis_data, + analysis_data=analysis_data, + conclusion_generation_prompt=conclusion_prompt, + executive_summary_prompt=executive_summary_prompt, + introduction_prompt=introduction_prompt, + ) # Assert that result is a tuple with 2 elements assert isinstance(result, tuple) assert len(result) == 2 - # Assert first element is ResearchState - assert isinstance(result[0], ResearchState) + # Assert first element is FinalReport + assert isinstance(result[0], FinalReport) # Assert second element is HTMLString assert isinstance(result[1], HTMLString) -def test_pydantic_final_report_step_with_complex_state(sample_research_state): - """Test that the step handles a complex state properly.""" - # Run the step with a complex state - result = pydantic_final_report_step(state=sample_research_state) +def test_pydantic_final_report_step_with_complex_artifacts(sample_artifacts): + """Test that the step handles complex artifacts properly.""" + # Run the step with complex artifacts + result = pydantic_final_report_step( + query_context=sample_artifacts["query_context"], + search_data=sample_artifacts["search_data"], + synthesis_data=sample_artifacts["synthesis_data"], + analysis_data=sample_artifacts["analysis_data"], + conclusion_generation_prompt=sample_artifacts[ + "conclusion_generation_prompt" + ], + executive_summary_prompt=sample_artifacts["executive_summary_prompt"], + introduction_prompt=sample_artifacts["introduction_prompt"], + ) # Unpack the results - updated_state, html_report = result + final_report, html_report = result - # Assert state contains final report HTML - assert updated_state.final_report_html != "" + # Assert FinalReport contains expected data + assert final_report.main_query == "What are the impacts of climate change?" + assert len(final_report.sub_questions) == 2 + assert final_report.report_html != "" # Assert HTML report contains key elements html_str = str(html_report) @@ -136,32 +215,51 @@ def test_pydantic_final_report_step_with_complex_state(sample_research_state): assert "Conservative" in html_str -def test_pydantic_final_report_step_updates_state(): - """Test that the step properly updates the state.""" - # Create an initial state without a final report - state = ResearchState( +def test_pydantic_final_report_step_creates_report(): + """Test that the step properly creates a final report.""" + # Create artifacts + query_context = QueryContext( main_query="What is climate change?", sub_questions=["What causes climate change?"], + ) + search_data = SearchData() + synthesis_data = SynthesisData( synthesized_info={ "What causes climate change?": SynthesizedInfo( synthesized_answer="Climate change is caused by greenhouse gases.", confidence_level="high", + key_sources=["https://example.com/causes"], ) - }, - enhanced_info={ - "What causes climate change?": SynthesizedInfo( - synthesized_answer="Climate change is caused by greenhouse gases.", - confidence_level="high", - ) - }, + } ) + analysis_data = AnalysisData() - # Verify initial state has no report - assert state.final_report_html == "" + # Create prompts + conclusion_prompt = Prompt( + name="conclusion_generation", content="Generate a conclusion." + ) + executive_summary_prompt = Prompt( + name="executive_summary", content="Generate summary." + ) + introduction_prompt = Prompt( + name="introduction", content="Generate intro." + ) # Run the step - updated_state, _ = pydantic_final_report_step(state=state) + final_report, html_report = pydantic_final_report_step( + query_context=query_context, + search_data=search_data, + synthesis_data=synthesis_data, + analysis_data=analysis_data, + conclusion_generation_prompt=conclusion_prompt, + executive_summary_prompt=executive_summary_prompt, + introduction_prompt=introduction_prompt, + ) + + # Verify FinalReport was created with content + assert final_report.report_html != "" + assert "climate change" in final_report.report_html.lower() - # Verify state was updated with a report - assert updated_state.final_report_html != "" - assert "climate change" in updated_state.final_report_html.lower() + # Verify HTML report was created + assert str(html_report) != "" + assert "climate change" in str(html_report).lower() diff --git a/deep_research/tests/test_pydantic_materializer.py b/deep_research/tests/test_pydantic_materializer.py deleted file mode 100644 index 49cb17f8..00000000 --- a/deep_research/tests/test_pydantic_materializer.py +++ /dev/null @@ -1,161 +0,0 @@ -"""Tests for Pydantic-based materializer. - -This module contains tests for the Pydantic-based implementation of -ResearchStateMaterializer, verifying that it correctly serializes and -visualizes ResearchState objects. -""" - -import os -import tempfile -from typing import Dict, List - -import pytest -from materializers.pydantic_materializer import ResearchStateMaterializer -from utils.pydantic_models import ( - ResearchState, - SearchResult, - SynthesizedInfo, - ViewpointAnalysis, - ViewpointTension, -) - - -@pytest.fixture -def sample_state() -> ResearchState: - """Create a sample research state for testing.""" - # Create a basic research state - state = ResearchState(main_query="What are the impacts of climate change?") - - # Add sub-questions - state.update_sub_questions(["Economic impacts", "Environmental impacts"]) - - # Add search results - search_results: Dict[str, List[SearchResult]] = { - "Economic impacts": [ - SearchResult( - url="https://example.com/economy", - title="Economic Impacts of Climate Change", - snippet="Overview of economic impacts", - content="Detailed content about economic impacts of climate change", - ) - ] - } - state.update_search_results(search_results) - - # Add synthesized info - synthesized_info: Dict[str, SynthesizedInfo] = { - "Economic impacts": SynthesizedInfo( - synthesized_answer="Climate change will have significant economic impacts...", - key_sources=["https://example.com/economy"], - confidence_level="high", - ) - } - state.update_synthesized_info(synthesized_info) - - return state - - -def test_materializer_initialization(): - """Test that the materializer can be initialized.""" - # Create a temporary directory for artifact storage - with tempfile.TemporaryDirectory() as tmpdirname: - materializer = ResearchStateMaterializer(uri=tmpdirname) - assert materializer is not None - - -def test_materializer_save_and_load(sample_state: ResearchState): - """Test saving and loading a state using the materializer.""" - # Create a temporary directory for artifact storage - with tempfile.TemporaryDirectory() as tmpdirname: - # Initialize materializer with temporary artifact URI - materializer = ResearchStateMaterializer(uri=tmpdirname) - - # Save the state - materializer.save(sample_state) - - # Load the state - loaded_state = materializer.load(ResearchState) - - # Verify that the loaded state matches the original - assert loaded_state.main_query == sample_state.main_query - assert loaded_state.sub_questions == sample_state.sub_questions - assert len(loaded_state.search_results) == len( - sample_state.search_results - ) - assert ( - loaded_state.get_current_stage() - == sample_state.get_current_stage() - ) - - # Check that key fields were preserved - question = "Economic impacts" - assert ( - loaded_state.synthesized_info[question].synthesized_answer - == sample_state.synthesized_info[question].synthesized_answer - ) - assert ( - loaded_state.synthesized_info[question].confidence_level - == sample_state.synthesized_info[question].confidence_level - ) - - -def test_materializer_save_visualizations(sample_state: ResearchState): - """Test generating and saving visualizations.""" - # Create a temporary directory for artifact storage - with tempfile.TemporaryDirectory() as tmpdirname: - # Initialize materializer with temporary artifact URI - materializer = ResearchStateMaterializer(uri=tmpdirname) - - # Generate and save visualizations - viz_paths = materializer.save_visualizations(sample_state) - - # Verify visualization file exists - html_path = list(viz_paths.keys())[0] - assert os.path.exists(html_path) - - # Verify the file has content - with open(html_path, "r") as f: - content = f.read() - # Check for expected elements in the HTML - assert "Research State" in content - assert sample_state.main_query in content - assert "Economic impacts" in content - - -def test_html_generation_stages(sample_state: ResearchState): - """Test that HTML visualization reflects the correct research stage.""" - # Create the materializer - with tempfile.TemporaryDirectory() as tmpdirname: - materializer = ResearchStateMaterializer(uri=tmpdirname) - - # Generate visualization at initial state - html = materializer._generate_visualization_html(sample_state) - # Verify stage by checking for expected elements in the HTML - assert ( - "Synthesized Information" in html - ) # Should show synthesized info - - # Add viewpoint analysis - state_with_viewpoints = sample_state.model_copy(deep=True) - viewpoint_analysis = ViewpointAnalysis( - main_points_of_agreement=["There will be economic impacts"], - areas_of_tension=[ - ViewpointTension( - topic="Job impacts", - viewpoints={ - "Positive": "New green jobs", - "Negative": "Job losses", - }, - ) - ], - ) - state_with_viewpoints.update_viewpoint_analysis(viewpoint_analysis) - html = materializer._generate_visualization_html(state_with_viewpoints) - assert "Viewpoint Analysis" in html - assert "Points of Agreement" in html - - # Add final report - state_with_report = state_with_viewpoints.model_copy(deep=True) - state_with_report.set_final_report("Final report content") - html = materializer._generate_visualization_html(state_with_report) - assert "Final Report" in html diff --git a/deep_research/tests/test_pydantic_models.py b/deep_research/tests/test_pydantic_models.py index b900f8d8..21d25123 100644 --- a/deep_research/tests/test_pydantic_models.py +++ b/deep_research/tests/test_pydantic_models.py @@ -8,11 +8,9 @@ """ import json -from typing import Dict, List from utils.pydantic_models import ( ReflectionMetadata, - ResearchState, SearchResult, SynthesizedInfo, ViewpointAnalysis, @@ -199,105 +197,3 @@ def test_reflection_metadata_model(): new_metadata = ReflectionMetadata.model_validate(metadata_dict) assert new_metadata.improvements_made == metadata.improvements_made assert new_metadata.critique_summary == metadata.critique_summary - - -def test_research_state_model(): - """Test the main ResearchState model.""" - # Create with defaults - state = ResearchState() - assert state.main_query == "" - assert state.sub_questions == [] - assert state.search_results == {} - assert state.get_current_stage() == "empty" - - # Set main query - state.main_query = "What are the impacts of climate change?" - assert state.get_current_stage() == "initial" - - # Test update methods - state.update_sub_questions( - ["What are economic impacts?", "What are environmental impacts?"] - ) - assert len(state.sub_questions) == 2 - assert state.get_current_stage() == "after_query_decomposition" - - # Add search results - search_results: Dict[str, List[SearchResult]] = { - "What are economic impacts?": [ - SearchResult( - url="https://example.com/economy", - title="Economic Impacts", - snippet="Overview of economic impacts", - content="Detailed content about economic impacts", - ) - ] - } - state.update_search_results(search_results) - assert state.get_current_stage() == "after_search" - assert len(state.search_results["What are economic impacts?"]) == 1 - - # Add synthesized info - synthesized_info: Dict[str, SynthesizedInfo] = { - "What are economic impacts?": SynthesizedInfo( - synthesized_answer="Economic impacts include job losses and growth opportunities", - key_sources=["https://example.com/economy"], - confidence_level="high", - ) - } - state.update_synthesized_info(synthesized_info) - assert state.get_current_stage() == "after_synthesis" - - # Add viewpoint analysis - analysis = ViewpointAnalysis( - main_points_of_agreement=["Economic changes are happening"], - areas_of_tension=[ - ViewpointTension( - topic="Job impacts", - viewpoints={ - "Positive": "New green jobs", - "Negative": "Fossil fuel job losses", - }, - ) - ], - ) - state.update_viewpoint_analysis(analysis) - assert state.get_current_stage() == "after_viewpoint_analysis" - - # Add reflection results - enhanced_info = { - "What are economic impacts?": SynthesizedInfo( - synthesized_answer="Enhanced answer with more details", - key_sources=[ - "https://example.com/economy", - "https://example.com/new-source", - ], - confidence_level="high", - improvements=["Added more context", "Added more sources"], - ) - } - metadata = ReflectionMetadata( - critique_summary=["Needed more sources"], - improvements_made=2, - ) - state.update_after_reflection(enhanced_info, metadata) - assert state.get_current_stage() == "after_reflection" - - # Set final report - state.set_final_report("Final report content") - assert state.get_current_stage() == "final_report" - assert state.final_report_html == "Final report content" - - # Test serialization and deserialization - state_dict = state.model_dump() - new_state = ResearchState.model_validate(state_dict) - - # Verify key properties were preserved - assert new_state.main_query == state.main_query - assert len(new_state.sub_questions) == len(state.sub_questions) - assert new_state.get_current_stage() == state.get_current_stage() - assert new_state.viewpoint_analysis is not None - assert len(new_state.viewpoint_analysis.areas_of_tension) == 1 - assert ( - new_state.viewpoint_analysis.areas_of_tension[0].topic == "Job impacts" - ) - assert new_state.final_report_html == state.final_report_html diff --git a/deep_research/utils/approval_utils.py b/deep_research/utils/approval_utils.py index 56a8ff91..94cd5a47 100644 --- a/deep_research/utils/approval_utils.py +++ b/deep_research/utils/approval_utils.py @@ -5,28 +5,6 @@ from utils.pydantic_models import ApprovalDecision -def summarize_research_progress(state) -> Dict[str, Any]: - """Summarize the current research progress.""" - completed_count = len(state.synthesized_info) - confidence_levels = [ - info.confidence_level for info in state.synthesized_info.values() - ] - - # Calculate average confidence (high=1.0, medium=0.5, low=0.0) - confidence_map = {"high": 1.0, "medium": 0.5, "low": 0.0} - avg_confidence = sum( - confidence_map.get(c, 0.5) for c in confidence_levels - ) / max(len(confidence_levels), 1) - - low_confidence_count = sum(1 for c in confidence_levels if c == "low") - - return { - "completed_count": completed_count, - "avg_confidence": round(avg_confidence, 2), - "low_confidence_count": low_confidence_count, - } - - def format_critique_summary(critique_points: List[Dict[str, Any]]) -> str: """Format critique points for display.""" if not critique_points: diff --git a/deep_research/utils/pydantic_models.py b/deep_research/utils/pydantic_models.py index 02ecc3c0..822afe99 100644 --- a/deep_research/utils/pydantic_models.py +++ b/deep_research/utils/pydantic_models.py @@ -247,21 +247,6 @@ def set_final_report(self, html: str) -> None: self.final_report_html = html -class ReflectionOutput(BaseModel): - """Output from the reflection generation step.""" - - state: ResearchState - recommended_queries: List[str] = Field(default_factory=list) - critique_summary: List[Dict[str, Any]] = Field(default_factory=list) - additional_questions: List[str] = Field(default_factory=list) - - model_config = { - "extra": "ignore", - "frozen": False, - "validate_assignment": True, - } - - class ApprovalDecision(BaseModel): """Approval decision from human reviewer.""" @@ -360,3 +345,159 @@ class TracingMetadata(BaseModel): "frozen": False, "validate_assignment": True, } + + +# ============================================================================ +# New Artifact Classes for ResearchState Refactoring +# ============================================================================ + + +class QueryContext(BaseModel): + """Immutable context containing the research query and its decomposition. + + This artifact is created once at the beginning of the pipeline and + remains unchanged throughout execution. + """ + + main_query: str = Field( + ..., description="The main research question from the user" + ) + sub_questions: List[str] = Field( + default_factory=list, + description="Decomposed sub-questions for parallel processing", + ) + decomposition_timestamp: float = Field( + default_factory=lambda: time.time(), + description="When the query was decomposed", + ) + + model_config = { + "extra": "ignore", + "frozen": True, # Make immutable after creation + "validate_assignment": True, + } + + +class SearchCostDetail(BaseModel): + """Detailed information about a single search operation.""" + + provider: str + query: str + cost: float + timestamp: float + step: str + sub_question: Optional[str] = None + + model_config = { + "extra": "ignore", + "frozen": False, + "validate_assignment": True, + } + + +class SearchData(BaseModel): + """Accumulates search results and cost tracking throughout the pipeline. + + This artifact grows as searches are performed and can be merged + when parallel searches complete. + """ + + search_results: Dict[str, List[SearchResult]] = Field( + default_factory=dict, + description="Map of sub-question to search results", + ) + search_costs: Dict[str, float] = Field( + default_factory=dict, description="Total costs by provider" + ) + search_cost_details: List[SearchCostDetail] = Field( + default_factory=list, description="Detailed log of each search" + ) + total_searches: int = Field( + default=0, description="Total number of searches performed" + ) + + model_config = { + "extra": "ignore", + "frozen": False, + "validate_assignment": True, + } + + def merge(self, other: "SearchData") -> "SearchData": + """Merge another SearchData instance into this one.""" + # Merge search results + for sub_q, results in other.search_results.items(): + if sub_q in self.search_results: + self.search_results[sub_q].extend(results) + else: + self.search_results[sub_q] = results + + # Merge costs + for provider, cost in other.search_costs.items(): + self.search_costs[provider] = ( + self.search_costs.get(provider, 0.0) + cost + ) + + # Merge cost details + self.search_cost_details.extend(other.search_cost_details) + + # Update total searches + self.total_searches += other.total_searches + + return self + + +class SynthesisData(BaseModel): + """Contains synthesized information for all sub-questions.""" + + synthesized_info: Dict[str, SynthesizedInfo] = Field( + default_factory=dict, + description="Synthesized answers for each sub-question", + ) + enhanced_info: Dict[str, SynthesizedInfo] = Field( + default_factory=dict, + description="Enhanced information after reflection", + ) + + model_config = { + "extra": "ignore", + "frozen": False, + "validate_assignment": True, + } + + def merge(self, other: "SynthesisData") -> "SynthesisData": + """Merge another SynthesisData instance into this one.""" + self.synthesized_info.update(other.synthesized_info) + self.enhanced_info.update(other.enhanced_info) + return self + + +class AnalysisData(BaseModel): + """Contains viewpoint analysis and reflection metadata.""" + + viewpoint_analysis: Optional[ViewpointAnalysis] = None + reflection_metadata: Optional[ReflectionMetadata] = None + + model_config = { + "extra": "ignore", + "frozen": False, + "validate_assignment": True, + } + + +class FinalReport(BaseModel): + """Contains the final HTML report.""" + + report_html: str = Field(default="", description="The final HTML report") + generated_at: float = Field( + default_factory=lambda: time.time(), + description="Timestamp when report was generated", + ) + main_query: str = Field( + default="", description="The original research query" + ) + + model_config = { + "extra": "ignore", + "frozen": False, + "validate_assignment": True, + }