diff --git a/.gitignore b/.gitignore index 851a4f94..f27e166a 100644 --- a/.gitignore +++ b/.gitignore @@ -170,3 +170,7 @@ nohup.out # Claude .claude/ + +# design folder +design/ +deep_research/design diff --git a/deep_research/README.md b/deep_research/README.md new file mode 100644 index 00000000..c057d85b --- /dev/null +++ b/deep_research/README.md @@ -0,0 +1,604 @@ +# 🔍 ZenML Deep Research Agent + +A production-ready MLOps pipeline for conducting deep, comprehensive research on any topic using LLMs and web search capabilities. + +
+ Research Pipeline Visualization +

ZenML Deep Research pipeline flow

+
+ +## 🎯 Overview + +The ZenML Deep Research Agent is a scalable, modular pipeline that automates in-depth research on any topic. It: + +- Creates a structured outline based on your research query +- Researches each section through targeted web searches and LLM analysis +- Iteratively refines content through reflection cycles +- Produces a comprehensive, well-formatted research report +- Visualizes the research process and report structure in the ZenML dashboard + +This project transforms exploratory notebook-based research into a production-grade, reproducible, and transparent process using the ZenML MLOps framework. + +## 📝 Example Research Results + +The Deep Research Agent produces comprehensive, well-structured reports on any topic. Here's an example of research conducted on quantum computing: + +
+ Sample Research Report +

Sample report generated by the Deep Research Agent

+
+ +## 🚀 Pipeline Architecture + +The pipeline uses a parallel processing architecture for efficiency and breaks down the research process into granular steps for maximum modularity and control: + +1. **Initialize Prompts**: Load and track all prompts as versioned artifacts +2. **Query Decomposition**: Break down the main query into specific sub-questions +3. **Parallel Information Gathering**: Process multiple sub-questions concurrently for faster results +4. **Merge Results**: Combine results from parallel processing into a unified state +5. **Cross-Viewpoint Analysis**: Analyze discrepancies and agreements between different perspectives +6. **Reflection Generation**: Generate recommendations for improving research quality +7. **Human Approval** (optional): Get human approval for additional searches +8. **Execute Approved Searches**: Perform approved additional searches to fill gaps +9. **Final Report Generation**: Compile all synthesized information into a coherent HTML report +10. **Collect Tracing Metadata**: Gather comprehensive metrics about token usage, costs, and performance + +This architecture enables: +- Better reproducibility and caching of intermediate results +- Parallel processing for faster research completion +- Easier debugging and monitoring of specific research stages +- More flexible reconfiguration of individual components +- Enhanced transparency into how the research is conducted +- Human oversight and control over iterative research expansions + +## 💡 Under the Hood + +- **LLM Integration**: Uses litellm for flexible access to various LLM providers +- **Web Research**: Utilizes Tavily API for targeted internet searches +- **ZenML Orchestration**: Manages pipeline flow, artifacts, and caching +- **Reproducibility**: Track every step, parameter, and output via ZenML +- **Visualizations**: Interactive visualizations of the research structure and progress +- **Report Generation**: Uses static HTML templates for consistent, high-quality reports +- **Human-in-the-Loop**: Optional approval mechanism via ZenML alerters (Discord, Slack, etc.) +- **LLM Observability**: Integrated Langfuse tracking for monitoring LLM usage, costs, and performance + +## 🛠️ Getting Started + +### Prerequisites + +- Python 3.9+ +- ZenML installed and configured +- API key for your preferred LLM provider (configured with litellm) +- Tavily API key +- Langfuse account for LLM tracking (optional but recommended) + +### Installation + +```bash +# Clone the repository +git clone +cd zenml_deep_research + +# Install dependencies +pip install -r requirements.txt + +# Set up API keys +export OPENAI_API_KEY=your_openai_key # Or another LLM provider key +export TAVILY_API_KEY=your_tavily_key # For Tavily search (default) +export EXA_API_KEY=your_exa_key # For Exa search (optional) + +# Set up Langfuse for LLM tracking (optional) +export LANGFUSE_PUBLIC_KEY=your_public_key +export LANGFUSE_SECRET_KEY=your_secret_key +export LANGFUSE_HOST=https://cloud.langfuse.com # Or your self-hosted URL + +# Initialize ZenML (if needed) +zenml init +``` + +### Setting up Langfuse for LLM Tracking + +The pipeline integrates with [Langfuse](https://langfuse.com) for comprehensive LLM observability and tracking. This allows you to monitor LLM usage, costs, and performance across all pipeline runs. + +#### 1. Create a Langfuse Account + +1. Sign up at [cloud.langfuse.com](https://cloud.langfuse.com) or set up a self-hosted instance +2. Create a new project in your Langfuse dashboard (e.g., "deep-research") +3. Navigate to Settings → API Keys to get your credentials + +#### 2. Configure Environment Variables + +Set the following environment variables with your Langfuse credentials: + +```bash +export LANGFUSE_PUBLIC_KEY=pk-lf-... # Your public key +export LANGFUSE_SECRET_KEY=sk-lf-... # Your secret key +export LANGFUSE_HOST=https://cloud.langfuse.com # Or your self-hosted URL +``` + +#### 3. Configure Project Name + +The Langfuse project name can be configured in any of the pipeline configuration files: + +```yaml +# configs/enhanced_research.yaml +langfuse_project_name: "deep-research" # Change to match your Langfuse project +``` + +**Note**: The project must already exist in your Langfuse dashboard before running the pipeline. + +#### What Gets Tracked + +When Langfuse is configured, the pipeline automatically tracks: + +- **All LLM calls** with their prompts, responses, and token usage +- **Pipeline trace information** including: + - `trace_name`: The ZenML pipeline run name for easy identification + - `trace_id`: The unique ZenML pipeline run ID for correlation +- **Tagged operations** such as: + - `structured_llm_output`: JSON generation calls + - `information_synthesis`: Research synthesis operations + - `find_most_relevant_string`: Relevance matching operations +- **Performance metrics**: Latency, token counts, and costs +- **Project organization**: All traces are organized under your configured project + +This integration provides full observability into your research pipeline's LLM usage, making it easy to optimize performance, track costs, and debug issues. + +### Running the Pipeline + +#### Basic Usage + +```bash +# Run with default configuration +python run.py +``` + +The default configuration and research query are defined in `configs/enhanced_research.yaml`. + +#### Using Research Mode Presets + +The pipeline includes three pre-configured research modes for different use cases: + +```bash +# Rapid mode - Quick overview with minimal depth +python run.py --mode rapid + +# Balanced mode - Standard research depth (default) +python run.py --mode balanced + +# Deep mode - Comprehensive analysis with maximum depth +python run.py --mode deep +``` + +**Mode Comparison:** + +| Mode | Sub-Questions | Search Results* | Additional Searches | Best For | +|------|---------------|----------------|-------------------|----------| +| **Rapid** | 5 | 2 per search | 0 | Quick overviews, time-sensitive research | +| **Balanced** | 10 | 3 per search | 2 | Most research tasks, good depth/speed ratio | +| **Deep** | 15 | 5 per search | 4 | Comprehensive analysis, academic research | + +*Can be overridden with `--num-results` + +#### Using Different Configurations + +```bash +# Run with a custom configuration file +python run.py --config configs/custom_enhanced_config.yaml + +# Override the research query from command line +python run.py --query "My research topic" + +# Specify maximum number of sub-questions to process in parallel +python run.py --max-sub-questions 15 + +# Combine mode with other options +python run.py --mode deep --query "Complex topic" --require-approval + +# Combine multiple options +python run.py --config configs/custom_enhanced_config.yaml --query "My research topic" --max-sub-questions 12 +``` + +### Advanced Options + +```bash +# Enable debug logging +python run.py --debug + +# Disable caching for a fresh run +python run.py --no-cache + +# Specify a log file +python run.py --log-file research.log + +# Enable human-in-the-loop approval for additional research +python run.py --require-approval + +# Set approval timeout (in seconds) +python run.py --require-approval --approval-timeout 7200 + +# Use a different search provider (default: tavily) +python run.py --search-provider exa # Use Exa search +python run.py --search-provider both # Use both providers +python run.py --search-provider exa --search-mode neural # Exa with neural search + +# Control the number of search results per query +python run.py --num-results 5 # Get 5 results per search +python run.py --num-results 10 --search-provider exa # 10 results with Exa +``` + +### Search Providers + +The pipeline supports multiple search providers for flexibility and comparison: + +#### Available Providers + +1. **Tavily** (Default) + - Traditional keyword-based search + - Good for factual information and current events + - Requires `TAVILY_API_KEY` environment variable + +2. **Exa** + - Neural search engine with semantic understanding + - Better for conceptual and research-oriented queries + - Supports three search modes: + - `auto` (default): Automatically chooses between neural and keyword + - `neural`: Semantic search for conceptual understanding + - `keyword`: Traditional keyword matching + - Requires `EXA_API_KEY` environment variable + +3. **Both** + - Runs searches on both providers + - Useful for comprehensive research or comparing results + - Requires both API keys + +#### Usage Examples + +```bash +# Use Exa with neural search +python run.py --search-provider exa --search-mode neural + +# Compare results from both providers +python run.py --search-provider both + +# Use Exa with keyword search for exact matches +python run.py --search-provider exa --search-mode keyword + +# Combine with other options +python run.py --mode deep --search-provider exa --require-approval +``` + +### Human-in-the-Loop Approval + +The pipeline supports human approval for additional research queries identified during the reflection phase: + +```bash +# Enable approval with default 1-hour timeout +python run.py --require-approval + +# Custom timeout (2 hours) +python run.py --require-approval --approval-timeout 7200 + +# Approval works with any configuration +python run.py --config configs/thorough_research.yaml --require-approval +``` + +When enabled, the pipeline will: +1. Pause after the initial research phase +2. Send an approval request via your configured ZenML alerter (Discord, Slack, etc.) +3. Present research progress, identified gaps, and proposed additional queries +4. Wait for your approval before conducting additional searches +5. Continue with approved queries or finalize the report based on your decision + +**Note**: You need a ZenML stack with an alerter configured (e.g., Discord or Slack) for approval functionality to work. + +**Tip**: When using `--mode deep`, the pipeline will suggest enabling `--require-approval` for better control over the comprehensive research process. + +## 📊 Visualizing Research Process + +The pipeline includes built-in visualizations to help you understand and monitor the research process: + +### Viewing Visualizations + +After running the pipeline, you can view the visualizations in the ZenML dashboard: + +1. Start the ZenML dashboard: + ```bash + zenml up + ``` + +2. Navigate to the "Runs" tab in the dashboard +3. Select your pipeline run +4. Explore visualizations for each step: + - **initialize_prompts_step**: View all prompts used in the pipeline + - **initial_query_decomposition_step**: See how the query was broken down + - **process_sub_question_step**: Track progress for each sub-question + - **cross_viewpoint_analysis_step**: View viewpoint analysis results + - **generate_reflection_step**: See reflection and recommendations + - **get_research_approval_step**: View approval decisions + - **pydantic_final_report_step**: Access the final research state + - **collect_tracing_metadata_step**: View comprehensive cost and performance metrics + +### Visualization Features + +The visualizations provide: +- An overview of the report structure +- Details of each paragraph's research status +- Search history and source information +- Progress through reflection iterations +- Professionally formatted HTML reports with static templates + +### Sample Visualization + +Here's what the report structure visualization looks like: + +``` +Report Structure: +├── Introduction +│ └── Initial understanding of the topic +├── Historical Background +│ └── Evolution and key developments +├── Current State +│ └── Latest advancements and implementations +└── Conclusion + └── Summary and future implications +``` + +## 📁 Project Structure + +``` +zenml_deep_research/ +├── configs/ # Configuration files +│ ├── __init__.py +│ └── enhanced_research.yaml # Main configuration file +├── materializers/ # Custom materializers for artifact storage +│ ├── __init__.py +│ └── pydantic_materializer.py +├── pipelines/ # ZenML pipeline definitions +│ ├── __init__.py +│ └── parallel_research_pipeline.py +├── steps/ # ZenML pipeline steps +│ ├── __init__.py +│ ├── approval_step.py # Human approval step for additional research +│ ├── cross_viewpoint_step.py +│ ├── execute_approved_searches_step.py # Execute approved searches +│ ├── generate_reflection_step.py # Generate reflection without execution +│ ├── iterative_reflection_step.py # Legacy combined reflection step +│ ├── merge_results_step.py +│ ├── process_sub_question_step.py +│ ├── pydantic_final_report_step.py +│ └── query_decomposition_step.py +├── utils/ # Utility functions and helpers +│ ├── __init__.py +│ ├── approval_utils.py # Human approval utilities +│ ├── helper_functions.py +│ ├── llm_utils.py # LLM integration utilities +│ ├── prompts.py # Contains prompt templates and HTML templates +│ ├── pydantic_models.py # Data models using Pydantic +│ └── search_utils.py # Web search functionality +├── __init__.py +├── requirements.txt # Project dependencies +├── logging_config.py # Logging configuration +├── README.md # Project documentation +└── run.py # Main script to run the pipeline +``` + +## 🔧 Customization + +The project supports two levels of customization: + +### 1. Command-Line Parameters + +You can customize the research behavior directly through command-line parameters: + +```bash +# Specify your research query +python run.py --query "Your research topic" + +# Control parallelism with max-sub-questions +python run.py --max-sub-questions 15 + +# Combine multiple options +python run.py --query "Your research topic" --max-sub-questions 12 --no-cache +``` + +These settings control how the parallel pipeline processes your research query. + +### 2. Pipeline Configuration + +For more detailed settings, modify the configuration file: + +```yaml +# configs/enhanced_research.yaml + +# Enhanced Deep Research Pipeline Configuration +enable_cache: true + +# Research query parameters +query: "Climate change policy debates" + +# Step configurations +steps: + initial_query_decomposition_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + + cross_viewpoint_analysis_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + viewpoint_categories: ["scientific", "political", "economic", "social", "ethical", "historical"] + + iterative_reflection_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + max_additional_searches: 2 + num_results_per_search: 3 + + # Human approval configuration (when using --require-approval) + get_research_approval_step: + parameters: + timeout: 3600 # 1 hour timeout for approval + max_queries: 2 # Maximum queries to present for approval + + pydantic_final_report_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + +# Environment settings +settings: + docker: + requirements: + - openai>=1.0.0 + - tavily-python>=0.2.8 + - PyYAML>=6.0 + - click>=8.0.0 + - pydantic>=2.0.0 + - typing_extensions>=4.0.0 +``` + +To use a custom configuration file: + +```bash +python run.py --config configs/custom_research.yaml +``` + +### Available Configurations + +**Mode-Based Configurations** (automatically selected when using `--mode`): + +| Config File | Mode | Description | +|-------------|------|-------------| +| `rapid_research.yaml` | `--mode rapid` | Quick overview with minimal depth | +| `balanced_research.yaml` | `--mode balanced` | Standard research with moderate depth | +| `deep_research.yaml` | `--mode deep` | Comprehensive analysis with maximum depth | + +**Specialized Configurations:** + +| Config File | Description | Key Parameters | +|-------------|-------------|----------------| +| `enhanced_research.yaml` | Default research configuration | Standard settings, 2 additional searches | +| `thorough_research.yaml` | In-depth analysis | 12 sub-questions, 5 results per search | +| `quick_research.yaml` | Faster results | 5 sub-questions, 2 results per search | +| `daily_trends.yaml` | Research on recent topics | 24-hour search recency, disable cache | +| `compare_viewpoints.yaml` | Focus on comparing perspectives | Extended viewpoint categories | +| `parallel_research.yaml` | Optimized for parallel execution | Configured for distributed orchestrators | + +You can create additional configuration files by copying and modifying the base configuration files above. + +## 🎯 Prompts Tracking and Management + +The pipeline includes a sophisticated prompts tracking system that allows you to track all prompts as versioned artifacts in ZenML. This provides better observability, version control, and visualization of the prompts used in your research pipeline. + +### Overview + +The prompts tracking system enables: +- **Artifact Tracking**: All prompts are tracked as versioned artifacts in ZenML +- **Beautiful Visualizations**: HTML interface in the dashboard with search, copy, and expand features +- **Version Control**: Prompts are versioned alongside your code +- **Pipeline Integration**: Prompts are passed through the pipeline as artifacts, not hardcoded imports + +### Components + +1. **PromptsBundle Model** (`utils/prompt_models.py`) + - Pydantic model containing all prompts used in the pipeline + - Each prompt includes metadata: name, content, description, version, and tags + +2. **PromptsBundleMaterializer** (`materializers/prompts_materializer.py`) + - Custom materializer creating HTML visualizations in the ZenML dashboard + - Features: search, copy-to-clipboard, expandable content, tag categorization + +3. **Prompt Loader** (`utils/prompt_loader.py`) + - Utility to load prompts from `prompts.py` into a PromptsBundle + +### Integration Guide + +To integrate prompts tracking into a pipeline: + +1. **Initialize prompts as the first step:** + ```python + from steps.initialize_prompts_step import initialize_prompts_step + + @pipeline + def my_pipeline(): + prompts_bundle = initialize_prompts_step(pipeline_version="1.0.0") + ``` + +2. **Update steps to receive prompts_bundle:** + ```python + @step + def my_step(state: ResearchState, prompts_bundle: PromptsBundle): + prompt = prompts_bundle.get_prompt_content("synthesis_prompt") + # Use prompt in your step logic + ``` + +3. **Pass prompts_bundle through the pipeline:** + ```python + state = synthesis_step(state=state, prompts_bundle=prompts_bundle) + ``` + +### Benefits + +- **Full Tracking**: Every pipeline run tracks which exact prompts were used +- **Version History**: See how prompts evolved across different runs +- **Debugging**: Easily identify which prompts produced specific outputs +- **A/B Testing**: Compare results using different prompt versions + +### Visualization Features + +The HTML visualization in the ZenML dashboard includes: +- Pipeline version and creation timestamp +- Statistics (total prompts, tagged prompts, custom prompts) +- Search functionality across all prompt content +- Expandable/collapsible prompt content +- One-click copy to clipboard +- Tag-based categorization with visual indicators + +## 📊 Cost and Performance Tracking + +The pipeline includes comprehensive tracking of costs and performance metrics through the `collect_tracing_metadata_step`, which runs at the end of each pipeline execution. + +### Tracked Metrics + +- **LLM Costs**: Detailed breakdown by model and prompt type +- **Search Costs**: Tracking for both Tavily and Exa search providers +- **Token Usage**: Input/output tokens per model and step +- **Performance**: Latency and execution time metrics +- **Cost Attribution**: See which steps and prompts consume the most resources + +### Viewing Metrics + +After pipeline execution, the tracing metadata is available in the ZenML dashboard: + +1. Navigate to your pipeline run +2. Find the `collect_tracing_metadata_step` +3. View the comprehensive cost visualization including: + - Total pipeline cost (LLM + Search) + - Cost breakdown by model + - Token usage distribution + - Performance metrics + +This helps you: +- Optimize pipeline costs by identifying expensive operations +- Monitor token usage to stay within limits +- Track performance over time +- Make informed decisions about model selection + +## 📈 Example Use Cases + +- **Academic Research**: Rapidly generate preliminary research on academic topics +- **Business Intelligence**: Stay informed on industry trends and competitive landscape +- **Content Creation**: Develop well-researched content for articles, blogs, or reports +- **Decision Support**: Gather comprehensive information for informed decision-making + +## 🔄 Integration Possibilities + +This pipeline can integrate with: + +- **Document Storage**: Save reports to database or document management systems +- **Web Applications**: Power research functionality in web interfaces +- **Alerting Systems**: Schedule research on key topics and receive regular reports +- **Other ZenML Pipelines**: Chain with downstream analysis or processing + +## 📄 License + +This project is licensed under the Apache License 2.0. diff --git a/deep_research/__init__.py b/deep_research/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/deep_research/assets/styles.css b/deep_research/assets/styles.css new file mode 100644 index 00000000..c0e5523a --- /dev/null +++ b/deep_research/assets/styles.css @@ -0,0 +1,692 @@ +/* =================================== + Deep Research Pipeline Global Styles + =================================== */ + +/* 1. CSS Variables / Custom Properties */ +:root { + /* Color Palette - ZenML Design System */ + --color-primary: #7a3ef4; + --color-primary-dark: #6b35db; + --color-primary-light: #9d6ff7; + --color-secondary: #667eea; + --color-secondary-dark: #5a63d8; + --color-accent: #764ba2; + + /* Status Colors - ZenML Semantic Colors */ + --color-success: #179f3e; + --color-success-light: #d4edda; + --color-success-dark: #155724; + --color-warning: #a65d07; + --color-warning-light: #fff3cd; + --color-warning-dark: #856404; + --color-danger: #dc3545; + --color-danger-light: #f8d7da; + --color-danger-dark: #721c24; + --color-info: #007bff; + --color-info-light: #d1ecf1; + --color-info-dark: #004085; + + /* Chart Colors - ZenML Palette */ + --color-chart-1: #7a3ef4; + --color-chart-2: #179f3e; + --color-chart-3: #007bff; + --color-chart-4: #dc3545; + --color-chart-5: #a65d07; + --color-chart-6: #6c757d; + + /* Neutrals */ + --color-text-primary: #333; + --color-text-secondary: #666; + --color-text-muted: #999; + --color-text-light: #7f8c8d; + --color-heading: #2c3e50; + --color-bg-primary: #f5f7fa; + --color-bg-secondary: #f8f9fa; + --color-bg-light: #f0f2f5; + --color-bg-white: #ffffff; + --color-border: #e9ecef; + --color-border-light: #dee2e6; + --color-border-dark: #ddd; + + /* Typography - ZenML Font Stack */ + --font-family-base: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, 'Open Sans', 'Helvetica Neue', sans-serif; + --font-family-mono: 'Monaco', 'Menlo', 'Ubuntu Mono', 'Consolas', 'source-code-pro', monospace; + + /* Spacing - ZenML 8px Grid System */ + --spacing-xs: 4px; + --spacing-sm: 8px; + --spacing-md: 16px; + --spacing-lg: 24px; + --spacing-xl: 32px; + --spacing-xxl: 48px; + + /* Border Radius - ZenML Subtle Corners */ + --radius-sm: 4px; + --radius-md: 6px; + --radius-lg: 8px; + --radius-xl: 12px; + --radius-round: 50%; + + /* Shadows - ZenML Subtle Shadows */ + --shadow-sm: 0 1px 2px rgba(0, 0, 0, 0.05); + --shadow-md: 0 4px 12px rgba(0, 0, 0, 0.1); + --shadow-lg: 0 8px 24px rgba(0, 0, 0, 0.12); + --shadow-xl: 0 12px 48px rgba(0, 0, 0, 0.15); + --shadow-hover: 0 6px 16px rgba(0, 0, 0, 0.1); + --shadow-hover-lg: 0 8px 24px rgba(0, 0, 0, 0.15); + + /* Transitions */ + --transition-base: all 0.3s ease; + --transition-fast: all 0.2s ease; +} + +/* 2. Base Styles */ +* { + box-sizing: border-box; +} + +body { + font-family: var(--font-family-base); + font-size: 14px; + line-height: 1.6; + color: var(--color-text-primary); + background-color: var(--color-bg-primary); + margin: 0; + padding: var(--spacing-md); + -webkit-font-smoothing: antialiased; + -moz-osx-font-smoothing: grayscale; +} + +/* 3. Layout Components */ +.dr-container { + max-width: 1200px; + margin: 0 auto; + padding: var(--spacing-md); +} + +.dr-container--wide { + max-width: 1400px; +} + +.dr-container--narrow { + max-width: 900px; +} + +/* 4. Typography */ +.dr-h1, h1 { + color: var(--color-heading); + font-size: 2em; + font-weight: 500; + margin: 0 0 var(--spacing-lg) 0; + padding-bottom: var(--spacing-sm); + border-bottom: 2px solid var(--color-primary); +} + +.dr-h1--no-border { + border-bottom: none; + padding-bottom: 0; +} + +.dr-h2, h2 { + color: var(--color-heading); + font-size: 1.4em; + font-weight: 500; + margin-top: var(--spacing-lg); + margin-bottom: var(--spacing-md); + border-bottom: 1px solid var(--color-border); + padding-bottom: var(--spacing-xs); +} + +.dr-h3, h3 { + color: var(--color-primary); + font-size: 1.2em; + font-weight: 500; + margin-top: var(--spacing-md); + margin-bottom: var(--spacing-sm); +} + +p { + margin: var(--spacing-md) 0; + line-height: 1.6; + color: var(--color-text-secondary); +} + +/* 5. Card Components */ +.dr-card { + background: var(--color-bg-white); + border-radius: var(--radius-md); + padding: var(--spacing-lg); + box-shadow: var(--shadow-md); + margin-bottom: var(--spacing-lg); + transition: var(--transition-base); +} + +.dr-card:hover { + transform: translateY(-2px); + box-shadow: var(--shadow-hover); +} + +.dr-card--bordered { + border: 1px solid var(--color-border-light); +} + +.dr-card--no-hover:hover { + transform: none; + box-shadow: var(--shadow-md); +} + +/* Header Cards */ +.dr-header-card { + background: white; + border-radius: var(--radius-md); + padding: var(--spacing-lg); + box-shadow: var(--shadow-sm); + margin-bottom: var(--spacing-lg); + border: 1px solid var(--color-border-light); +} + +/* 6. Grid System */ +.dr-grid { + display: grid; + gap: var(--spacing-md); +} + +.dr-grid--stats { + grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); +} + +.dr-grid--cards { + grid-template-columns: repeat(auto-fit, minmax(300px, 1fr)); +} + +.dr-grid--metrics { + grid-template-columns: repeat(auto-fit, minmax(250px, 1fr)); +} + +/* 7. Badges & Tags */ +.dr-badge { + display: inline-block; + padding: 4px 12px; + border-radius: 12px; + font-size: 12px; + font-weight: 600; + text-transform: uppercase; + letter-spacing: 0.3px; + line-height: 1.5; +} + +.dr-badge--success { + background-color: var(--color-success-light); + color: var(--color-success-dark); +} + +.dr-badge--warning { + background-color: var(--color-warning-light); + color: var(--color-warning-dark); +} + +.dr-badge--danger { + background-color: var(--color-danger-light); + color: var(--color-danger-dark); +} + +.dr-badge--info { + background-color: var(--color-info-light); + color: var(--color-info-dark); +} + +.dr-badge--primary { + background-color: var(--color-primary); + color: white; +} + +/* Tag variations */ +.dr-tag { + display: inline-block; + background-color: #f0f0f0; + color: #555; + padding: 4px 10px; + border-radius: var(--radius-sm); + font-size: 12px; + font-weight: 500; + margin: 2px; +} + +.dr-tag--primary { + background-color: #e1f5fe; + color: #0277bd; +} + +/* 8. Stat Cards */ +.dr-stat-card { + background: var(--color-bg-white); + border-radius: var(--radius-md); + padding: var(--spacing-lg); + text-align: center; + transition: var(--transition-base); + border: 1px solid var(--color-border); + box-shadow: var(--shadow-sm); +} + +.dr-stat-card:hover { + transform: translateY(-2px); + box-shadow: var(--shadow-hover); +} + +.dr-stat-value { + font-size: 2rem; + font-weight: 600; + color: var(--color-primary); + margin-bottom: var(--spacing-xs); + display: block; +} + +.dr-stat-label { + color: var(--color-text-secondary); + font-size: 13px; + text-transform: uppercase; + letter-spacing: 0.3px; + display: block; + font-weight: 500; +} + +/* 9. Sections */ +.dr-section { + background: var(--color-bg-white); + border-radius: var(--radius-md); + padding: var(--spacing-lg); + margin-bottom: var(--spacing-lg); + box-shadow: var(--shadow-sm); + border: 1px solid var(--color-border-light); +} + +.dr-section--bordered { + border-left: 3px solid var(--color-primary); +} + +.dr-section--info { + background-color: #e8f4f8; + border-left: 4px solid var(--color-primary); +} + +.dr-section--warning { + background-color: var(--color-warning-light); + border-left: 4px solid var(--color-warning); +} + +.dr-section--success { + background-color: var(--color-success-light); + border-left: 4px solid var(--color-success); +} + +.dr-section--danger { + background-color: var(--color-danger-light); + border-left: 4px solid var(--color-danger); +} + +/* 10. Tables */ +.dr-table { + width: 100%; + border-collapse: collapse; + margin: var(--spacing-md) 0; + background: var(--color-bg-white); + overflow: hidden; +} + +.dr-table th { + background-color: var(--color-primary); + color: white; + padding: var(--spacing-sm); + text-align: left; + font-weight: 600; +} + +.dr-table td { + padding: var(--spacing-sm); + border-bottom: 1px solid var(--color-border); +} + +.dr-table tr:last-child td { + border-bottom: none; +} + +.dr-table tr:hover { + background-color: var(--color-bg-secondary); +} + +.dr-table--striped tr:nth-child(even) { + background-color: #f2f2f2; +} + +/* 11. Buttons */ +.dr-button { + background: var(--color-primary); + color: white; + border: none; + padding: 10px 20px; + border-radius: var(--radius-md); + font-size: 14px; + font-weight: 500; + cursor: pointer; + transition: var(--transition-base); + display: inline-flex; + align-items: center; + gap: var(--spacing-xs); + text-decoration: none; + position: relative; + overflow: hidden; +} + +.dr-button:hover { + background: var(--color-primary-dark); + transform: translateY(-1px); + box-shadow: var(--shadow-hover); +} + +.dr-button:active { + transform: translateY(0); + box-shadow: var(--shadow-sm); +} + +.dr-button--secondary { + background: var(--color-secondary); +} + +.dr-button--secondary:hover { + background: var(--color-secondary-dark); + box-shadow: var(--shadow-hover); +} + +.dr-button--success { + background: var(--color-success); +} + +.dr-button--small { + padding: 6px 12px; + font-size: 12px; +} + +/* 12. Confidence Indicators */ +.dr-confidence { + display: inline-flex; + align-items: center; + padding: 6px 16px; + border-radius: 20px; + font-weight: 600; + font-size: 13px; + gap: var(--spacing-xs); + box-shadow: var(--shadow-sm); +} + +.dr-confidence--high { + background: linear-gradient(to right, #d4edda, #c3e6cb); + color: var(--color-success-dark); +} + +.dr-confidence--medium { + background: linear-gradient(to right, #fff3cd, #ffeeba); + color: var(--color-warning-dark); +} + +.dr-confidence--low { + background: linear-gradient(to right, #f8d7da, #f5c6cb); + color: var(--color-danger-dark); +} + +/* 13. Chart Containers */ +.dr-chart-container { + position: relative; + height: 300px; + margin: var(--spacing-md) 0; +} + +/* 14. Code Blocks */ +.dr-code { + background-color: #f7f7f7; + border: 1px solid #e1e1e8; + border-radius: var(--radius-sm); + padding: var(--spacing-sm); + font-family: var(--font-family-mono); + overflow-x: auto; + white-space: pre-wrap; + word-wrap: break-word; +} + +/* 15. Lists */ +.dr-list { + margin: var(--spacing-sm) 0; + padding-left: 25px; +} + +.dr-list li { + margin: 8px 0; + line-height: 1.6; +} + +.dr-list--unstyled { + list-style-type: none; + padding-left: 0; +} + +/* 16. Notice Boxes */ +.dr-notice { + padding: 15px; + margin: 20px 0; + border-radius: var(--radius-sm); +} + +.dr-notice--info { + background-color: #e8f4f8; + border-left: 4px solid var(--color-primary); + color: var(--color-info-dark); +} + +.dr-notice--warning { + background-color: var(--color-warning-light); + border-left: 4px solid var(--color-warning); + color: var(--color-warning-dark); +} + +/* 17. Loading States */ +.dr-loading { + text-align: center; + padding: var(--spacing-xxl); + color: var(--color-text-secondary); + font-style: italic; +} + +/* 18. Empty States */ +.dr-empty { + text-align: center; + color: var(--color-text-muted); + font-style: italic; + padding: var(--spacing-xl); + background: var(--color-bg-white); + border-radius: var(--radius-lg); + box-shadow: var(--shadow-md); +} + +/* 19. Utility Classes */ +.dr-text-center { text-align: center; } +.dr-text-right { text-align: right; } +.dr-text-left { text-align: left; } +.dr-text-muted { color: var(--color-text-muted); } +.dr-text-secondary { color: var(--color-text-secondary); } +.dr-text-primary { color: var(--color-text-primary); } + +/* Margin utilities */ +.dr-mt-xs { margin-top: var(--spacing-xs); } +.dr-mt-sm { margin-top: var(--spacing-sm); } +.dr-mt-md { margin-top: var(--spacing-md); } +.dr-mt-lg { margin-top: var(--spacing-lg); } +.dr-mt-xl { margin-top: var(--spacing-xl); } + +.dr-mb-xs { margin-bottom: var(--spacing-xs); } +.dr-mb-sm { margin-bottom: var(--spacing-sm); } +.dr-mb-md { margin-bottom: var(--spacing-md); } +.dr-mb-lg { margin-bottom: var(--spacing-lg); } +.dr-mb-xl { margin-bottom: var(--spacing-xl); } + +/* Padding utilities */ +.dr-p-sm { padding: var(--spacing-sm); } +.dr-p-md { padding: var(--spacing-md); } +.dr-p-lg { padding: var(--spacing-lg); } + +/* Display utilities */ +.dr-d-none { display: none; } +.dr-d-block { display: block; } +.dr-d-flex { display: flex; } +.dr-d-grid { display: grid; } + +/* Flex utilities */ +.dr-flex-center { + display: flex; + align-items: center; + justify-content: center; +} + +.dr-flex-between { + display: flex; + align-items: center; + justify-content: space-between; +} + +/* 20. Special Components */ + +/* Mind Map Styles */ +.dr-mind-map { + position: relative; + margin: var(--spacing-xl) 0; +} + +.dr-mind-map-node { + background: linear-gradient(135deg, var(--color-primary) 0%, var(--color-primary-dark) 100%); + color: white; + padding: var(--spacing-lg); + border-radius: var(--radius-md); + text-align: center; + font-size: 1.25rem; + font-weight: 600; + box-shadow: var(--shadow-md); + margin-bottom: var(--spacing-xl); +} + +/* Result Cards */ +.dr-result-item { + background: var(--color-bg-secondary); + border-radius: var(--radius-md); + padding: 15px; + margin-bottom: 15px; + border: 1px solid var(--color-border); + transition: var(--transition-base); +} + +.dr-result-item:hover { + box-shadow: var(--shadow-hover); + transform: translateY(-1px); +} + +.dr-result-title { + font-weight: 600; + color: var(--color-heading); + margin-bottom: var(--spacing-xs); +} + +.dr-result-snippet { + color: var(--color-text-secondary); + font-size: 13px; + line-height: 1.6; + margin-bottom: var(--spacing-sm); +} + +.dr-result-link { + color: var(--color-primary); + text-decoration: none; + font-size: 13px; + font-weight: 500; +} + +.dr-result-link:hover { + text-decoration: underline; +} + +/* Timestamp */ +.dr-timestamp { + text-align: right; + color: var(--color-text-light); + font-size: 12px; + margin-top: var(--spacing-md); + padding-top: var(--spacing-md); + border-top: 1px dashed var(--color-border-dark); +} + +/* 21. Gradients */ +.dr-gradient-primary { + background: linear-gradient(135deg, var(--color-primary) 0%, var(--color-primary-dark) 100%); +} + +.dr-gradient-header { + background: linear-gradient(90deg, var(--color-primary), var(--color-success), var(--color-warning), var(--color-danger)); + height: 5px; +} + +/* 22. Responsive Design */ +@media (max-width: 768px) { + body { + padding: var(--spacing-sm); + } + + .dr-container { + padding: var(--spacing-sm); + } + + .dr-grid--stats, + .dr-grid--cards, + .dr-grid--metrics { + grid-template-columns: 1fr; + } + + .dr-h1, h1 { + font-size: 1.5em; + } + + .dr-h2, h2 { + font-size: 1.25em; + } + + .dr-stat-value { + font-size: 1.75rem; + } + + .dr-section, + .dr-card { + padding: var(--spacing-md); + } + + .dr-table { + font-size: 13px; + } + + .dr-table th, + .dr-table td { + padding: 8px; + } +} + +/* 23. Print Styles */ +@media print { + body { + background: white; + color: black; + } + + .dr-card, + .dr-section { + box-shadow: none; + border: 1px solid #ddd; + } + + .dr-button { + display: none; + } +} \ No newline at end of file diff --git a/deep_research/configs/balanced_research.yaml b/deep_research/configs/balanced_research.yaml new file mode 100644 index 00000000..4f8bfa23 --- /dev/null +++ b/deep_research/configs/balanced_research.yaml @@ -0,0 +1,79 @@ +# Deep Research Pipeline Configuration - Balanced Mode +enable_cache: true + +# ZenML MCP +model: + name: "deep_research" + description: "Parallelized ZenML pipelines for deep research on a given query." + tags: + [ + "research", + "exa", + "tavily", + "openrouter", + "sambanova", + "langfuse", + "balanced", + ] + use_cases: "Research on a given query." + +# Langfuse project name for LLM tracking +langfuse_project_name: "deep-research" + +# Research parameters for balanced research +parameters: + query: "Default research query" + +steps: + initial_query_decomposition_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + max_sub_questions: 10 # Balanced number of sub-questions + + process_sub_question_step: + parameters: + llm_model_search: "sambanova/Meta-Llama-3.3-70B-Instruct" + llm_model_synthesis: "sambanova/DeepSeek-R1-Distill-Llama-70B" + cap_search_length: 20000 # Standard cap for search length + + cross_viewpoint_analysis_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + viewpoint_categories: + [ + "scientific", + "political", + "economic", + "social", + "ethical", + "historical", + ] # Standard viewpoints + + generate_reflection_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + + get_research_approval_step: + parameters: + timeout: 3600 # 1 hour timeout + max_queries: 2 # Moderate additional queries + + execute_approved_searches_step: + parameters: + llm_model: "sambanova/Meta-Llama-3.3-70B-Instruct" + cap_search_length: 20000 + + pydantic_final_report_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + +# Environment settings +settings: + docker: + requirements: + - openai>=1.0.0 + - tavily-python>=0.2.8 + - PyYAML>=6.0 + - click>=8.0.0 + - pydantic>=2.0.0 + - typing_extensions>=4.0.0 \ No newline at end of file diff --git a/deep_research/configs/deep_research.yaml b/deep_research/configs/deep_research.yaml new file mode 100644 index 00000000..61cc4c2b --- /dev/null +++ b/deep_research/configs/deep_research.yaml @@ -0,0 +1,81 @@ +# Deep Research Pipeline Configuration - Deep Comprehensive Mode +enable_cache: false # Disable cache for fresh comprehensive analysis + +# ZenML MCP +model: + name: "deep_research" + description: "Parallelized ZenML pipelines for deep research on a given query." + tags: + [ + "research", + "exa", + "tavily", + "openrouter", + "sambanova", + "langfuse", + "deep", + ] + use_cases: "Research on a given query." + +# Langfuse project name for LLM tracking +langfuse_project_name: "deep-research" + +# Research parameters for deep comprehensive research +parameters: + query: "Default research query" + +steps: + initial_query_decomposition_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + max_sub_questions: 15 # Maximum sub-questions for comprehensive analysis + + process_sub_question_step: + parameters: + llm_model_search: "sambanova/Meta-Llama-3.3-70B-Instruct" + llm_model_synthesis: "sambanova/DeepSeek-R1-Distill-Llama-70B" + cap_search_length: 30000 # Higher cap for more comprehensive data + + cross_viewpoint_analysis_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + viewpoint_categories: + [ + "scientific", + "political", + "economic", + "social", + "ethical", + "historical", + "technological", + "philosophical", + ] # Extended viewpoints for comprehensive analysis + + generate_reflection_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + + get_research_approval_step: + parameters: + timeout: 7200 # 2 hour timeout for deep research + max_queries: 4 # Maximum additional queries for deep mode + + execute_approved_searches_step: + parameters: + llm_model: "sambanova/Meta-Llama-3.3-70B-Instruct" + cap_search_length: 30000 + + pydantic_final_report_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + +# Environment settings +settings: + docker: + requirements: + - openai>=1.0.0 + - tavily-python>=0.2.8 + - PyYAML>=6.0 + - click>=8.0.0 + - pydantic>=2.0.0 + - typing_extensions>=4.0.0 \ No newline at end of file diff --git a/deep_research/configs/enhanced_research.yaml b/deep_research/configs/enhanced_research.yaml new file mode 100644 index 00000000..0bfc0a79 --- /dev/null +++ b/deep_research/configs/enhanced_research.yaml @@ -0,0 +1,71 @@ +# Enhanced Deep Research Pipeline Configuration +enable_cache: false + +# ZenML MCP +model: + name: "deep_research" + description: "Parallelized ZenML pipelines for deep research on a given query." + tags: + [ + "research", + "exa", + "tavily", + "openrouter", + "sambanova", + "langfuse", + "enhanced", + ] + use_cases: "Research on a given query." + +# Research query parameters +query: "Climate change policy debates" + +# Langfuse project name for LLM tracking +langfuse_project_name: "deep-research" + +# Step configurations +steps: + initial_query_decomposition_step: + parameters: + llm_model: "openrouter/google/gemini-2.0-flash-lite-001" + + cross_viewpoint_analysis_step: + parameters: + llm_model: "openrouter/google/gemini-2.0-flash-lite-001" + viewpoint_categories: + [ + "scientific", + "political", + "economic", + "social", + "ethical", + "historical", + ] + + generate_reflection_step: + parameters: + llm_model: "openrouter/google/gemini-2.0-flash-lite-001" + + get_research_approval_step: + parameters: + timeout: 3600 + max_queries: 2 + + execute_approved_searches_step: + parameters: + llm_model: "openrouter/google/gemini-2.0-flash-lite-001" + + pydantic_final_report_step: + parameters: + llm_model: "openrouter/google/gemini-2.0-flash-lite-001" + +# Environment settings +settings: + docker: + requirements: + - openai>=1.0.0 + - tavily-python>=0.2.8 + - PyYAML>=6.0 + - click>=8.0.0 + - pydantic>=2.0.0 + - typing_extensions>=4.0.0 diff --git a/deep_research/configs/enhanced_research_with_approval.yaml b/deep_research/configs/enhanced_research_with_approval.yaml new file mode 100644 index 00000000..73d6fe42 --- /dev/null +++ b/deep_research/configs/enhanced_research_with_approval.yaml @@ -0,0 +1,77 @@ +# Enhanced Deep Research Pipeline Configuration with Human Approval +enable_cache: false + +# ZenML MCP +model: + name: "deep_research" + description: "Parallelized ZenML pipelines for deep research on a given query." + tags: + [ + "research", + "exa", + "tavily", + "openrouter", + "sambanova", + "langfuse", + "enhanced_approval", + ] + use_cases: "Research on a given query." + +# Langfuse project name for LLM tracking +langfuse_project_name: "deep-research" + +# Research query parameters +query: "Climate change policy debates" + +# Pipeline parameters +parameters: + require_approval: true # Enable human-in-the-loop approval + approval_timeout: 1800 # 30 minutes timeout for approval + max_additional_searches: 3 # Allow up to 3 additional searches + +# Step configurations +steps: + initial_query_decomposition_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + + cross_viewpoint_analysis_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + viewpoint_categories: + [ + "scientific", + "political", + "economic", + "social", + "ethical", + "historical", + ] + + # New reflection steps (replacing iterative_reflection_step) + generate_reflection_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + + get_research_approval_step: + parameters: + alerter_type: "slack" # or "email" if configured + + execute_approved_searches_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + + pydantic_final_report_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + +# Environment settings +settings: + docker: + requirements: + - openai>=1.0.0 + - tavily-python>=0.2.8 + - PyYAML>=6.0 + - click>=8.0.0 + - pydantic>=2.0.0 + - typing_extensions>=4.0.0 \ No newline at end of file diff --git a/deep_research/configs/quick_research.yaml b/deep_research/configs/quick_research.yaml new file mode 100644 index 00000000..b210f18f --- /dev/null +++ b/deep_research/configs/quick_research.yaml @@ -0,0 +1,59 @@ +# Deep Research Pipeline Configuration - Quick Research +enable_cache: true + +# ZenML MCP +model: + name: "deep_research" + description: "Parallelized ZenML pipelines for deep research on a given query." + tags: + [ + "research", + "exa", + "tavily", + "openrouter", + "sambanova", + "langfuse", + "quick", + ] + use_cases: "Research on a given query." + +# Langfuse project name for LLM tracking +langfuse_project_name: "deep-research" + +# Research parameters for quick research +parameters: + query: "Default research query" + +steps: + initial_query_decomposition_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + max_sub_questions: 5 # Limit to fewer sub-questions for quick research + + process_sub_question_step: + parameters: + llm_model_search: "sambanova/Meta-Llama-3.3-70B-Instruct" + llm_model_synthesis: "sambanova/DeepSeek-R1-Distill-Llama-70B" + + generate_reflection_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + + get_research_approval_step: + parameters: + auto_approve: true # Auto-approve for quick research + + execute_approved_searches_step: + parameters: + llm_model: "sambanova/Meta-Llama-3.3-70B-Instruct" + +# Environment settings +settings: + docker: + requirements: + - openai>=1.0.0 + - tavily-python>=0.2.8 + - PyYAML>=6.0 + - click>=8.0.0 + - pydantic>=2.0.0 + - typing_extensions>=4.0.0 diff --git a/deep_research/configs/rapid_research.yaml b/deep_research/configs/rapid_research.yaml new file mode 100644 index 00000000..e69982bf --- /dev/null +++ b/deep_research/configs/rapid_research.yaml @@ -0,0 +1,59 @@ +# Deep Research Pipeline Configuration - Quick Research +enable_cache: true + +# ZenML MCP +model: + name: "deep_research" + description: "Parallelized ZenML pipelines for deep research on a given query." + tags: + [ + "research", + "exa", + "tavily", + "openrouter", + "sambanova", + "langfuse", + "rapid", + ] + use_cases: "Research on a given query." + +# Langfuse project name for LLM tracking +langfuse_project_name: "deep-research" + +# Research parameters for quick research +parameters: + query: "Default research query" + +steps: + initial_query_decomposition_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + max_sub_questions: 5 # Limit to fewer sub-questions for quick research + + process_sub_question_step: + parameters: + llm_model_search: "sambanova/Meta-Llama-3.3-70B-Instruct" + llm_model_synthesis: "sambanova/DeepSeek-R1-Distill-Llama-70B" + + generate_reflection_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + + get_research_approval_step: + parameters: + auto_approve: true # Auto-approve for quick research + + execute_approved_searches_step: + parameters: + llm_model: "sambanova/Meta-Llama-3.3-70B-Instruct" + +# Environment settings +settings: + docker: + requirements: + - openai>=1.0.0 + - tavily-python>=0.2.8 + - PyYAML>=6.0 + - click>=8.0.0 + - pydantic>=2.0.0 + - typing_extensions>=4.0.0 diff --git a/deep_research/logging_config.py b/deep_research/logging_config.py new file mode 100644 index 00000000..2b93c3e0 --- /dev/null +++ b/deep_research/logging_config.py @@ -0,0 +1,42 @@ +import logging +import sys +from typing import Optional + + +def configure_logging( + level: int = logging.INFO, log_file: Optional[str] = None +): + """Configure logging for the application. + + Args: + level: The log level (default: INFO) + log_file: Optional path to a log file + """ + # Create formatter + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + + # Configure root logger + root_logger = logging.getLogger() + root_logger.setLevel(level) + + # Remove existing handlers to avoid duplicate logs + for handler in root_logger.handlers[:]: + root_logger.removeHandler(handler) + + # Console handler + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setFormatter(formatter) + root_logger.addHandler(console_handler) + + # File handler if log_file is provided + if log_file: + file_handler = logging.FileHandler(log_file) + file_handler.setFormatter(formatter) + root_logger.addHandler(file_handler) + + # Reduce verbosity for noisy third-party libraries + logging.getLogger("LiteLLM").setLevel(logging.WARNING) + logging.getLogger("httpx").setLevel(logging.WARNING) + logging.getLogger("urllib3").setLevel(logging.WARNING) diff --git a/deep_research/materializers/__init__.py b/deep_research/materializers/__init__.py new file mode 100644 index 00000000..7009d557 --- /dev/null +++ b/deep_research/materializers/__init__.py @@ -0,0 +1,26 @@ +""" +Materializers package for the ZenML Deep Research project. + +This package contains custom ZenML materializers that handle serialization and +deserialization of complex data types used in the research pipeline. +""" + +from .analysis_data_materializer import AnalysisDataMaterializer +from .approval_decision_materializer import ApprovalDecisionMaterializer +from .final_report_materializer import FinalReportMaterializer +from .prompt_materializer import PromptMaterializer +from .query_context_materializer import QueryContextMaterializer +from .search_data_materializer import SearchDataMaterializer +from .synthesis_data_materializer import SynthesisDataMaterializer +from .tracing_metadata_materializer import TracingMetadataMaterializer + +__all__ = [ + "ApprovalDecisionMaterializer", + "PromptMaterializer", + "TracingMetadataMaterializer", + "QueryContextMaterializer", + "SearchDataMaterializer", + "SynthesisDataMaterializer", + "AnalysisDataMaterializer", + "FinalReportMaterializer", +] diff --git a/deep_research/materializers/analysis_data_materializer.py b/deep_research/materializers/analysis_data_materializer.py new file mode 100644 index 00000000..8b78aa91 --- /dev/null +++ b/deep_research/materializers/analysis_data_materializer.py @@ -0,0 +1,269 @@ +"""Materializer for AnalysisData with viewpoint tension diagrams and reflection insights.""" + +import os +from typing import Dict + +from utils.css_utils import ( + get_card_class, + get_section_class, + get_shared_css_tag, +) +from utils.pydantic_models import AnalysisData +from zenml.enums import ArtifactType, VisualizationType +from zenml.io import fileio +from zenml.materializers import PydanticMaterializer + + +class AnalysisDataMaterializer(PydanticMaterializer): + """Materializer for AnalysisData with viewpoint and reflection visualization.""" + + ASSOCIATED_TYPES = (AnalysisData,) + ASSOCIATED_ARTIFACT_TYPE = ArtifactType.DATA + + def save_visualizations( + self, data: AnalysisData + ) -> Dict[str, VisualizationType]: + """Create and save visualizations for the AnalysisData. + + Args: + data: The AnalysisData to visualize + + Returns: + Dictionary mapping file paths to visualization types + """ + visualization_path = os.path.join(self.uri, "analysis_data.html") + html_content = self._generate_visualization_html(data) + + with fileio.open(visualization_path, "w") as f: + f.write(html_content) + + return {visualization_path: VisualizationType.HTML} + + def _generate_visualization_html(self, data: AnalysisData) -> str: + """Generate HTML visualization for the analysis data. + + Args: + data: The AnalysisData to visualize + + Returns: + HTML string + """ + # Viewpoint analysis section + viewpoint_html = "" + if data.viewpoint_analysis: + va = data.viewpoint_analysis + + # Points of agreement + agreement_html = "" + if va.main_points_of_agreement: + agreement_html = "

Main Points of Agreement

" + + # Areas of tension + tensions_html = "" + if va.areas_of_tension: + tensions_html = ( + "

Areas of Tension

" + ) + for tension in va.areas_of_tension: + viewpoints_html = "" + for perspective, view in tension.viewpoints.items(): + viewpoints_html += f""" +
+
{perspective}
+
{view}
+
+ """ + + tensions_html += f""" +
+

{tension.topic}

+
+ {viewpoints_html} +
+
+ """ + tensions_html += "
" + + # Perspective gaps + gaps_html = "" + if va.perspective_gaps: + gaps_html = f""" +
+

Perspective Gaps

+

{va.perspective_gaps}

+
+ """ + + # Integrative insights + insights_html = "" + if va.integrative_insights: + insights_html = f""" +
+

Integrative Insights

+

{va.integrative_insights}

+
+ """ + + viewpoint_html = f""" +
+

Viewpoint Analysis

+ {agreement_html} + {tensions_html} + {gaps_html} + {insights_html} +
+ """ + + # Reflection metadata section + reflection_html = "" + if data.reflection_metadata: + rm = data.reflection_metadata + + # Critique summary + critique_html = "" + if rm.critique_summary: + critique_html = "

Critique Summary

" + + # Additional questions + questions_html = "" + if rm.additional_questions_identified: + questions_html = "

Additional Questions Identified

" + + # Searches performed + searches_html = "" + if rm.searches_performed: + searches_html = "

Searches Performed

" + + # Error handling + error_html = "" + if rm.error: + error_html = f""" +
+

Error Encountered

+

{rm.error}

+
+ """ + + reflection_html = f""" +
+

Reflection Metadata

+
+ {int(rm.improvements_made)} + Improvements Made +
+ {critique_html} + {questions_html} + {searches_html} + {error_html} +
+ """ + + # Handle empty state + if not viewpoint_html and not reflection_html: + content_html = ( + '
No analysis data available yet
' + ) + else: + content_html = viewpoint_html + reflection_html + + html = f""" + + + + Analysis Data Visualization + {get_shared_css_tag()} + + + +
+
+

Research Analysis

+
+ + {content_html} +
+ + + """ + + return html diff --git a/deep_research/materializers/approval_decision_materializer.py b/deep_research/materializers/approval_decision_materializer.py new file mode 100644 index 00000000..e17b9d9b --- /dev/null +++ b/deep_research/materializers/approval_decision_materializer.py @@ -0,0 +1,241 @@ +"""Materializer for ApprovalDecision with custom visualization.""" + +import os +from datetime import datetime +from typing import Dict + +from utils.css_utils import ( + get_card_class, + get_grid_class, + get_section_class, + get_shared_css_tag, +) +from utils.pydantic_models import ApprovalDecision +from zenml.enums import ArtifactType, VisualizationType +from zenml.io import fileio +from zenml.materializers import PydanticMaterializer + + +class ApprovalDecisionMaterializer(PydanticMaterializer): + """Materializer for the ApprovalDecision class with visualizations.""" + + ASSOCIATED_TYPES = (ApprovalDecision,) + ASSOCIATED_ARTIFACT_TYPE = ArtifactType.DATA + + def save_visualizations( + self, data: ApprovalDecision + ) -> Dict[str, VisualizationType]: + """Create and save visualizations for the ApprovalDecision. + + Args: + data: The ApprovalDecision to visualize + + Returns: + Dictionary mapping file paths to visualization types + """ + # Generate an HTML visualization + visualization_path = os.path.join(self.uri, "approval_decision.html") + + # Create HTML content + html_content = self._generate_visualization_html(data) + + # Write the HTML content to a file + with fileio.open(visualization_path, "w") as f: + f.write(html_content) + + # Return the visualization path and type + return {visualization_path: VisualizationType.HTML} + + def _generate_visualization_html(self, decision: ApprovalDecision) -> str: + """Generate HTML visualization for the approval decision. + + Args: + decision: The ApprovalDecision to visualize + + Returns: + HTML string + """ + # Format timestamp + decision_time = datetime.fromtimestamp(decision.timestamp).strftime( + "%Y-%m-%d %H:%M:%S" + ) + + # Determine status icon and text + if decision.approved: + status_icon = "✅" + status_text = "APPROVED" + else: + status_icon = "❌" + status_text = "NOT APPROVED" + + # Format approval method + method_display = { + "APPROVE_ALL": "Approve All Queries", + "SKIP": "Skip Additional Research", + "SELECT_SPECIFIC": "Select Specific Queries", + }.get(decision.approval_method, decision.approval_method or "Unknown") + + # Build info cards + info_cards_html = f""" +
+
+
Approval Method
+
{method_display}
+
+
+
Decision Time
+
{decision_time}
+
+
+
Queries Selected
+
{len(decision.selected_queries)}
+
+
+ """ + + html = f""" + + + + Approval Decision + {get_shared_css_tag()} + + + +
+
+

+ 🔒 Approval Decision +
+ {status_icon} + {status_text} +
+

+ + {info_cards_html} + """ + + # Add selected queries section if any + if decision.selected_queries: + html += f""" +
+

📋Selected Queries

+
+ """ + + for i, query in enumerate(decision.selected_queries, 1): + html += f""" +
+
{i}
+
{query}
+
+ """ + + html += """ +
+
+ """ + else: + html += f""" +
+

📋Selected Queries

+
+ No queries were selected for additional research +
+
+ """ + + # Add reviewer notes if any + if decision.reviewer_notes: + html += f""" +
+

📝Reviewer Notes

+
+ {decision.reviewer_notes} +
+
+ """ + + # Add timestamp footer + html += f""" +
+ Decision recorded at: {decision_time} +
+
+
+ + + """ + + return html diff --git a/deep_research/materializers/final_report_materializer.py b/deep_research/materializers/final_report_materializer.py new file mode 100644 index 00000000..960e4afc --- /dev/null +++ b/deep_research/materializers/final_report_materializer.py @@ -0,0 +1,235 @@ +"""Materializer for FinalReport with enhanced interactive report visualization.""" + +import os +from datetime import datetime +from typing import Dict + +from utils.css_utils import get_shared_css_tag +from utils.pydantic_models import FinalReport +from zenml.enums import ArtifactType, VisualizationType +from zenml.io import fileio +from zenml.materializers import PydanticMaterializer + + +class FinalReportMaterializer(PydanticMaterializer): + """Materializer for FinalReport with interactive report visualization.""" + + ASSOCIATED_TYPES = (FinalReport,) + ASSOCIATED_ARTIFACT_TYPE = ArtifactType.DATA + + def save_visualizations( + self, data: FinalReport + ) -> Dict[str, VisualizationType]: + """Create and save visualizations for the FinalReport. + + Args: + data: The FinalReport to visualize + + Returns: + Dictionary mapping file paths to visualization types + """ + # Save the actual report + report_path = os.path.join(self.uri, "final_report.html") + with fileio.open(report_path, "w") as f: + f.write(data.report_html) + + # Save a wrapper visualization with metadata + visualization_path = os.path.join( + self.uri, "report_visualization.html" + ) + html_content = self._generate_visualization_html(data) + + with fileio.open(visualization_path, "w") as f: + f.write(html_content) + + return { + report_path: VisualizationType.HTML, + visualization_path: VisualizationType.HTML, + } + + def _generate_visualization_html(self, data: FinalReport) -> str: + """Generate HTML wrapper visualization for the final report. + + Args: + data: The FinalReport to visualize + + Returns: + HTML string + """ + # Format timestamp + timestamp = datetime.fromtimestamp(data.generated_at).strftime( + "%B %d, %Y at %I:%M %p UTC" + ) + + # Extract some statistics from the HTML report if possible + report_length = len(data.report_html) + + html = f""" + + + + Final Research Report - {data.main_query[:50]}... + {get_shared_css_tag()} + + + +
+
+

Final Research Report

+ +
+
Research Query
+
{data.main_query}
+
+ + +
+
+ +
+
+ + Open in New Tab + + +
+ +
+ Loading report... +
+ + +
+ + + + + """ + + return html diff --git a/deep_research/materializers/prompt_materializer.py b/deep_research/materializers/prompt_materializer.py new file mode 100644 index 00000000..4e306419 --- /dev/null +++ b/deep_research/materializers/prompt_materializer.py @@ -0,0 +1,252 @@ +"""Materializer for individual Prompt with custom HTML visualization. + +This module provides a materializer that creates beautiful HTML visualizations +for individual prompts in the ZenML dashboard. +""" + +import os +from typing import Dict + +from utils.css_utils import ( + create_stat_card, + get_card_class, + get_grid_class, + get_shared_css_tag, +) +from utils.pydantic_models import Prompt +from zenml.enums import ArtifactType, VisualizationType +from zenml.io import fileio +from zenml.materializers import PydanticMaterializer + + +class PromptMaterializer(PydanticMaterializer): + """Materializer for Prompt with custom visualization.""" + + ASSOCIATED_TYPES = (Prompt,) + ASSOCIATED_ARTIFACT_TYPE = ArtifactType.DATA + + def save_visualizations( + self, data: Prompt + ) -> Dict[str, VisualizationType]: + """Create and save visualizations for the Prompt. + + Args: + data: The Prompt to visualize + + Returns: + Dictionary mapping file paths to visualization types + """ + # Generate an HTML visualization + visualization_path = os.path.join(self.uri, "prompt.html") + + # Create HTML content + html_content = self._generate_visualization_html(data) + + # Write the HTML content to a file + with fileio.open(visualization_path, "w") as f: + f.write(html_content) + + # Return the visualization path and type + return {visualization_path: VisualizationType.HTML} + + def _generate_visualization_html(self, prompt: Prompt) -> str: + """Generate HTML visualization for a single prompt. + + Args: + prompt: The Prompt to visualize + + Returns: + HTML string + """ + # Create tags HTML + tag_html = "" + if prompt.tags: + tag_html = '
' + for tag in prompt.tags: + tag_html += ( + f'{tag}' + ) + tag_html += "
" + + # Build stats HTML + stats_html = f""" +
+ {create_stat_card(len(prompt.content.split()), "Words")} + {create_stat_card(len(prompt.content), "Characters")} + {create_stat_card(len(prompt.content.splitlines()), "Lines")} +
+ """ + + # Create HTML content + html = f""" + + + + {prompt.name} - Prompt + {get_shared_css_tag()} + + + +
+
+

+ 🎯 {prompt.name} + v{prompt.version} +

+ {f'

{prompt.description}

' if prompt.description else ""} + {tag_html} +
+ + {stats_html} + +
+

📝Prompt Content

+
+ + {self._escape_html(prompt.content)} +
+
+
+ + + + + """ + + return html + + def _escape_html(self, text: str) -> str: + """Escape HTML special characters. + + Args: + text: Text to escape + + Returns: + Escaped text + """ + return ( + text.replace("&", "&") + .replace("<", "<") + .replace(">", ">") + .replace('"', """) + .replace("'", "'") + ) diff --git a/deep_research/materializers/query_context_materializer.py b/deep_research/materializers/query_context_materializer.py new file mode 100644 index 00000000..78bce3ee --- /dev/null +++ b/deep_research/materializers/query_context_materializer.py @@ -0,0 +1,233 @@ +"""Materializer for QueryContext with interactive mind map visualization.""" + +import os +from typing import Dict + +from utils.css_utils import ( + create_stat_card, + get_grid_class, + get_shared_css_tag, +) +from utils.pydantic_models import QueryContext +from zenml.enums import ArtifactType, VisualizationType +from zenml.io import fileio +from zenml.materializers import PydanticMaterializer + + +class QueryContextMaterializer(PydanticMaterializer): + """Materializer for QueryContext with mind map visualization.""" + + ASSOCIATED_TYPES = (QueryContext,) + ASSOCIATED_ARTIFACT_TYPE = ArtifactType.DATA + + def save_visualizations( + self, data: QueryContext + ) -> Dict[str, VisualizationType]: + """Create and save mind map visualization for the QueryContext. + + Args: + data: The QueryContext to visualize + + Returns: + Dictionary mapping file paths to visualization types + """ + visualization_path = os.path.join(self.uri, "query_context.html") + html_content = self._generate_visualization_html(data) + + with fileio.open(visualization_path, "w") as f: + f.write(html_content) + + return {visualization_path: VisualizationType.HTML} + + def _generate_visualization_html(self, context: QueryContext) -> str: + """Generate HTML mind map visualization for the query context. + + Args: + context: The QueryContext to visualize + + Returns: + HTML string + """ + # Create sub-questions HTML + sub_questions_html = "" + if context.sub_questions: + for i, sub_q in enumerate(context.sub_questions, 1): + sub_questions_html += f""" +
+
{i}
+
{sub_q}
+
+ """ + else: + sub_questions_html = ( + '
No sub-questions decomposed yet
' + ) + + # Format timestamp + from datetime import datetime + + timestamp = datetime.fromtimestamp( + context.decomposition_timestamp + ).strftime("%Y-%m-%d %H:%M:%S UTC") + + # Build stats + stats_html = f""" +
+ {create_stat_card(len(context.sub_questions), "Sub-Questions")} + {create_stat_card(len(context.main_query.split()), "Words in Query")} + {create_stat_card(sum(len(q.split()) for q in context.sub_questions), "Total Sub-Question Words")} +
+ """ + + html = f""" + + + + Query Context - {context.main_query[:50]}... + {get_shared_css_tag()} + + + +
+
+

Query Decomposition Mind Map

+
Created: {timestamp}
+
+ +
+
+ {context.main_query} +
+ +
+ {sub_questions_html} +
+
+ + {stats_html} +
+ + + """ + + return html diff --git a/deep_research/materializers/search_data_materializer.py b/deep_research/materializers/search_data_materializer.py new file mode 100644 index 00000000..d2b8f989 --- /dev/null +++ b/deep_research/materializers/search_data_materializer.py @@ -0,0 +1,213 @@ +"""Materializer for SearchData with cost breakdown charts and search results visualization.""" + +import json +import os +from typing import Dict + +from utils.css_utils import ( + create_stat_card, + get_card_class, + get_grid_class, + get_shared_css_tag, + get_table_class, +) +from utils.pydantic_models import SearchData +from zenml.enums import ArtifactType, VisualizationType +from zenml.io import fileio +from zenml.materializers import PydanticMaterializer + + +class SearchDataMaterializer(PydanticMaterializer): + """Materializer for SearchData with interactive visualizations.""" + + ASSOCIATED_TYPES = (SearchData,) + ASSOCIATED_ARTIFACT_TYPE = ArtifactType.DATA + + def save_visualizations( + self, data: SearchData + ) -> Dict[str, VisualizationType]: + """Create and save visualizations for the SearchData. + + Args: + data: The SearchData to visualize + + Returns: + Dictionary mapping file paths to visualization types + """ + visualization_path = os.path.join(self.uri, "search_data.html") + html_content = self._generate_visualization_html(data) + + with fileio.open(visualization_path, "w") as f: + f.write(html_content) + + return {visualization_path: VisualizationType.HTML} + + def _generate_visualization_html(self, data: SearchData) -> str: + """Generate HTML visualization for the search data. + + Args: + data: The SearchData to visualize + + Returns: + HTML string + """ + # Prepare data for charts + cost_data = [ + {"provider": k, "cost": v} for k, v in data.search_costs.items() + ] + + # Create search results HTML + results_html = "" + for sub_q, results in data.search_results.items(): + results_html += f""" +
+

{sub_q}

+
{len(results)} results found
+
+ """ + + for i, result in enumerate(results[:5]): # Show first 5 results + results_html += f""" +
+
{result.title or "Untitled"}
+
{result.snippet or result.content[:200]}...
+ View Source +
+ """ + + if len(results) > 5: + results_html += f'
... and {len(results) - 5} more results
' + + results_html += """ +
+
+ """ + + if not results_html: + results_html = '
No search results yet
' + + # Calculate total cost + total_cost = sum(data.search_costs.values()) + + # Build stats cards + stats_html = f""" +
+ {create_stat_card(data.total_searches, "Total Searches")} + {create_stat_card(len(data.search_results), "Sub-Questions")} + {create_stat_card(sum(len(results) for results in data.search_results.values()), "Total Results")} + {create_stat_card(f"${total_cost:.4f}", "Total Cost")} +
+ """ + + # Build cost table rows + cost_table_rows = "".join( + f""" + + {provider} + ${cost:.4f} + {(cost / total_cost * 100 if total_cost > 0 else 0):.1f}% + + """ + for provider, cost in data.search_costs.items() + ) + + html = f""" + + + + Search Data Visualization + + {get_shared_css_tag()} + + +
+
+

Search Data Analysis

+ {stats_html} +
+ +
+

Cost Analysis

+ +
+ +
+ +
+

Cost Breakdown by Provider

+ + + + + + + + + + {cost_table_rows} + +
ProviderCostPercentage
+
+
+ +
+

Search Results

+ {results_html} +
+
+ + + + + """ + + return html diff --git a/deep_research/materializers/synthesis_data_materializer.py b/deep_research/materializers/synthesis_data_materializer.py new file mode 100644 index 00000000..fb6f63f3 --- /dev/null +++ b/deep_research/materializers/synthesis_data_materializer.py @@ -0,0 +1,254 @@ +"""Materializer for SynthesisData with confidence metrics and synthesis quality visualization.""" + +import os +from typing import Dict + +from utils.css_utils import ( + create_stat_card, + get_card_class, + get_confidence_class, + get_grid_class, + get_shared_css_tag, +) +from utils.pydantic_models import SynthesisData +from zenml.enums import ArtifactType, VisualizationType +from zenml.io import fileio +from zenml.materializers import PydanticMaterializer + + +class SynthesisDataMaterializer(PydanticMaterializer): + """Materializer for SynthesisData with quality metrics visualization.""" + + ASSOCIATED_TYPES = (SynthesisData,) + ASSOCIATED_ARTIFACT_TYPE = ArtifactType.DATA + + def save_visualizations( + self, data: SynthesisData + ) -> Dict[str, VisualizationType]: + """Create and save visualizations for the SynthesisData. + + Args: + data: The SynthesisData to visualize + + Returns: + Dictionary mapping file paths to visualization types + """ + visualization_path = os.path.join(self.uri, "synthesis_data.html") + html_content = self._generate_visualization_html(data) + + with fileio.open(visualization_path, "w") as f: + f.write(html_content) + + return {visualization_path: VisualizationType.HTML} + + def _generate_visualization_html(self, data: SynthesisData) -> str: + """Generate HTML visualization for the synthesis data. + + Args: + data: The SynthesisData to visualize + + Returns: + HTML string + """ + # Count confidence levels + confidence_counts = {"high": 0, "medium": 0, "low": 0} + for info in data.synthesized_info.values(): + confidence_counts[info.confidence_level] += 1 + + # Create synthesis cards HTML + synthesis_html = "" + for sub_q, info in data.synthesized_info.items(): + sources_html = "" + if info.key_sources: + sources_html = "
Key Sources:
" + + gaps_html = "" + if info.information_gaps: + gaps_html = f""" +
+ Information Gaps: +

{info.information_gaps}

+
+ """ + + improvements_html = "" + if info.improvements: + improvements_html = "
Suggested Improvements:
" + + # Check if this has enhanced version + enhanced_badge = "" + enhanced_section = "" + if sub_q in data.enhanced_info: + enhanced_badge = ( + 'Enhanced' + ) + enhanced_info = data.enhanced_info[sub_q] + enhanced_section = f""" +
+

Enhanced Answer

+

{enhanced_info.synthesized_answer}

+
+ Confidence: {enhanced_info.confidence_level.upper()} +
+
+ """ + + synthesis_html += f""" +
+
+

{sub_q}

+ {enhanced_badge} +
+ +
+

Original Synthesis

+

{info.synthesized_answer}

+ +
+ Confidence: {info.confidence_level.upper()} +
+ + {sources_html} + {gaps_html} + {improvements_html} +
+ + {enhanced_section} +
+ """ + + if not synthesis_html: + synthesis_html = ( + '
No synthesis data available yet
' + ) + + # Calculate statistics + total_syntheses = len(data.synthesized_info) + total_enhanced = len(data.enhanced_info) + avg_sources = sum( + len(info.key_sources) for info in data.synthesized_info.values() + ) / max(total_syntheses, 1) + + # Create stats HTML + stats_html = f""" +
+ {create_stat_card(total_syntheses, "Total Syntheses")} + {create_stat_card(total_enhanced, "Enhanced Syntheses")} + {create_stat_card(f"{avg_sources:.1f}", "Avg Sources per Synthesis")} + {create_stat_card(confidence_counts["high"], "High Confidence")} +
+ """ + + html = f""" + + + + Synthesis Data Visualization + + {get_shared_css_tag()} + + + +
+
+

Synthesis Quality Analysis

+
+ + {stats_html} + +
+

Confidence Distribution

+
+ +
+
+ +
+

Synthesized Information

+ {synthesis_html} +
+
+ + + + + """ + + return html diff --git a/deep_research/materializers/tracing_metadata_materializer.py b/deep_research/materializers/tracing_metadata_materializer.py new file mode 100644 index 00000000..bb40e7c9 --- /dev/null +++ b/deep_research/materializers/tracing_metadata_materializer.py @@ -0,0 +1,505 @@ +"""Materializer for TracingMetadata with custom visualization.""" + +import os +from typing import Dict + +from utils.css_utils import ( + create_stat_card, + get_card_class, + get_grid_class, + get_shared_css_tag, + get_table_class, +) +from utils.pydantic_models import TracingMetadata +from zenml.enums import ArtifactType, VisualizationType +from zenml.io import fileio +from zenml.materializers import PydanticMaterializer + + +class TracingMetadataMaterializer(PydanticMaterializer): + """Materializer for the TracingMetadata class with visualizations.""" + + ASSOCIATED_TYPES = (TracingMetadata,) + ASSOCIATED_ARTIFACT_TYPE = ArtifactType.DATA + + def save_visualizations( + self, data: TracingMetadata + ) -> Dict[str, VisualizationType]: + """Create and save visualizations for the TracingMetadata. + + Args: + data: The TracingMetadata to visualize + + Returns: + Dictionary mapping file paths to visualization types + """ + # Generate an HTML visualization + visualization_path = os.path.join(self.uri, "tracing_metadata.html") + + # Create HTML content + html_content = self._generate_visualization_html(data) + + # Write the HTML content to a file + with fileio.open(visualization_path, "w") as f: + f.write(html_content) + + # Return the visualization path and type + return {visualization_path: VisualizationType.HTML} + + def _generate_visualization_html(self, metadata: TracingMetadata) -> str: + """Generate HTML visualization for the tracing metadata. + + Args: + metadata: The TracingMetadata to visualize + + Returns: + HTML string + """ + # Calculate some derived values + avg_cost_per_token = metadata.total_cost / max( + metadata.total_tokens, 1 + ) + + # Build stats cards HTML + stats_html = f""" +
+
+
Pipeline Run
+
{metadata.pipeline_run_name}
+
+ {create_stat_card(f"${metadata.total_cost:.4f}", "LLM Cost")} + {create_stat_card(f"{metadata.total_tokens:,}", "Total Tokens")} + {create_stat_card(metadata.formatted_latency, "Duration")} +
+ """ + + token_stats_html = f""" +
+ {create_stat_card(f"{metadata.total_input_tokens:,}", "Input Tokens")} + {create_stat_card(f"{metadata.total_output_tokens:,}", "Output Tokens")} + {create_stat_card(metadata.observation_count, "Observations")} + {create_stat_card(f"${avg_cost_per_token:.6f}", "Avg Cost per Token")} +
+ """ + + # Base structure for the HTML + html = f""" + + + + Pipeline Tracing Metadata + {get_shared_css_tag()} + + + +
+

Pipeline Tracing Metadata

+ + {stats_html} + +

Token Usage

+ {token_stats_html} + +

Model Usage Breakdown

+ + + + + + + + + + + + """ + + # Add model breakdown + for model in metadata.models_used: + tokens = metadata.model_token_breakdown.get(model, {}) + cost = metadata.cost_breakdown_by_model.get(model, 0.0) + html += f""" + + + + + + + + """ + + html += """ + +
ModelInput TokensOutput TokensTotal TokensCost
{model}{tokens.get("input_tokens", 0):,}{tokens.get("output_tokens", 0):,}{tokens.get("total_tokens", 0):,}${cost:.4f}
+ """ + + # Add prompt-level metrics visualization + if metadata.prompt_metrics: + html += f""" +

Cost Analysis by Prompt Type

+ + + + + +
+

Cost Distribution

+
+ +
+
+ + +
+

Token Usage

+
+ +
+
+ + +

Prompt Type Efficiency

+ + + + + + + + + + + + + + """ + + # Add prompt metrics rows + for metric in metadata.prompt_metrics: + # Format prompt type name nicely + prompt_type_display = metric.prompt_type.replace( + "_", " " + ).title() + html += f""" + + + + + + + + + + """ + + html += """ + +
Prompt TypeTotal CostCallsAvg $/Call% of TotalInput TokensOutput Tokens
{prompt_type_display}${metric.total_cost:.4f}{metric.call_count}${metric.avg_cost_per_call:.4f}{metric.percentage_of_total_cost:.1f}%{metric.input_tokens:,}{metric.output_tokens:,}
+ + + """ + + # Add search cost visualization if available + if metadata.search_costs and any(metadata.search_costs.values()): + total_search_cost = sum(metadata.search_costs.values()) + total_combined_cost = metadata.total_cost + total_search_cost + + # Build search provider cards + search_cards = "" + for provider, cost in metadata.search_costs.items(): + if cost > 0: + query_count = metadata.search_queries_count.get( + provider, 0 + ) + avg_cost_per_query = ( + cost / query_count if query_count > 0 else 0 + ) + search_cards += f""" +
+
{provider.upper()} Search
+
${cost:.4f}
+
+ {query_count} queries • ${avg_cost_per_query:.4f}/query +
+
+ """ + + search_cards += f""" +
+
Total Search Cost
+
${total_search_cost:.4f}
+
+ {sum(metadata.search_queries_count.values())} total queries +
+
+ """ + + html += f""" +

Search Provider Costs

+
+ {search_cards} +
+ +

Combined Cost Summary

+
+
+
LLM Cost
+
${metadata.total_cost:.4f}
+
+ {(metadata.total_cost / total_combined_cost * 100):.1f}% of total +
+
+
+
Search Cost
+
${total_search_cost:.4f}
+
+ {(total_search_cost / total_combined_cost * 100):.1f}% of total +
+
+
+
Total Pipeline Cost
+
${total_combined_cost:.4f}
+
+
+ +
+

Cost Breakdown Chart

+
+ +
+
+ + """ + + # Add trace metadata + if metadata.trace_tags or metadata.trace_metadata: + html += f""" +

Trace Information

+
+ """ + + if metadata.trace_tags: + html += """ +

Tags

+
+ """ + for tag in metadata.trace_tags: + html += ( + f'{tag}' + ) + html += """ +
+ """ + + if metadata.trace_metadata: + html += """ +

Metadata

+
+                """
+                import json
+
+                html += json.dumps(metadata.trace_metadata, indent=2)
+                html += """
+                    
+ """ + + html += """ +
+ """ + + # Add footer with collection info + from datetime import datetime + + collection_time = datetime.fromtimestamp( + metadata.collected_at + ).strftime("%Y-%m-%d %H:%M:%S") + + html += f""" +
+

Trace ID: {metadata.trace_id}

+

Pipeline Run ID: {metadata.pipeline_run_id}

+

Collected at: {collection_time}

+
+
+ + + """ + + return html diff --git a/deep_research/pipelines/__init__.py b/deep_research/pipelines/__init__.py new file mode 100644 index 00000000..7f4ea5eb --- /dev/null +++ b/deep_research/pipelines/__init__.py @@ -0,0 +1,11 @@ +""" +Pipelines package for the ZenML Deep Research project. + +This package contains the ZenML pipeline definitions for running deep research +workflows. Each pipeline orchestrates a sequence of steps for comprehensive +research on a given query topic. +""" + +from .parallel_research_pipeline import parallelized_deep_research_pipeline + +__all__ = ["parallelized_deep_research_pipeline"] diff --git a/deep_research/pipelines/parallel_research_pipeline.py b/deep_research/pipelines/parallel_research_pipeline.py new file mode 100644 index 00000000..fabdc204 --- /dev/null +++ b/deep_research/pipelines/parallel_research_pipeline.py @@ -0,0 +1,164 @@ +from steps.approval_step import get_research_approval_step +from steps.collect_tracing_metadata_step import collect_tracing_metadata_step +from steps.cross_viewpoint_step import cross_viewpoint_analysis_step +from steps.execute_approved_searches_step import execute_approved_searches_step +from steps.generate_reflection_step import generate_reflection_step +from steps.initialize_prompts_step import initialize_prompts_step +from steps.merge_results_step import merge_sub_question_results_step +from steps.process_sub_question_step import process_sub_question_step +from steps.pydantic_final_report_step import pydantic_final_report_step +from steps.query_decomposition_step import initial_query_decomposition_step +from zenml import pipeline + + +@pipeline(enable_cache=False) +def parallelized_deep_research_pipeline( + query: str = "What is ZenML?", + max_sub_questions: int = 10, + require_approval: bool = False, + approval_timeout: int = 3600, + max_additional_searches: int = 2, + search_provider: str = "tavily", + search_mode: str = "auto", + num_results_per_search: int = 3, + langfuse_project_name: str = "deep-research", +) -> None: + """Parallelized ZenML pipeline for deep research on a given query. + + This pipeline uses the fan-out/fan-in pattern for parallel processing of sub-questions, + potentially improving execution time when using distributed orchestrators. + + Args: + query: The research query/topic + max_sub_questions: Maximum number of sub-questions to process in parallel + require_approval: Whether to require human approval for additional searches + approval_timeout: Timeout in seconds for human approval + max_additional_searches: Maximum number of additional searches to perform + search_provider: Search provider to use (tavily, exa, or both) + search_mode: Search mode for Exa provider (neural, keyword, or auto) + num_results_per_search: Number of search results to return per query + langfuse_project_name: Langfuse project name for LLM tracking + + Returns: + Formatted research report as HTML + """ + # Initialize individual prompts for tracking + ( + search_query_prompt, + query_decomposition_prompt, + synthesis_prompt, + viewpoint_analysis_prompt, + reflection_prompt, + additional_synthesis_prompt, + conclusion_generation_prompt, + executive_summary_prompt, + introduction_prompt, + ) = initialize_prompts_step(pipeline_version="1.0.0") + + # Step 1: Decompose the query into sub-questions, limiting to max_sub_questions + query_context = initial_query_decomposition_step( + main_query=query, + query_decomposition_prompt=query_decomposition_prompt, + max_sub_questions=max_sub_questions, + langfuse_project_name=langfuse_project_name, + ) + + # Fan out: Process each sub-question in parallel + # Collect step names to establish dependencies for the merge step + parallel_step_names = [] + for i in range(max_sub_questions): + # Process the i-th sub-question (if it exists) + step_name = f"process_question_{i + 1}" + search_data, synthesis_data = process_sub_question_step( + query_context=query_context, + search_query_prompt=search_query_prompt, + synthesis_prompt=synthesis_prompt, + question_index=i, + search_provider=search_provider, + search_mode=search_mode, + num_results_per_search=num_results_per_search, + langfuse_project_name=langfuse_project_name, + id=step_name, + after="initial_query_decomposition_step", + ) + parallel_step_names.append(step_name) + + # Fan in: Merge results from all parallel processing + # The 'after' parameter ensures this step runs after all processing steps + merged_search_data, merged_synthesis_data = ( + merge_sub_question_results_step( + step_prefix="process_question_", + after=parallel_step_names, # Wait for all parallel steps to complete + ) + ) + + # Continue with subsequent steps + analysis_data = cross_viewpoint_analysis_step( + query_context=query_context, + synthesis_data=merged_synthesis_data, + viewpoint_analysis_prompt=viewpoint_analysis_prompt, + langfuse_project_name=langfuse_project_name, + after="merge_sub_question_results_step", + ) + + # New 3-step reflection flow with optional human approval + # Step 1: Generate reflection and recommendations (no searches yet) + analysis_with_reflection, recommended_queries = generate_reflection_step( + query_context=query_context, + synthesis_data=merged_synthesis_data, + analysis_data=analysis_data, + reflection_prompt=reflection_prompt, + langfuse_project_name=langfuse_project_name, + after="cross_viewpoint_analysis_step", + ) + + # Step 2: Get approval for recommended searches + approval_decision = get_research_approval_step( + query_context=query_context, + synthesis_data=merged_synthesis_data, + analysis_data=analysis_with_reflection, + recommended_queries=recommended_queries, + require_approval=require_approval, + timeout=approval_timeout, + max_queries=max_additional_searches, + after="generate_reflection_step", + ) + + # Step 3: Execute approved searches (if any) + enhanced_search_data, enhanced_synthesis_data, enhanced_analysis_data = ( + execute_approved_searches_step( + query_context=query_context, + search_data=merged_search_data, + synthesis_data=merged_synthesis_data, + analysis_data=analysis_with_reflection, + recommended_queries=recommended_queries, + approval_decision=approval_decision, + additional_synthesis_prompt=additional_synthesis_prompt, + search_provider=search_provider, + search_mode=search_mode, + num_results_per_search=num_results_per_search, + langfuse_project_name=langfuse_project_name, + after="get_research_approval_step", + ) + ) + + # Use our new Pydantic-based final report step + pydantic_final_report_step( + query_context=query_context, + search_data=enhanced_search_data, + synthesis_data=enhanced_synthesis_data, + analysis_data=enhanced_analysis_data, + conclusion_generation_prompt=conclusion_generation_prompt, + executive_summary_prompt=executive_summary_prompt, + introduction_prompt=introduction_prompt, + langfuse_project_name=langfuse_project_name, + after="execute_approved_searches_step", + ) + + # Collect tracing metadata for the entire pipeline run + collect_tracing_metadata_step( + query_context=query_context, + search_data=enhanced_search_data, + langfuse_project_name=langfuse_project_name, + after="pydantic_final_report_step", + ) diff --git a/deep_research/requirements.txt b/deep_research/requirements.txt new file mode 100644 index 00000000..6d2f9166 --- /dev/null +++ b/deep_research/requirements.txt @@ -0,0 +1,10 @@ +zenml>=0.82.0 +litellm>=1.70.0,<2.0.0 +tavily-python>=0.2.8 +exa-py>=1.0.0 +PyYAML>=6.0 +click>=8.0.0 +pydantic>=2.0.0 +typing_extensions>=4.0.0 +requests +langfuse>=2.0.0 diff --git a/deep_research/run.py b/deep_research/run.py new file mode 100644 index 00000000..8db45164 --- /dev/null +++ b/deep_research/run.py @@ -0,0 +1,334 @@ +import logging +import os + +import click +import yaml +from logging_config import configure_logging +from pipelines.parallel_research_pipeline import ( + parallelized_deep_research_pipeline, +) +from utils.config_utils import check_required_env_vars + +logger = logging.getLogger(__name__) + + +# Research mode presets for easy configuration +RESEARCH_MODES = { + "rapid": { + "max_sub_questions": 5, + "num_results_per_search": 2, + "max_additional_searches": 0, + "description": "Quick research with minimal depth - great for getting a fast overview", + }, + "balanced": { + "max_sub_questions": 10, + "num_results_per_search": 3, + "max_additional_searches": 2, + "description": "Balanced research with moderate depth - ideal for most use cases", + }, + "deep": { + "max_sub_questions": 15, + "num_results_per_search": 5, + "max_additional_searches": 4, + "description": "Comprehensive research with maximum depth - for thorough analysis", + "suggest_approval": True, # Suggest using approval for deep mode + }, +} + + +@click.command( + help=""" +Deep Research Agent - ZenML Pipeline for Comprehensive Research + +Run a deep research pipeline that: +1. Generates a structured report outline +2. Researches each topic with web searches and LLM analysis +3. Refines content through multiple reflection cycles +4. Produces a formatted, comprehensive research report + +Examples: + + \b + # Run with default configuration + python run.py + + \b + # Use a research mode preset for easy configuration + python run.py --mode rapid # Quick overview + python run.py --mode balanced # Standard research (default) + python run.py --mode deep # Comprehensive analysis + + \b + # Run with a custom pipeline configuration file + python run.py --config configs/custom_pipeline.yaml + + \b + # Override the research query + python run.py --query "My research topic" + + \b + # Combine mode with other options + python run.py --mode deep --query "Complex topic" --require-approval + + \b + # Run with a custom number of sub-questions + python run.py --max-sub-questions 15 +""" +) +@click.option( + "--mode", + type=click.Choice(["rapid", "balanced", "deep"], case_sensitive=False), + default=None, + help="Research mode preset: rapid (fast overview), balanced (standard), or deep (comprehensive)", +) +@click.option( + "--config", + type=str, + default="configs/enhanced_research.yaml", + help="Path to the pipeline configuration YAML file", +) +@click.option( + "--no-cache", + is_flag=True, + default=False, + help="Disable caching for the pipeline run", +) +@click.option( + "--log-file", + type=str, + default=None, + help="Path to log file (if not provided, logs only go to console)", +) +@click.option( + "--debug", + is_flag=True, + default=False, + help="Enable debug logging", +) +@click.option( + "--query", + type=str, + default=None, + help="Research query (overrides the query in the config file)", +) +@click.option( + "--max-sub-questions", + type=int, + default=10, + help="Maximum number of sub-questions to process in parallel", +) +@click.option( + "--require-approval", + is_flag=True, + default=False, + help="Enable human-in-the-loop approval for additional searches", +) +@click.option( + "--approval-timeout", + type=int, + default=3600, + help="Timeout in seconds for human approval (default: 3600)", +) +@click.option( + "--search-provider", + type=click.Choice(["tavily", "exa", "both"], case_sensitive=False), + default="tavily", + help="Search provider to use: tavily (default), exa, or both", +) +@click.option( + "--search-mode", + type=click.Choice(["neural", "keyword", "auto"], case_sensitive=False), + default="auto", + help="Search mode for Exa provider: neural, keyword, or auto (default: auto)", +) +@click.option( + "--num-results", + type=int, + default=3, + help="Number of search results to return per query (default: 3)", +) +def main( + mode, + config, + no_cache, + log_file, + debug, + query, + max_sub_questions, + require_approval, + approval_timeout, + search_provider, + search_mode, + num_results, +): + """Run the deep research pipeline. + + Args: + mode: Research mode preset (rapid, balanced, or deep) + config: Path to the pipeline configuration YAML file + no_cache: Disable caching for the pipeline run + log_file: Path to log file + debug: Enable debug logging + query: Research query (overrides the query in the config file) + max_sub_questions: Maximum number of sub-questions to process in parallel + require_approval: Enable human-in-the-loop approval for additional searches + approval_timeout: Timeout in seconds for human approval + search_provider: Search provider to use (tavily, exa, or both) + search_mode: Search mode for Exa provider (neural, keyword, or auto) + num_results: Number of search results to return per query + """ + # Configure logging + log_level = logging.DEBUG if debug else logging.INFO + configure_logging(level=log_level, log_file=log_file) + + # Apply mode presets if specified + if mode: + mode_config = RESEARCH_MODES[mode.lower()] + logger.info("\n" + "=" * 80) + logger.info(f"Using research mode: {mode.upper()}") + logger.info(f"Description: {mode_config['description']}") + + # Apply mode parameters (can be overridden by explicit arguments) + if ( + max_sub_questions + == click.get_current_context().params["max_sub_questions"] + ): + # Default value - apply mode preset + max_sub_questions = mode_config["max_sub_questions"] + logger.info(f" - Max sub-questions: {max_sub_questions}") + + # Store mode config for later use + mode_max_additional_searches = mode_config["max_additional_searches"] + + # Use mode's num_results_per_search only if user didn't override with --num-results + if num_results == 3: # Default value - apply mode preset + num_results = mode_config["num_results_per_search"] + + logger.info( + f" - Max additional searches: {mode_max_additional_searches}" + ) + logger.info(f" - Results per search: {num_results}") + + # Check if a mode-specific config exists and user didn't override config + if config == "configs/enhanced_research.yaml": # Default config + mode_specific_config = f"configs/{mode.lower()}_research.yaml" + if os.path.exists(mode_specific_config): + config = mode_specific_config + logger.info(f" - Using mode-specific config: {config}") + + # Suggest approval for deep mode if not already enabled + if mode_config.get("suggest_approval") and not require_approval: + logger.info(f"\n{'!' * 60}") + logger.info( + f"! TIP: Consider using --require-approval with {mode} mode" + ) + logger.info(f"! for better control over comprehensive research") + logger.info(f"{'!' * 60}") + + logger.info(f"{'=' * 80}\n") + else: + # Default values if no mode specified + mode_max_additional_searches = 2 + + # Check that required environment variables are present using the helper function + required_vars = ["SAMBANOVA_API_KEY"] + + # Add provider-specific API key requirements + if search_provider in {"exa", "both"}: + required_vars.append("EXA_API_KEY") + if search_provider in {"tavily", "both", None}: # Default is tavily + required_vars.append("TAVILY_API_KEY") + + if missing_vars := check_required_env_vars(required_vars): + logger.error( + f"The following required environment variables are not set: {', '.join(missing_vars)}" + ) + logger.info("Please set them with:") + for var in missing_vars: + logger.info(f" export {var}=your_{var.lower()}_here") + return + + # Set pipeline options + pipeline_options = {"config_path": config} + + if no_cache: + pipeline_options["enable_cache"] = False + + logger.info("\n" + "=" * 80) + logger.info("Starting Deep Research") + logger.info("Using parallel pipeline for efficient execution") + + # Log search provider settings + if search_provider: + logger.info(f"Search provider: {search_provider.upper()}") + if search_provider == "exa": + logger.info(f" - Search mode: {search_mode}") + elif search_provider == "both": + logger.info(f" - Running both Tavily and Exa searches") + logger.info(f" - Exa search mode: {search_mode}") + else: + logger.info("Search provider: TAVILY (default)") + + # Log num_results if custom value or no mode preset + if num_results != 3 or not mode: + logger.info(f"Results per search: {num_results}") + + langfuse_project_name = "deep-research" # default + try: + with open(config, "r") as f: + config_data = yaml.safe_load(f) + langfuse_project_name = config_data.get( + "langfuse_project_name", "deep-research" + ) + except Exception as e: + logger.warning( + f"Could not load langfuse_project_name from config: {e}" + ) + + # Set up the pipeline with the parallelized version as default + pipeline = parallelized_deep_research_pipeline.with_options( + **pipeline_options + ) + + # Execute the pipeline + if query: + logger.info( + f"Using query: {query} with max {max_sub_questions} parallel sub-questions" + ) + if require_approval: + logger.info( + f"Human approval enabled with {approval_timeout}s timeout" + ) + pipeline( + query=query, + max_sub_questions=max_sub_questions, + require_approval=require_approval, + approval_timeout=approval_timeout, + max_additional_searches=mode_max_additional_searches, + search_provider=search_provider or "tavily", + search_mode=search_mode, + num_results_per_search=num_results, + langfuse_project_name=langfuse_project_name, + ) + else: + logger.info( + f"Using query from config file with max {max_sub_questions} parallel sub-questions" + ) + if require_approval: + logger.info( + f"Human approval enabled with {approval_timeout}s timeout" + ) + pipeline( + max_sub_questions=max_sub_questions, + require_approval=require_approval, + approval_timeout=approval_timeout, + max_additional_searches=mode_max_additional_searches, + search_provider=search_provider or "tavily", + search_mode=search_mode, + num_results_per_search=num_results, + langfuse_project_name=langfuse_project_name, + ) + + +if __name__ == "__main__": + main() diff --git a/deep_research/steps/__init__.py b/deep_research/steps/__init__.py new file mode 100644 index 00000000..1d454e49 --- /dev/null +++ b/deep_research/steps/__init__.py @@ -0,0 +1,7 @@ +""" +Steps package for the ZenML Deep Research project. + +This package contains individual ZenML steps used in the research pipelines. +Each step is responsible for a specific part of the research process, such as +query decomposition, searching, synthesis, and report generation. +""" diff --git a/deep_research/steps/approval_step.py b/deep_research/steps/approval_step.py new file mode 100644 index 00000000..93566049 --- /dev/null +++ b/deep_research/steps/approval_step.py @@ -0,0 +1,360 @@ +import logging +import time +from typing import Annotated, List + +from materializers.approval_decision_materializer import ( + ApprovalDecisionMaterializer, +) +from utils.approval_utils import ( + format_approval_request, +) +from utils.pydantic_models import ( + AnalysisData, + ApprovalDecision, + QueryContext, + SynthesisData, +) +from zenml import log_metadata, step +from zenml.client import Client + +logger = logging.getLogger(__name__) + + +def summarize_research_progress_from_artifacts( + synthesis_data: SynthesisData, analysis_data: AnalysisData +) -> dict: + """Summarize research progress from the new artifact structure.""" + completed_count = len(synthesis_data.synthesized_info) + + # Calculate confidence levels from synthesis data + confidence_levels = [] + for info in synthesis_data.synthesized_info.values(): + confidence_levels.append(info.confidence_level) + + # Calculate average confidence (high=1.0, medium=0.5, low=0.0) + confidence_map = {"high": 1.0, "medium": 0.5, "low": 0.0} + avg_confidence = sum( + confidence_map.get(c.lower(), 0.5) for c in confidence_levels + ) / max(len(confidence_levels), 1) + + low_confidence_count = sum( + 1 for c in confidence_levels if c.lower() == "low" + ) + + return { + "completed_count": completed_count, + "avg_confidence": round(avg_confidence, 2), + "low_confidence_count": low_confidence_count, + } + + +@step( + enable_cache=False, + output_materializers={"approval_decision": ApprovalDecisionMaterializer}, +) # Never cache approval decisions +def get_research_approval_step( + query_context: QueryContext, + synthesis_data: SynthesisData, + analysis_data: AnalysisData, + recommended_queries: List[str], + require_approval: bool = True, + alerter_type: str = "slack", + timeout: int = 3600, + max_queries: int = 2, +) -> Annotated[ApprovalDecision, "approval_decision"]: + """ + Get human approval for additional research queries. + + Always returns an ApprovalDecision object. If require_approval is False, + automatically approves all queries. + + Args: + query_context: Context containing the main query and sub-questions + synthesis_data: Synthesized information from research + analysis_data: Analysis including viewpoints and critique + recommended_queries: List of recommended additional queries + require_approval: Whether to require human approval + alerter_type: Type of alerter to use (slack, email, etc.) + timeout: Timeout in seconds for approval response + max_queries: Maximum number of queries to approve + + Returns: + ApprovalDecision object with approval status and selected queries + """ + start_time = time.time() + + # Limit queries to max_queries + limited_queries = recommended_queries[:max_queries] + + # If approval not required, auto-approve all + if not require_approval: + logger.info( + f"Auto-approving {len(limited_queries)} recommended queries (approval not required)" + ) + + # Log metadata for auto-approval + execution_time = time.time() - start_time + log_metadata( + metadata={ + "approval_decision": { + "execution_time_seconds": execution_time, + "approval_required": False, + "approval_method": "AUTO_APPROVED", + "num_queries_recommended": len(recommended_queries), + "num_queries_approved": len(limited_queries), + "max_queries_allowed": max_queries, + "approval_status": "approved", + "wait_time_seconds": 0, + } + } + ) + + return ApprovalDecision( + approved=True, + selected_queries=limited_queries, + approval_method="AUTO_APPROVED", + reviewer_notes="Approval not required by configuration", + ) + + # If no queries to approve, skip + if not limited_queries: + logger.info("No additional queries recommended") + + # Log metadata for no queries + execution_time = time.time() - start_time + log_metadata( + metadata={ + "approval_decision": { + "execution_time_seconds": execution_time, + "approval_required": require_approval, + "approval_method": "NO_QUERIES", + "num_queries_recommended": 0, + "num_queries_approved": 0, + "max_queries_allowed": max_queries, + "approval_status": "skipped", + "wait_time_seconds": 0, + } + } + ) + + return ApprovalDecision( + approved=False, + selected_queries=[], + approval_method="NO_QUERIES", + reviewer_notes="No additional queries recommended", + ) + + # Prepare approval request + progress_summary = summarize_research_progress_from_artifacts( + synthesis_data, analysis_data + ) + + # Extract critique points from analysis data + critique_points = [] + if analysis_data.critique_summary: + # Convert critique summary to list of dicts for compatibility + for i, critique in enumerate( + analysis_data.critique_summary.split("\n") + ): + if critique.strip(): + critique_points.append( + { + "issue": critique.strip(), + "importance": "high" if i < 3 else "medium", + } + ) + + message = format_approval_request( + main_query=query_context.main_query, + progress_summary=progress_summary, + critique_points=critique_points, + proposed_queries=limited_queries, + timeout=timeout, + ) + + # Log the approval request for visibility + logger.info("=" * 80) + logger.info("APPROVAL REQUEST:") + logger.info(message) + logger.info("=" * 80) + + try: + # Get the alerter from the active stack + client = Client() + alerter = client.active_stack.alerter + + if not alerter: + logger.warning("No alerter configured in stack, auto-approving") + + # Log metadata for no alerter scenario + execution_time = time.time() - start_time + log_metadata( + metadata={ + "approval_decision": { + "execution_time_seconds": execution_time, + "approval_required": require_approval, + "approval_method": "NO_ALERTER_AUTO_APPROVED", + "alerter_type": "none", + "num_queries_recommended": len(recommended_queries), + "num_queries_approved": len(limited_queries), + "max_queries_allowed": max_queries, + "approval_status": "auto_approved", + "wait_time_seconds": 0, + } + } + ) + + return ApprovalDecision( + approved=True, + selected_queries=limited_queries, + approval_method="NO_ALERTER_AUTO_APPROVED", + reviewer_notes="No alerter configured - auto-approved", + ) + + # Use the alerter's ask method for interactive approval + try: + # Send the message to Discord and wait for response + logger.info( + f"Sending approval request to {alerter.flavor} alerter" + ) + + # Format message for Discord (Discord has message length limits) + discord_message = ( + f"**Research Approval Request**\n\n{message[:1900]}" + ) + if len(message) > 1900: + discord_message += ( + "\n\n*(Message truncated due to Discord limits)*" + ) + + # Add instructions for Discord responses + discord_message += "\n\n**How to respond:**\n" + discord_message += "✅ Type `yes`, `approve`, `ok`, or `LGTM` to approve ALL queries\n" + discord_message += "❌ Type `no`, `skip`, `reject`, or `decline` to skip additional research\n" + discord_message += f"⏱️ Response timeout: {timeout} seconds" + + # Use the ask method to get user response + logger.info("Waiting for approval response from Discord...") + wait_start_time = time.time() + approved = alerter.ask(discord_message) + wait_end_time = time.time() + wait_time = wait_end_time - wait_start_time + + logger.info( + f"Received Discord response: {'approved' if approved else 'rejected'}" + ) + + if approved: + # Log metadata for approved decision + execution_time = time.time() - start_time + log_metadata( + metadata={ + "approval_decision": { + "execution_time_seconds": execution_time, + "approval_required": require_approval, + "approval_method": "DISCORD_APPROVED", + "alerter_type": alerter_type, + "num_queries_recommended": len( + recommended_queries + ), + "num_queries_approved": len(limited_queries), + "max_queries_allowed": max_queries, + "approval_status": "approved", + "wait_time_seconds": wait_time, + "timeout_configured": timeout, + } + } + ) + + return ApprovalDecision( + approved=True, + selected_queries=limited_queries, + approval_method="DISCORD_APPROVED", + reviewer_notes="Approved via Discord", + ) + else: + # Log metadata for rejected decision + execution_time = time.time() - start_time + log_metadata( + metadata={ + "approval_decision": { + "execution_time_seconds": execution_time, + "approval_required": require_approval, + "approval_method": "DISCORD_REJECTED", + "alerter_type": alerter_type, + "num_queries_recommended": len( + recommended_queries + ), + "num_queries_approved": 0, + "max_queries_allowed": max_queries, + "approval_status": "rejected", + "wait_time_seconds": wait_time, + "timeout_configured": timeout, + } + } + ) + + return ApprovalDecision( + approved=False, + selected_queries=[], + approval_method="DISCORD_REJECTED", + reviewer_notes="Rejected via Discord", + ) + + except Exception as e: + logger.error(f"Failed to get approval from alerter: {e}") + + # Log metadata for alerter error + execution_time = time.time() - start_time + log_metadata( + metadata={ + "approval_decision": { + "execution_time_seconds": execution_time, + "approval_required": require_approval, + "approval_method": "ALERTER_ERROR", + "alerter_type": alerter_type, + "num_queries_recommended": len(recommended_queries), + "num_queries_approved": 0, + "max_queries_allowed": max_queries, + "approval_status": "error", + "error_message": str(e), + } + } + ) + + return ApprovalDecision( + approved=False, + selected_queries=[], + approval_method="ALERTER_ERROR", + reviewer_notes=f"Failed to get approval: {str(e)}", + ) + + except Exception as e: + logger.error(f"Approval step failed: {e}") + + # Log metadata for general error + execution_time = time.time() - start_time + log_metadata( + metadata={ + "approval_decision": { + "execution_time_seconds": execution_time, + "approval_required": require_approval, + "approval_method": "ERROR", + "num_queries_recommended": len(recommended_queries), + "num_queries_approved": 0, + "max_queries_allowed": max_queries, + "approval_status": "error", + "error_message": str(e), + } + } + ) + + # Add tag to the approval decision artifact + # add_tags(tags=["hitl"], artifact_name="approval_decision", infer_artifact=True) + + return ApprovalDecision( + approved=False, + selected_queries=[], + approval_method="ERROR", + reviewer_notes=f"Approval failed: {str(e)}", + ) diff --git a/deep_research/steps/collect_tracing_metadata_step.py b/deep_research/steps/collect_tracing_metadata_step.py new file mode 100644 index 00000000..bfaa70e8 --- /dev/null +++ b/deep_research/steps/collect_tracing_metadata_step.py @@ -0,0 +1,252 @@ +"""Step to collect tracing metadata from Langfuse for the pipeline run.""" + +import logging +from typing import Annotated, Dict + +from materializers.tracing_metadata_materializer import ( + TracingMetadataMaterializer, +) +from utils.pydantic_models import ( + PromptTypeMetrics, + QueryContext, + SearchData, + TracingMetadata, +) +from utils.tracing_metadata_utils import ( + get_observations_for_trace, + get_prompt_type_statistics, + get_trace_stats, + get_traces_by_name, +) +from zenml import get_step_context, step + +logger = logging.getLogger(__name__) + + +@step( + enable_cache=False, + output_materializers={ + "tracing_metadata": TracingMetadataMaterializer, + }, +) +def collect_tracing_metadata_step( + query_context: QueryContext, + search_data: SearchData, + langfuse_project_name: str, +) -> Annotated[TracingMetadata, "tracing_metadata"]: + """Collect tracing metadata from Langfuse for the current pipeline run. + + This step gathers comprehensive metrics about token usage, costs, and performance + for the entire pipeline run, providing insights into resource consumption. + + Args: + query_context: The query context (for reference) + search_data: The search data containing cost information + langfuse_project_name: Langfuse project name for accessing traces + + Returns: + TracingMetadata with comprehensive cost and performance metrics + """ + ctx = get_step_context() + pipeline_run_name = ctx.pipeline_run.name + pipeline_run_id = str(ctx.pipeline_run.id) + + logger.info( + f"Collecting tracing metadata for pipeline run: {pipeline_run_name} (ID: {pipeline_run_id})" + ) + + # Initialize the metadata object + metadata = TracingMetadata( + pipeline_run_name=pipeline_run_name, + pipeline_run_id=pipeline_run_id, + trace_name=pipeline_run_name, + trace_id=pipeline_run_id, + ) + + try: + # Fetch the trace for this pipeline run + # The trace_name is the pipeline run name + traces = get_traces_by_name(name=pipeline_run_name, limit=1) + + if not traces: + logger.warning( + f"No trace found for pipeline run: {pipeline_run_name}" + ) + # Still add search costs before returning + _add_search_costs_to_metadata(metadata, search_data) + return metadata + + trace = traces[0] + + # Get comprehensive trace stats + trace_stats = get_trace_stats(trace) + + # Update metadata with trace stats + metadata.trace_id = trace.id + metadata.total_cost = trace_stats["total_cost"] + metadata.total_input_tokens = trace_stats["input_tokens"] + metadata.total_output_tokens = trace_stats["output_tokens"] + metadata.total_tokens = ( + trace_stats["input_tokens"] + trace_stats["output_tokens"] + ) + metadata.total_latency_seconds = trace_stats["latency_seconds"] + metadata.formatted_latency = trace_stats["latency_formatted"] + metadata.observation_count = trace_stats["observation_count"] + metadata.models_used = trace_stats["models_used"] + metadata.trace_tags = trace_stats.get("tags", []) + metadata.trace_metadata = trace_stats.get("metadata", {}) + + # Get model-specific breakdown + observations = get_observations_for_trace(trace_id=trace.id) + model_costs = {} + model_tokens = {} + step_costs = {} + step_tokens = {} + + for obs in observations: + if obs.model: + # Track by model + if obs.model not in model_costs: + model_costs[obs.model] = 0.0 + model_tokens[obs.model] = { + "input_tokens": 0, + "output_tokens": 0, + "total_tokens": 0, + } + + if obs.calculated_total_cost: + model_costs[obs.model] += obs.calculated_total_cost + + if obs.usage: + input_tokens = obs.usage.input or 0 + output_tokens = obs.usage.output or 0 + model_tokens[obs.model]["input_tokens"] += input_tokens + model_tokens[obs.model]["output_tokens"] += output_tokens + model_tokens[obs.model]["total_tokens"] += ( + input_tokens + output_tokens + ) + + # Track by step (using observation name as step indicator) + if obs.name: + step_name = obs.name + + if step_name not in step_costs: + step_costs[step_name] = 0.0 + step_tokens[step_name] = { + "input_tokens": 0, + "output_tokens": 0, + } + + if obs.calculated_total_cost: + step_costs[step_name] += obs.calculated_total_cost + + if obs.usage: + input_tokens = obs.usage.input or 0 + output_tokens = obs.usage.output or 0 + step_tokens[step_name]["input_tokens"] += input_tokens + step_tokens[step_name]["output_tokens"] += output_tokens + + metadata.cost_breakdown_by_model = model_costs + metadata.model_token_breakdown = model_tokens + metadata.step_costs = step_costs + metadata.step_tokens = step_tokens + + # Collect prompt-level metrics + try: + prompt_stats = get_prompt_type_statistics(trace_id=trace.id) + + # Convert to PromptTypeMetrics objects + prompt_metrics_list = [] + for prompt_type, stats in prompt_stats.items(): + prompt_metrics = PromptTypeMetrics( + prompt_type=prompt_type, + total_cost=stats["cost"], + input_tokens=stats["input_tokens"], + output_tokens=stats["output_tokens"], + call_count=stats["count"], + avg_cost_per_call=stats["avg_cost_per_call"], + percentage_of_total_cost=stats["percentage_of_total_cost"], + ) + prompt_metrics_list.append(prompt_metrics) + + # Sort by total cost descending + prompt_metrics_list.sort(key=lambda x: x.total_cost, reverse=True) + metadata.prompt_metrics = prompt_metrics_list + + logger.info( + f"Collected prompt-level metrics for {len(prompt_metrics_list)} prompt types" + ) + except Exception as e: + logger.warning(f"Failed to collect prompt-level metrics: {str(e)}") + + # Add search costs from the SearchData artifact + _add_search_costs_to_metadata(metadata, search_data) + + total_search_cost = sum(metadata.search_costs.values()) + logger.info( + f"Successfully collected tracing metadata - " + f"LLM Cost: ${metadata.total_cost:.4f}, " + f"Search Cost: ${total_search_cost:.4f}, " + f"Total Cost: ${metadata.total_cost + total_search_cost:.4f}, " + f"Tokens: {metadata.total_tokens:,}, " + f"Models: {metadata.models_used}, " + f"Duration: {metadata.formatted_latency}" + ) + + except Exception as e: + logger.error( + f"Failed to collect tracing metadata for pipeline run {pipeline_run_name}: {str(e)}" + ) + # Return metadata with whatever we could collect + # Still try to get search costs even if Langfuse failed + _add_search_costs_to_metadata(metadata, search_data) + + # Add tags to the artifact + # add_tags( + # tags=["exa", "tavily", "llm", "cost", "tracing"], + # artifact_name="tracing_metadata", + # infer_artifact=True, + # ) + + return metadata + + +def _add_search_costs_to_metadata( + metadata: TracingMetadata, search_data: SearchData +) -> None: + """Add search costs from SearchData to TracingMetadata. + + Args: + metadata: The TracingMetadata object to update + search_data: The SearchData containing cost information + """ + if search_data.search_costs: + metadata.search_costs = search_data.search_costs.copy() + logger.info(f"Added search costs: {metadata.search_costs}") + + if search_data.search_cost_details: + # Convert SearchCostDetail objects to dicts for backward compatibility + metadata.search_cost_details = [ + { + "provider": detail.provider, + "query": detail.query, + "cost": detail.cost, + "timestamp": detail.timestamp, + "step": detail.step, + "sub_question": detail.sub_question, + } + for detail in search_data.search_cost_details + ] + + # Count queries by provider + search_queries_count: Dict[str, int] = {} + for detail in search_data.search_cost_details: + provider = detail.provider + search_queries_count[provider] = ( + search_queries_count.get(provider, 0) + 1 + ) + metadata.search_queries_count = search_queries_count + + logger.info( + f"Added {len(metadata.search_cost_details)} search cost detail entries" + ) diff --git a/deep_research/steps/cross_viewpoint_step.py b/deep_research/steps/cross_viewpoint_step.py new file mode 100644 index 00000000..cc844bdd --- /dev/null +++ b/deep_research/steps/cross_viewpoint_step.py @@ -0,0 +1,243 @@ +import json +import logging +import time +from typing import Annotated, List + +from materializers.analysis_data_materializer import AnalysisDataMaterializer +from utils.llm_utils import run_llm_completion, safe_json_loads +from utils.pydantic_models import ( + AnalysisData, + Prompt, + QueryContext, + SynthesisData, + ViewpointAnalysis, + ViewpointTension, +) +from zenml import log_metadata, step + +logger = logging.getLogger(__name__) + + +@step(output_materializers={"analysis_data": AnalysisDataMaterializer}) +def cross_viewpoint_analysis_step( + query_context: QueryContext, + synthesis_data: SynthesisData, + viewpoint_analysis_prompt: Prompt, + llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", + viewpoint_categories: List[str] = [ + "scientific", + "political", + "economic", + "social", + "ethical", + "historical", + ], + langfuse_project_name: str = "deep-research", +) -> Annotated[AnalysisData, "analysis_data"]: + """Analyze synthesized information across different viewpoints. + + Args: + query_context: The query context with main query and sub-questions + synthesis_data: The synthesized information to analyze + viewpoint_analysis_prompt: Prompt for viewpoint analysis + llm_model: The model to use for viewpoint analysis + viewpoint_categories: Categories of viewpoints to analyze + langfuse_project_name: Project name for tracing + + Returns: + AnalysisData containing viewpoint analysis + """ + start_time = time.time() + logger.info( + f"Performing cross-viewpoint analysis on {len(synthesis_data.synthesized_info)} sub-questions" + ) + + # Initialize analysis data + analysis_data = AnalysisData() + + # Prepare input for viewpoint analysis + analysis_input = { + "main_query": query_context.main_query, + "sub_questions": query_context.sub_questions, + "synthesized_information": { + question: { + "synthesized_answer": info.synthesized_answer, + "key_sources": info.key_sources, + "confidence_level": info.confidence_level, + "information_gaps": info.information_gaps, + } + for question, info in synthesis_data.synthesized_info.items() + }, + "viewpoint_categories": viewpoint_categories, + } + + # Perform viewpoint analysis + try: + logger.info(f"Calling {llm_model} for viewpoint analysis") + # Use the run_llm_completion function from llm_utils + content = run_llm_completion( + prompt=json.dumps(analysis_input), + system_prompt=str(viewpoint_analysis_prompt), + model=llm_model, # Model name will be prefixed in the function + max_tokens=3000, # Further increased for more comprehensive viewpoint analysis + project=langfuse_project_name, + ) + + result = safe_json_loads(content) + + if not result: + logger.warning("Failed to parse viewpoint analysis result") + # Create a default viewpoint analysis + viewpoint_analysis = ViewpointAnalysis( + main_points_of_agreement=[ + "Analysis failed to identify points of agreement." + ], + perspective_gaps="Analysis failed to identify perspective gaps.", + integrative_insights="Analysis failed to provide integrative insights.", + ) + else: + # Create tension objects + tensions = [] + for tension_data in result.get("areas_of_tension", []): + tensions.append( + ViewpointTension( + topic=tension_data.get("topic", ""), + viewpoints=tension_data.get("viewpoints", {}), + ) + ) + + # Create the viewpoint analysis object + viewpoint_analysis = ViewpointAnalysis( + main_points_of_agreement=result.get( + "main_points_of_agreement", [] + ), + areas_of_tension=tensions, + perspective_gaps=result.get("perspective_gaps", ""), + integrative_insights=result.get("integrative_insights", ""), + ) + + logger.info("Completed viewpoint analysis") + + # Update the analysis data with the viewpoint analysis + analysis_data.viewpoint_analysis = viewpoint_analysis + + # Calculate execution time + execution_time = time.time() - start_time + + # Count viewpoint tensions by category + tension_categories = {} + for tension in viewpoint_analysis.areas_of_tension: + for category in tension.viewpoints.keys(): + tension_categories[category] = ( + tension_categories.get(category, 0) + 1 + ) + + # Log metadata + log_metadata( + metadata={ + "viewpoint_analysis": { + "execution_time_seconds": execution_time, + "llm_model": llm_model, + "num_sub_questions_analyzed": len( + synthesis_data.synthesized_info + ), + "viewpoint_categories_requested": viewpoint_categories, + "num_agreement_points": len( + viewpoint_analysis.main_points_of_agreement + ), + "num_tension_areas": len( + viewpoint_analysis.areas_of_tension + ), + "tension_categories_distribution": tension_categories, + "has_perspective_gaps": bool( + viewpoint_analysis.perspective_gaps + and viewpoint_analysis.perspective_gaps != "" + ), + "has_integrative_insights": bool( + viewpoint_analysis.integrative_insights + and viewpoint_analysis.integrative_insights != "" + ), + "analysis_success": not viewpoint_analysis.main_points_of_agreement[ + 0 + ].startswith("Analysis failed"), + } + } + ) + + # Log model metadata for cross-pipeline tracking + log_metadata( + metadata={ + "research_scope": { + "num_tension_areas": len( + viewpoint_analysis.areas_of_tension + ), + } + }, + infer_model=True, + ) + + # Log artifact metadata + log_metadata( + metadata={ + "analysis_data_characteristics": { + "has_viewpoint_analysis": True, + "total_viewpoints_analyzed": sum( + tension_categories.values() + ), + "most_common_tension_category": max( + tension_categories, key=tension_categories.get + ) + if tension_categories + else None, + } + }, + artifact_name="analysis_data", + infer_artifact=True, + ) + + # Add tags to the artifact + # add_tags(tags=["analysis", "viewpoint"], artifact_name="analysis_data", infer_artifact=True) + + return analysis_data + + except Exception as e: + logger.error(f"Error performing viewpoint analysis: {e}") + + # Create a fallback viewpoint analysis + fallback_analysis = ViewpointAnalysis( + main_points_of_agreement=[ + "Analysis failed due to technical error." + ], + perspective_gaps=f"Analysis failed: {str(e)}", + integrative_insights="No insights available due to analysis failure.", + ) + + # Update the analysis data with the fallback analysis + analysis_data.viewpoint_analysis = fallback_analysis + + # Log error metadata + execution_time = time.time() - start_time + log_metadata( + metadata={ + "viewpoint_analysis": { + "execution_time_seconds": execution_time, + "llm_model": llm_model, + "num_sub_questions_analyzed": len( + synthesis_data.synthesized_info + ), + "viewpoint_categories_requested": viewpoint_categories, + "analysis_success": False, + "error_message": str(e), + "fallback_used": True, + } + } + ) + + # Add tags to the artifact + # add_tags( + # tags=["analysis", "viewpoint", "fallback"], + # artifact_name="analysis_data", + # infer_artifact=True, + # ) + + return analysis_data diff --git a/deep_research/steps/execute_approved_searches_step.py b/deep_research/steps/execute_approved_searches_step.py new file mode 100644 index 00000000..f1eb8625 --- /dev/null +++ b/deep_research/steps/execute_approved_searches_step.py @@ -0,0 +1,450 @@ +import json +import logging +import time +from typing import Annotated, List, Tuple + +from materializers.analysis_data_materializer import AnalysisDataMaterializer +from materializers.search_data_materializer import SearchDataMaterializer +from materializers.synthesis_data_materializer import SynthesisDataMaterializer +from utils.llm_utils import ( + find_most_relevant_string, + get_structured_llm_output, + is_text_relevant, +) +from utils.pydantic_models import ( + AnalysisData, + ApprovalDecision, + Prompt, + QueryContext, + SearchCostDetail, + SearchData, + SynthesisData, + SynthesizedInfo, +) +from utils.search_utils import search_and_extract_results +from zenml import log_metadata, step + +logger = logging.getLogger(__name__) + + +@step( + output_materializers={ + "enhanced_search_data": SearchDataMaterializer, + "enhanced_synthesis_data": SynthesisDataMaterializer, + "updated_analysis_data": AnalysisDataMaterializer, + } +) +def execute_approved_searches_step( + query_context: QueryContext, + search_data: SearchData, + synthesis_data: SynthesisData, + analysis_data: AnalysisData, + recommended_queries: List[str], + approval_decision: ApprovalDecision, + additional_synthesis_prompt: Prompt, + num_results_per_search: int = 3, + cap_search_length: int = 20000, + llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", + search_provider: str = "tavily", + search_mode: str = "auto", + langfuse_project_name: str = "deep-research", +) -> Tuple[ + Annotated[SearchData, "enhanced_search_data"], + Annotated[SynthesisData, "enhanced_synthesis_data"], + Annotated[AnalysisData, "updated_analysis_data"], +]: + """Execute approved searches and enhance the research artifacts. + + This step receives the approval decision and only executes + searches that were approved by the human reviewer (or auto-approved). + + Args: + query_context: The query context with main query and sub-questions + search_data: The existing search data + synthesis_data: The existing synthesis data + analysis_data: The analysis data with viewpoint and reflection metadata + recommended_queries: The recommended queries from reflection + approval_decision: Human approval decision + additional_synthesis_prompt: Prompt for synthesis enhancement + num_results_per_search: Number of results to fetch per search + cap_search_length: Maximum length of content to process from search results + llm_model: The model to use for synthesis enhancement + search_provider: Search provider to use + search_mode: Search mode for the provider + langfuse_project_name: Project name for tracing + + Returns: + Tuple of enhanced SearchData, SynthesisData, and updated AnalysisData + """ + start_time = time.time() + logger.info( + f"Processing approval decision: {approval_decision.approval_method}" + ) + + # Create copies of the data to enhance + enhanced_search_data = SearchData( + search_results=search_data.search_results.copy(), + search_costs=search_data.search_costs.copy(), + search_cost_details=search_data.search_cost_details.copy(), + total_searches=search_data.total_searches, + ) + + enhanced_synthesis_data = SynthesisData( + synthesized_info=synthesis_data.synthesized_info.copy(), + enhanced_info={}, # Will be populated with enhanced versions + ) + + # Track improvements count + improvements_count = 0 + + # Check if we should execute searches + if ( + not approval_decision.approved + or not approval_decision.selected_queries + ): + logger.info("No additional searches approved") + + # Update reflection metadata with no searches + if analysis_data.reflection_metadata: + analysis_data.reflection_metadata.searches_performed = [] + analysis_data.reflection_metadata.improvements_made = 0.0 + + # Log metadata for no approved searches + execution_time = time.time() - start_time + log_metadata( + metadata={ + "execute_approved_searches": { + "execution_time_seconds": execution_time, + "approval_method": approval_decision.approval_method, + "approval_status": "not_approved" + if not approval_decision.approved + else "no_queries", + "num_queries_approved": 0, + "num_searches_executed": 0, + "num_recommended": len(recommended_queries), + "improvements_made": improvements_count, + "search_provider": search_provider, + "llm_model": llm_model, + } + } + ) + + # Add tags to the artifacts + # add_tags( + # tags=["search", "not-enhanced"], artifact_name="enhanced_search_data", infer_artifact=True + # ) + # add_tags( + # tags=["synthesis", "not-enhanced"], + # artifact_name="enhanced_synthesis_data", + # infer_artifact=True, + # ) + # add_tags( + # tags=["analysis", "no-searches"], artifact_name="updated_analysis_data", infer_artifact=True + # ) + + return enhanced_search_data, enhanced_synthesis_data, analysis_data + + # Execute approved searches + logger.info( + f"Executing {len(approval_decision.selected_queries)} approved searches" + ) + + try: + search_enhancements = [] # Track search results for metadata + + for query in approval_decision.selected_queries: + logger.info(f"Performing approved search: {query}") + + # Execute search using the utility function + search_results, search_cost = search_and_extract_results( + query=query, + max_results=num_results_per_search, + cap_content_length=cap_search_length, + provider=search_provider, + search_mode=search_mode, + ) + + # Track search costs if using Exa + if ( + search_provider + and search_provider.lower() in ["exa", "both"] + and search_cost > 0 + ): + # Update total costs + enhanced_search_data.search_costs["exa"] = ( + enhanced_search_data.search_costs.get("exa", 0.0) + + search_cost + ) + + # Add detailed cost entry + enhanced_search_data.search_cost_details.append( + SearchCostDetail( + provider="exa", + query=query, + cost=search_cost, + timestamp=time.time(), + step="execute_approved_searches", + sub_question=None, # These are reflection queries + ) + ) + logger.info( + f"Exa search cost for approved query: ${search_cost:.4f}" + ) + + # Update total searches + enhanced_search_data.total_searches += 1 + + # Extract raw contents + raw_contents = [result.content for result in search_results] + + # Find the most relevant sub-question for this query + most_relevant_question = find_most_relevant_string( + query, + query_context.sub_questions, + llm_model, + project=langfuse_project_name, + ) + + if ( + most_relevant_question + and most_relevant_question in synthesis_data.synthesized_info + ): + # Store the search results under the relevant question + if ( + most_relevant_question + in enhanced_search_data.search_results + ): + enhanced_search_data.search_results[ + most_relevant_question + ].extend(search_results) + else: + enhanced_search_data.search_results[ + most_relevant_question + ] = search_results + + # Enhance the synthesis with new information + original_synthesis = synthesis_data.synthesized_info[ + most_relevant_question + ] + + enhancement_input = { + "original_synthesis": original_synthesis.synthesized_answer, + "new_information": raw_contents, + "critique": [ + item + for item in analysis_data.reflection_metadata.critique_summary + if is_text_relevant(item, most_relevant_question) + ] + if analysis_data.reflection_metadata + else [], + } + + # Use the utility function for enhancement + enhanced_synthesis = get_structured_llm_output( + prompt=json.dumps(enhancement_input), + system_prompt=str(additional_synthesis_prompt), + model=llm_model, + fallback_response={ + "enhanced_synthesis": original_synthesis.synthesized_answer, + "improvements_made": ["Failed to enhance synthesis"], + "remaining_limitations": "Enhancement process failed.", + }, + project=langfuse_project_name, + ) + + if ( + enhanced_synthesis + and "enhanced_synthesis" in enhanced_synthesis + ): + # Create enhanced synthesis info + enhanced_info = SynthesizedInfo( + synthesized_answer=enhanced_synthesis[ + "enhanced_synthesis" + ], + key_sources=original_synthesis.key_sources + + [r.url for r in search_results[:2]], + confidence_level="high" + if original_synthesis.confidence_level == "medium" + else original_synthesis.confidence_level, + information_gaps=enhanced_synthesis.get( + "remaining_limitations", "" + ), + improvements=original_synthesis.improvements + + enhanced_synthesis.get("improvements_made", []), + ) + + # Store in enhanced_info + enhanced_synthesis_data.enhanced_info[ + most_relevant_question + ] = enhanced_info + + improvements = enhanced_synthesis.get( + "improvements_made", [] + ) + improvements_count += len(improvements) + + # Track enhancement for metadata + search_enhancements.append( + { + "query": query, + "relevant_question": most_relevant_question, + "num_results": len(search_results), + "improvements": len(improvements), + "enhanced": True, + "search_cost": search_cost + if search_provider + and search_provider.lower() in ["exa", "both"] + else 0.0, + } + ) + + # Update reflection metadata with search info + if analysis_data.reflection_metadata: + analysis_data.reflection_metadata.searches_performed = ( + approval_decision.selected_queries + ) + analysis_data.reflection_metadata.improvements_made = float( + improvements_count + ) + + logger.info( + f"Completed approved searches with {improvements_count} improvements" + ) + + # Calculate metrics for metadata + execution_time = time.time() - start_time + total_results = sum( + e.get("num_results", 0) for e in search_enhancements + ) + questions_enhanced = len( + set( + e.get("relevant_question") + for e in search_enhancements + if e.get("enhanced") + ) + ) + + # Log successful execution metadata + log_metadata( + metadata={ + "execute_approved_searches": { + "execution_time_seconds": execution_time, + "approval_method": approval_decision.approval_method, + "approval_status": "approved", + "num_queries_recommended": len(recommended_queries), + "num_queries_approved": len( + approval_decision.selected_queries + ), + "num_searches_executed": len( + approval_decision.selected_queries + ), + "total_search_results": total_results, + "questions_enhanced": questions_enhanced, + "improvements_made": improvements_count, + "search_provider": search_provider, + "search_mode": search_mode, + "llm_model": llm_model, + "success": True, + "total_search_cost": enhanced_search_data.search_costs.get( + "exa", 0.0 + ), + } + } + ) + + # Log artifact metadata + log_metadata( + metadata={ + "search_data_characteristics": { + "new_searches": len(approval_decision.selected_queries), + "total_searches": enhanced_search_data.total_searches, + "additional_cost": enhanced_search_data.search_costs.get( + "exa", 0.0 + ) + - search_data.search_costs.get("exa", 0.0), + } + }, + artifact_name="enhanced_search_data", + infer_artifact=True, + ) + + log_metadata( + metadata={ + "synthesis_data_characteristics": { + "questions_enhanced": questions_enhanced, + "total_enhancements": len( + enhanced_synthesis_data.enhanced_info + ), + "improvements_made": improvements_count, + } + }, + artifact_name="enhanced_synthesis_data", + infer_artifact=True, + ) + + log_metadata( + metadata={ + "analysis_data_characteristics": { + "searches_performed": len( + approval_decision.selected_queries + ), + "approval_method": approval_decision.approval_method, + } + }, + artifact_name="updated_analysis_data", + infer_artifact=True, + ) + + # Add tags to the artifacts + # add_tags(tags=["search", "enhanced"], artifact_name="enhanced_search_data", infer_artifact=True) + # add_tags( + # tags=["synthesis", "enhanced"], artifact_name="enhanced_synthesis_data", infer_artifact=True + # ) + # add_tags( + # tags=["analysis", "with-searches"], + # artifact_name="updated_analysis_data", + # infer_artifact=True, + # ) + + return enhanced_search_data, enhanced_synthesis_data, analysis_data + + except Exception as e: + logger.error(f"Error during approved search execution: {e}") + + # Update reflection metadata with error + if analysis_data.reflection_metadata: + analysis_data.reflection_metadata.error = ( + f"Approved search execution failed: {str(e)}" + ) + analysis_data.reflection_metadata.searches_performed = [] + analysis_data.reflection_metadata.improvements_made = 0.0 + + # Log error metadata + execution_time = time.time() - start_time + log_metadata( + metadata={ + "execute_approved_searches": { + "execution_time_seconds": execution_time, + "approval_method": approval_decision.approval_method, + "approval_status": "approved", + "num_queries_approved": len( + approval_decision.selected_queries + ), + "num_searches_executed": 0, + "improvements_made": 0, + "search_provider": search_provider, + "llm_model": llm_model, + "success": False, + "error_message": str(e), + } + } + ) + + # Add tags to the artifacts + # add_tags(tags=["search", "error"], artifact_name="enhanced_search_data", infer_artifact=True) + # add_tags( + # tags=["synthesis", "error"], artifact_name="enhanced_synthesis_data", infer_artifact=True + # ) + # add_tags(tags=["analysis", "error"], artifact_name="updated_analysis_data", infer_artifact=True) + + return enhanced_search_data, enhanced_synthesis_data, analysis_data diff --git a/deep_research/steps/generate_reflection_step.py b/deep_research/steps/generate_reflection_step.py new file mode 100644 index 00000000..61cf4637 --- /dev/null +++ b/deep_research/steps/generate_reflection_step.py @@ -0,0 +1,197 @@ +import json +import logging +import time +from typing import Annotated, List, Tuple + +from materializers.analysis_data_materializer import AnalysisDataMaterializer +from utils.llm_utils import get_structured_llm_output +from utils.pydantic_models import ( + AnalysisData, + Prompt, + QueryContext, + ReflectionMetadata, + SynthesisData, +) +from zenml import log_metadata, step + +logger = logging.getLogger(__name__) + + +@step( + output_materializers={ + "analysis_data": AnalysisDataMaterializer, + } +) +def generate_reflection_step( + query_context: QueryContext, + synthesis_data: SynthesisData, + analysis_data: AnalysisData, + reflection_prompt: Prompt, + llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", + langfuse_project_name: str = "deep-research", +) -> Tuple[ + Annotated[AnalysisData, "analysis_data"], + Annotated[List[str], "recommended_queries"], +]: + """ + Generate reflection and recommendations WITHOUT executing searches. + + This step only analyzes the current state and produces recommendations + for additional research that could improve the quality of the results. + + Args: + query_context: The query context with main query and sub-questions + synthesis_data: The synthesized information + analysis_data: The analysis data with viewpoint analysis + reflection_prompt: Prompt for generating reflection + llm_model: The model to use for reflection + langfuse_project_name: Project name for tracing + + Returns: + Tuple of updated AnalysisData and recommended queries + """ + start_time = time.time() + logger.info("Generating reflection on research") + + # Prepare input for reflection + synthesized_info_dict = { + question: { + "synthesized_answer": info.synthesized_answer, + "key_sources": info.key_sources, + "confidence_level": info.confidence_level, + "information_gaps": info.information_gaps, + } + for question, info in synthesis_data.synthesized_info.items() + } + + viewpoint_analysis_dict = None + if analysis_data.viewpoint_analysis: + # Convert the viewpoint analysis to a dict for the LLM + tension_list = [] + for tension in analysis_data.viewpoint_analysis.areas_of_tension: + tension_list.append( + {"topic": tension.topic, "viewpoints": tension.viewpoints} + ) + + viewpoint_analysis_dict = { + "main_points_of_agreement": analysis_data.viewpoint_analysis.main_points_of_agreement, + "areas_of_tension": tension_list, + "perspective_gaps": analysis_data.viewpoint_analysis.perspective_gaps, + "integrative_insights": analysis_data.viewpoint_analysis.integrative_insights, + } + + reflection_input = { + "main_query": query_context.main_query, + "sub_questions": query_context.sub_questions, + "synthesized_information": synthesized_info_dict, + } + + if viewpoint_analysis_dict: + reflection_input["viewpoint_analysis"] = viewpoint_analysis_dict + + # Get reflection critique + logger.info(f"Generating self-critique via {llm_model}") + + # Define fallback for reflection result + fallback_reflection = { + "critique": [], + "additional_questions": [], + "recommended_search_queries": [], + } + + # Use utility function to get structured output + reflection_result = get_structured_llm_output( + prompt=json.dumps(reflection_input), + system_prompt=str(reflection_prompt), + model=llm_model, + fallback_response=fallback_reflection, + project=langfuse_project_name, + ) + + # Extract results + recommended_queries = reflection_result.get( + "recommended_search_queries", [] + ) + critique_summary = reflection_result.get("critique", []) + additional_questions = reflection_result.get("additional_questions", []) + + # Update analysis data with reflection metadata + analysis_data.reflection_metadata = ReflectionMetadata( + critique_summary=[ + str(c) for c in critique_summary + ], # Convert to strings + additional_questions_identified=additional_questions, + searches_performed=[], # Will be populated by execute_approved_searches_step + improvements_made=0.0, # Will be updated later + ) + + # Calculate execution time + execution_time = time.time() - start_time + + # Count confidence levels in synthesized info + confidence_levels = [ + info.confidence_level + for info in synthesis_data.synthesized_info.values() + ] + confidence_distribution = { + "high": confidence_levels.count("high"), + "medium": confidence_levels.count("medium"), + "low": confidence_levels.count("low"), + } + + # Log step metadata + log_metadata( + metadata={ + "reflection_generation": { + "execution_time_seconds": execution_time, + "llm_model": llm_model, + "num_sub_questions_analyzed": len(query_context.sub_questions), + "num_synthesized_answers": len( + synthesis_data.synthesized_info + ), + "viewpoint_analysis_included": bool(viewpoint_analysis_dict), + "num_critique_points": len(critique_summary), + "num_additional_questions": len(additional_questions), + "num_recommended_queries": len(recommended_queries), + "confidence_distribution": confidence_distribution, + "has_information_gaps": any( + info.information_gaps + for info in synthesis_data.synthesized_info.values() + ), + } + } + ) + + # Log artifact metadata + log_metadata( + metadata={ + "analysis_data_characteristics": { + "has_reflection_metadata": True, + "has_viewpoint_analysis": analysis_data.viewpoint_analysis + is not None, + "num_critique_points": len(critique_summary), + "num_additional_questions": len(additional_questions), + } + }, + artifact_name="analysis_data", + infer_artifact=True, + ) + + log_metadata( + metadata={ + "recommended_queries_characteristics": { + "num_queries": len(recommended_queries), + "has_recommendations": bool(recommended_queries), + } + }, + artifact_name="recommended_queries", + infer_artifact=True, + ) + + # Add tags to the artifacts + # add_tags(tags=["analysis", "reflection"], artifact_name="analysis_data", infer_artifact=True) + # add_tags( + # tags=["recommendations", "queries"], artifact_name="recommended_queries", infer_artifact=True + # ) + + return analysis_data, recommended_queries diff --git a/deep_research/steps/initialize_prompts_step.py b/deep_research/steps/initialize_prompts_step.py new file mode 100644 index 00000000..fe47c395 --- /dev/null +++ b/deep_research/steps/initialize_prompts_step.py @@ -0,0 +1,154 @@ +"""Step to initialize and track prompts as individual artifacts. + +This step creates individual Prompt artifacts at the beginning of the pipeline, +making all prompts trackable and versioned in ZenML. +""" + +import logging +from typing import Annotated, Tuple + +from materializers.prompt_materializer import PromptMaterializer +from utils import prompts +from utils.pydantic_models import Prompt +from zenml import step + +logger = logging.getLogger(__name__) + + +@step(output_materializers=PromptMaterializer) +def initialize_prompts_step( + pipeline_version: str = "1.1.0", +) -> Tuple[ + Annotated[Prompt, "search_query_prompt"], + Annotated[Prompt, "query_decomposition_prompt"], + Annotated[Prompt, "synthesis_prompt"], + Annotated[Prompt, "viewpoint_analysis_prompt"], + Annotated[Prompt, "reflection_prompt"], + Annotated[Prompt, "additional_synthesis_prompt"], + Annotated[Prompt, "conclusion_generation_prompt"], + Annotated[Prompt, "executive_summary_prompt"], + Annotated[Prompt, "introduction_prompt"], +]: + """Initialize individual prompts for the pipeline. + + This step loads all prompts from the prompts.py module and creates + individual Prompt artifacts that can be tracked and visualized in ZenML. + + Args: + pipeline_version: Version of the pipeline using these prompts + + Returns: + Tuple of individual Prompt artifacts used in the pipeline + """ + logger.info( + f"Initializing prompts for pipeline version {pipeline_version}" + ) + + # Create individual prompt instances + search_query_prompt = Prompt( + content=prompts.DEFAULT_SEARCH_QUERY_PROMPT, + name="search_query_prompt", + description="Generates effective search queries from sub-questions", + version="1.0.0", + tags=["search", "query", "information-gathering"], + ) + + query_decomposition_prompt = Prompt( + content=prompts.QUERY_DECOMPOSITION_PROMPT, + name="query_decomposition_prompt", + description="Breaks down complex research queries into specific sub-questions", + version="1.0.0", + tags=["analysis", "decomposition", "planning"], + ) + + synthesis_prompt = Prompt( + content=prompts.SYNTHESIS_PROMPT, + name="synthesis_prompt", + description="Synthesizes search results into comprehensive answers for sub-questions", + version="1.1.0", + tags=["synthesis", "integration", "analysis"], + ) + + viewpoint_analysis_prompt = Prompt( + content=prompts.VIEWPOINT_ANALYSIS_PROMPT, + name="viewpoint_analysis_prompt", + description="Analyzes synthesized answers across different perspectives and viewpoints", + version="1.1.0", + tags=["analysis", "viewpoint", "perspective"], + ) + + reflection_prompt = Prompt( + content=prompts.REFLECTION_PROMPT, + name="reflection_prompt", + description="Evaluates research and identifies gaps, biases, and areas for improvement", + version="1.0.0", + tags=["reflection", "critique", "improvement"], + ) + + additional_synthesis_prompt = Prompt( + content=prompts.ADDITIONAL_SYNTHESIS_PROMPT, + name="additional_synthesis_prompt", + description="Enhances original synthesis with new information and addresses critique points", + version="1.1.0", + tags=["synthesis", "enhancement", "integration"], + ) + + conclusion_generation_prompt = Prompt( + content=prompts.CONCLUSION_GENERATION_PROMPT, + name="conclusion_generation_prompt", + description="Synthesizes all research findings into a comprehensive conclusion", + version="1.0.0", + tags=["report", "conclusion", "synthesis"], + ) + + executive_summary_prompt = Prompt( + content=prompts.EXECUTIVE_SUMMARY_GENERATION_PROMPT, + name="executive_summary_prompt", + description="Creates a compelling, insight-driven executive summary", + version="1.1.0", + tags=["report", "summary", "insights"], + ) + + introduction_prompt = Prompt( + content=prompts.INTRODUCTION_GENERATION_PROMPT, + name="introduction_prompt", + description="Creates a contextual, engaging introduction", + version="1.1.0", + tags=["report", "introduction", "context"], + ) + + logger.info(f"Loaded 9 individual prompts") + + # # add tags to all prompts + # add_tags(tags=["prompt", "search"], artifact_name="search_query_prompt", infer_artifact=True) + + # add_tags( + # tags=["prompt", "generation"], artifact_name="query_decomposition_prompt", infer_artifact=True + # ) + # add_tags(tags=["prompt", "generation"], artifact_name="synthesis_prompt", infer_artifact=True) + # add_tags( + # tags=["prompt", "generation"], artifact_name="viewpoint_analysis_prompt", infer_artifact=True + # ) + # add_tags(tags=["prompt", "generation"], artifact_name="reflection_prompt", infer_artifact=True) + # add_tags( + # tags=["prompt", "generation"], artifact_name="additional_synthesis_prompt", infer_artifact=True + # ) + # add_tags( + # tags=["prompt", "generation"], artifact_name="conclusion_generation_prompt", infer_artifact=True + # ) + # add_tags( + # tags=["prompt", "generation"], artifact_name="executive_summary_prompt", infer_artifact=True + # ) + # add_tags(tags=["prompt", "generation"], artifact_name="introduction_prompt", infer_artifact=True) + + return ( + search_query_prompt, + query_decomposition_prompt, + synthesis_prompt, + viewpoint_analysis_prompt, + reflection_prompt, + additional_synthesis_prompt, + conclusion_generation_prompt, + executive_summary_prompt, + introduction_prompt, + ) diff --git a/deep_research/steps/merge_results_step.py b/deep_research/steps/merge_results_step.py new file mode 100644 index 00000000..4802c98f --- /dev/null +++ b/deep_research/steps/merge_results_step.py @@ -0,0 +1,231 @@ +import logging +import time +from typing import Annotated, Tuple + +from materializers.search_data_materializer import SearchDataMaterializer +from materializers.synthesis_data_materializer import SynthesisDataMaterializer +from utils.pydantic_models import SearchData, SynthesisData +from zenml import get_step_context, log_metadata, step +from zenml.client import Client + +logger = logging.getLogger(__name__) + + +@step( + output_materializers={ + "merged_search_data": SearchDataMaterializer, + "merged_synthesis_data": SynthesisDataMaterializer, + } +) +def merge_sub_question_results_step( + step_prefix: str = "process_question_", +) -> Tuple[ + Annotated[SearchData, "merged_search_data"], + Annotated[SynthesisData, "merged_synthesis_data"], +]: + """Merge results from individual sub-question processing steps. + + This step collects the results from the parallel sub-question processing steps + and combines them into single SearchData and SynthesisData artifacts. + + Args: + step_prefix: The prefix used in step IDs for the parallel processing steps + + Returns: + Tuple of merged SearchData and SynthesisData artifacts + + Note: + This step is typically configured with the 'after' parameter in the pipeline + definition to ensure it runs after all parallel sub-question processing steps + have completed. + """ + start_time = time.time() + + # Initialize merged artifacts + merged_search_data = SearchData() + merged_synthesis_data = SynthesisData() + + # Get pipeline run information to access outputs + try: + ctx = get_step_context() + if not ctx or not ctx.pipeline_run: + logger.error("Could not get pipeline run context") + return merged_search_data, merged_synthesis_data + + run_name = ctx.pipeline_run.name + client = Client() + run = client.get_pipeline_run(run_name) + + logger.info( + f"Merging results from parallel sub-question processing steps in run: {run_name}" + ) + + # Track which sub-questions were successfully processed + processed_questions = set() + parallel_steps_processed = 0 + + # Process each step in the run + for step_name, step_info in run.steps.items(): + # Only process steps with the specified prefix + if step_name.startswith(step_prefix): + try: + # Extract the sub-question index from the step name + if "_" in step_name: + index = int(step_name.split("_")[-1]) + logger.info( + f"Processing results from step: {step_name} (index: {index})" + ) + + # Get the search_data artifact + if "search_data" in step_info.outputs: + search_artifacts = step_info.outputs["search_data"] + if search_artifacts: + search_artifact = search_artifacts[0] + sub_search_data = search_artifact.load() + + # Merge search data + merged_search_data.merge(sub_search_data) + + # Track processed questions + for sub_q in sub_search_data.search_results: + processed_questions.add(sub_q) + logger.info( + f"Merged search results for: {sub_q}" + ) + + # Get the synthesis_data artifact + if "synthesis_data" in step_info.outputs: + synthesis_artifacts = step_info.outputs[ + "synthesis_data" + ] + if synthesis_artifacts: + synthesis_artifact = synthesis_artifacts[0] + sub_synthesis_data = synthesis_artifact.load() + + # Merge synthesis data + merged_synthesis_data.merge(sub_synthesis_data) + + # Track processed questions + for ( + sub_q + ) in sub_synthesis_data.synthesized_info: + logger.info( + f"Merged synthesis info for: {sub_q}" + ) + + parallel_steps_processed += 1 + + except (ValueError, IndexError, KeyError, AttributeError) as e: + logger.warning(f"Error processing step {step_name}: {e}") + continue + + # Log summary + logger.info( + f"Merged results from {parallel_steps_processed} parallel steps" + ) + logger.info( + f"Successfully processed {len(processed_questions)} sub-questions" + ) + + # Log search cost summary + if merged_search_data.search_costs: + total_cost = sum(merged_search_data.search_costs.values()) + logger.info( + f"Total search costs merged: ${total_cost:.4f} across {len(merged_search_data.search_cost_details)} queries" + ) + for provider, cost in merged_search_data.search_costs.items(): + logger.info(f" {provider}: ${cost:.4f}") + + except Exception as e: + logger.error(f"Error during merge step: {e}") + + # Final check for empty results + if ( + not merged_search_data.search_results + or not merged_synthesis_data.synthesized_info + ): + logger.warning( + "No results were found or merged from parallel processing steps!" + ) + + # Calculate execution time + execution_time = time.time() - start_time + + # Count total search results across all questions + total_search_results = sum( + len(results) for results in merged_search_data.search_results.values() + ) + + # Get confidence distribution for merged results + confidence_distribution = {"high": 0, "medium": 0, "low": 0} + for info in merged_synthesis_data.synthesized_info.values(): + level = info.confidence_level.lower() + if level in confidence_distribution: + confidence_distribution[level] += 1 + + # Log metadata + log_metadata( + metadata={ + "merge_results": { + "execution_time_seconds": execution_time, + "parallel_steps_processed": parallel_steps_processed, + "questions_successfully_merged": len(processed_questions), + "total_search_results": total_search_results, + "confidence_distribution": confidence_distribution, + "merge_success": bool( + merged_search_data.search_results + and merged_synthesis_data.synthesized_info + ), + "total_search_costs": merged_search_data.search_costs, + "total_search_queries": len( + merged_search_data.search_cost_details + ), + "total_exa_cost": merged_search_data.search_costs.get( + "exa", 0.0 + ), + } + } + ) + + # Log artifact metadata + log_metadata( + metadata={ + "search_data_characteristics": { + "total_searches": merged_search_data.total_searches, + "search_results_count": len(merged_search_data.search_results), + "total_cost": sum(merged_search_data.search_costs.values()), + } + }, + artifact_name="merged_search_data", + infer_artifact=True, + ) + + log_metadata( + metadata={ + "synthesis_data_characteristics": { + "synthesized_info_count": len( + merged_synthesis_data.synthesized_info + ), + "enhanced_info_count": len( + merged_synthesis_data.enhanced_info + ), + "confidence_distribution": confidence_distribution, + } + }, + artifact_name="merged_synthesis_data", + infer_artifact=True, + ) + + # Add tags to the artifacts + # add_tags( + # tags=["search", "merged"], + # artifact_name="merged_search_data", + # infer_artifact=True, + # ) + # add_tags( + # tags=["synthesis", "merged"], + # artifact_name="merged_synthesis_data", + # infer_artifact=True, + # ) + + return merged_search_data, merged_synthesis_data diff --git a/deep_research/steps/process_sub_question_step.py b/deep_research/steps/process_sub_question_step.py new file mode 100644 index 00000000..7ff14ab1 --- /dev/null +++ b/deep_research/steps/process_sub_question_step.py @@ -0,0 +1,315 @@ +import logging +import time +import warnings +from typing import Annotated, Tuple + +# Suppress Pydantic serialization warnings from ZenML artifact metadata +# These occur when ZenML stores timestamp metadata as floats but models expect ints +warnings.filterwarnings( + "ignore", message=".*PydanticSerializationUnexpectedValue.*" +) + +from materializers.search_data_materializer import SearchDataMaterializer +from materializers.synthesis_data_materializer import SynthesisDataMaterializer +from utils.llm_utils import synthesize_information +from utils.pydantic_models import ( + Prompt, + QueryContext, + SearchCostDetail, + SearchData, + SynthesisData, + SynthesizedInfo, +) +from utils.search_utils import ( + generate_search_query, + search_and_extract_results, +) +from zenml import log_metadata, step + +logger = logging.getLogger(__name__) + + +@step( + output_materializers={ + "search_data": SearchDataMaterializer, + "synthesis_data": SynthesisDataMaterializer, + } +) +def process_sub_question_step( + query_context: QueryContext, + search_query_prompt: Prompt, + synthesis_prompt: Prompt, + question_index: int, + llm_model_search: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", + llm_model_synthesis: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", + num_results_per_search: int = 3, + cap_search_length: int = 20000, + search_provider: str = "tavily", + search_mode: str = "auto", + langfuse_project_name: str = "deep-research", +) -> Tuple[ + Annotated[SearchData, "search_data"], + Annotated[SynthesisData, "synthesis_data"], +]: + """Process a single sub-question if it exists at the given index. + + This step combines the gathering and synthesis steps for a single sub-question. + It's designed to be run in parallel for each sub-question. + + Args: + query_context: The query context with main query and sub-questions + search_query_prompt: Prompt for generating search queries + synthesis_prompt: Prompt for synthesizing search results + question_index: The index of the sub-question to process + llm_model_search: Model to use for search query generation + llm_model_synthesis: Model to use for synthesis + num_results_per_search: Number of results to fetch per search + cap_search_length: Maximum length of content to process from search results + search_provider: Search provider to use (tavily, exa, or both) + search_mode: Search mode for Exa provider (neural, keyword, or auto) + langfuse_project_name: Project name for tracing + + Returns: + Tuple of SearchData and SynthesisData for the processed sub-question + """ + start_time = time.time() + + # Initialize empty artifacts + search_data = SearchData() + synthesis_data = SynthesisData() + + # Check if this index exists in sub-questions + if question_index >= len(query_context.sub_questions): + logger.info( + f"No sub-question at index {question_index}, skipping processing" + ) + # Log metadata for skipped processing + log_metadata( + metadata={ + "sub_question_processing": { + "question_index": question_index, + "status": "skipped", + "reason": "index_out_of_range", + "total_sub_questions": len(query_context.sub_questions), + } + } + ) + # Return empty artifacts + # add_tags( + # tags=["search", "synthesis", "skipped"], artifact_name="search_data", infer_artifact=True + # ) + # add_tags( + # tags=["search", "synthesis", "skipped"], artifact_name="synthesis_data", infer_artifact=True + # ) + return search_data, synthesis_data + + # Get the target sub-question + sub_question = query_context.sub_questions[question_index] + logger.info( + f"Processing sub-question {question_index + 1}: {sub_question}" + ) + + # === INFORMATION GATHERING === + search_phase_start = time.time() + + # Generate search query with prompt + search_query_data = generate_search_query( + sub_question=sub_question, + model=llm_model_search, + system_prompt=str(search_query_prompt), + project=langfuse_project_name, + ) + search_query = search_query_data.get( + "search_query", f"research about {sub_question}" + ) + + # Perform search + logger.info(f"Performing search with query: {search_query}") + if search_provider: + logger.info(f"Using search provider: {search_provider}") + results_list, search_cost = search_and_extract_results( + query=search_query, + max_results=num_results_per_search, + cap_content_length=cap_search_length, + provider=search_provider, + search_mode=search_mode, + ) + + # Update search data + search_data.search_results[sub_question] = results_list + search_data.total_searches = 1 + + # Track search costs if using Exa + if ( + search_provider + and search_provider.lower() in ["exa", "both"] + and search_cost > 0 + ): + # Update total costs + search_data.search_costs["exa"] = search_cost + + # Add detailed cost entry + search_data.search_cost_details.append( + SearchCostDetail( + provider="exa", + query=search_query, + cost=search_cost, + timestamp=time.time(), + step="process_sub_question", + sub_question=sub_question, + ) + ) + logger.info( + f"Exa search cost for sub-question {question_index}: ${search_cost:.4f}" + ) + + search_phase_time = time.time() - search_phase_start + + # === INFORMATION SYNTHESIS === + synthesis_phase_start = time.time() + + # Extract raw contents and URLs + raw_contents = [] + sources = [] + for result in results_list: + raw_contents.append(result.content) + sources.append(result.url) + + # Prepare input for synthesis + synthesis_input = { + "sub_question": sub_question, + "search_results": raw_contents, + "sources": sources, + } + + # Synthesize information with prompt + synthesis_result = synthesize_information( + synthesis_input=synthesis_input, + model=llm_model_synthesis, + system_prompt=str(synthesis_prompt), + project=langfuse_project_name, + ) + + # Create SynthesizedInfo object + synthesized_info = SynthesizedInfo( + synthesized_answer=synthesis_result.get( + "synthesized_answer", f"Synthesis for '{sub_question}' failed." + ), + key_sources=synthesis_result.get("key_sources", sources[:1]), + confidence_level=synthesis_result.get("confidence_level", "low"), + information_gaps=synthesis_result.get( + "information_gaps", + "Synthesis process encountered technical difficulties.", + ), + improvements=synthesis_result.get("improvements", []), + ) + + # Update synthesis data + synthesis_data.synthesized_info[sub_question] = synthesized_info + + synthesis_phase_time = time.time() - synthesis_phase_start + total_execution_time = time.time() - start_time + + # Calculate total content length processed + total_content_length = sum(len(content) for content in raw_contents) + + # Get unique domains from sources + unique_domains = set() + for url in sources: + try: + from urllib.parse import urlparse + + domain = urlparse(url).netloc + unique_domains.add(domain) + except: + pass + + # Log comprehensive metadata + log_metadata( + metadata={ + "sub_question_processing": { + "question_index": question_index, + "status": "completed", + "sub_question": sub_question, + "execution_time_seconds": total_execution_time, + "search_phase_time_seconds": search_phase_time, + "synthesis_phase_time_seconds": synthesis_phase_time, + "search_query": search_query, + "search_provider": search_provider, + "search_mode": search_mode, + "num_results_requested": num_results_per_search, + "num_results_retrieved": len(results_list), + "total_content_length": total_content_length, + "cap_search_length": cap_search_length, + "unique_domains": list(unique_domains), + "llm_model_search": llm_model_search, + "llm_model_synthesis": llm_model_synthesis, + "confidence_level": synthesis_result.get( + "confidence_level", "low" + ), + "information_gaps": synthesis_result.get( + "information_gaps", "" + ), + "key_sources_count": len( + synthesis_result.get("key_sources", []) + ), + "search_cost": search_cost, + "search_cost_provider": "exa" + if search_provider + and search_provider.lower() in ["exa", "both"] + else None, + } + } + ) + + # Log model metadata for cross-pipeline tracking + log_metadata( + metadata={ + "search_metrics": { + "confidence_level": synthesis_result.get( + "confidence_level", "low" + ), + "search_provider": search_provider, + } + }, + infer_model=True, + ) + + # Log artifact metadata for the output artifacts + log_metadata( + metadata={ + "search_data_characteristics": { + "sub_question": sub_question, + "num_results": len(results_list), + "search_provider": search_provider, + "search_cost": search_cost if search_cost > 0 else None, + } + }, + artifact_name="search_data", + infer_artifact=True, + ) + + log_metadata( + metadata={ + "synthesis_data_characteristics": { + "sub_question": sub_question, + "confidence_level": synthesis_result.get( + "confidence_level", "low" + ), + "has_information_gaps": bool( + synthesis_result.get("information_gaps") + ), + "num_key_sources": len( + synthesis_result.get("key_sources", []) + ), + } + }, + artifact_name="synthesis_data", + infer_artifact=True, + ) + + # Add tags to the artifacts + # add_tags(tags=["search", "sub-question"], artifact_name="search_data", infer_artifact=True) + # add_tags(tags=["synthesis", "sub-question"], artifact_name="synthesis_data", infer_artifact=True) + + return search_data, synthesis_data diff --git a/deep_research/steps/pydantic_final_report_step.py b/deep_research/steps/pydantic_final_report_step.py new file mode 100644 index 00000000..339d42fb --- /dev/null +++ b/deep_research/steps/pydantic_final_report_step.py @@ -0,0 +1,1104 @@ +"""Final report generation step using artifact-based approach. + +This module provides a ZenML pipeline step for generating the final HTML research report +using the new artifact-based approach. +""" + +import html +import json +import logging +import re +import time +from typing import Annotated, Tuple + +from materializers.final_report_materializer import FinalReportMaterializer +from utils.css_utils import extract_html_from_content, get_shared_css_tag +from utils.llm_utils import remove_reasoning_from_output, run_llm_completion +from utils.prompts import ( + STATIC_HTML_TEMPLATE, + SUB_QUESTION_TEMPLATE, + VIEWPOINT_ANALYSIS_TEMPLATE, +) +from utils.pydantic_models import ( + AnalysisData, + FinalReport, + Prompt, + QueryContext, + SearchData, + SynthesisData, +) +from zenml import log_metadata, step +from zenml.types import HTMLString + +logger = logging.getLogger(__name__) + + +def clean_html_output(html_content: str) -> str: + """Clean HTML output from LLM to ensure proper rendering. + + This function removes markdown code blocks, fixes common issues with LLM HTML output, + and ensures we have proper HTML structure for rendering. + + Args: + html_content: Raw HTML content from LLM + + Returns: + Cleaned HTML content ready for rendering + """ + # Remove markdown code block markers (```html and ```) + html_content = re.sub(r"```html\s*", "", html_content) + html_content = re.sub(r"```\s*$", "", html_content) + html_content = re.sub(r"```", "", html_content) + + # Remove any CSS code block markers + html_content = re.sub(r"```css\s*", "", html_content) + + # Ensure HTML content is properly wrapped in HTML tags if not already + if not html_content.strip().startswith( + "{html_content}' + + html_content = re.sub(r"\[CSS STYLESHEET GOES HERE\]", "", html_content) + html_content = re.sub(r"\[SUB-QUESTIONS LINKS\]", "", html_content) + html_content = re.sub(r"\[ADDITIONAL SECTIONS LINKS\]", "", html_content) + html_content = re.sub(r"\[FOR EACH SUB-QUESTION\]:", "", html_content) + html_content = re.sub(r"\[FOR EACH TENSION\]:", "", html_content) + + # Replace content placeholders with appropriate defaults + html_content = re.sub( + r"\[CONCISE SUMMARY OF KEY FINDINGS\]", + "Summary of findings from the research query.", + html_content, + ) + html_content = re.sub( + r"\[INTRODUCTION TO THE RESEARCH QUERY\]", + "Introduction to the research topic.", + html_content, + ) + html_content = re.sub( + r"\[OVERVIEW OF THE APPROACH AND SUB-QUESTIONS\]", + "Overview of the research approach.", + html_content, + ) + html_content = re.sub( + r"\[CONCLUSION TEXT\]", + "Conclusion of the research findings.", + html_content, + ) + + return html_content + + +def format_text_with_code_blocks(text: str) -> str: + """Format text with proper handling of code blocks and markdown formatting. + + Args: + text: The raw text to format + + Returns: + str: HTML-formatted text + """ + if not text: + return "" + + # Handle code blocks + lines = text.split("\n") + formatted_lines = [] + in_code_block = False + code_language = "" + code_lines = [] + + for line in lines: + # Check for code block start + if line.strip().startswith("```"): + if in_code_block: + # End of code block + code_content = "\n".join(code_lines) + formatted_lines.append( + f'
{html.escape(code_content)}
' + ) + code_lines = [] + in_code_block = False + code_language = "" + else: + # Start of code block + in_code_block = True + # Extract language if specified + code_language = line.strip()[3:].strip() or "plaintext" + elif in_code_block: + code_lines.append(line) + else: + # Process inline code + line = re.sub(r"`([^`]+)`", r"\1", html.escape(line)) + # Process bullet points + if line.strip().startswith("•") or line.strip().startswith("-"): + line = re.sub(r"^(\s*)[•-]\s*", r"\1", line) + formatted_lines.append(f"
  • {line.strip()}
  • ") + elif line.strip(): + formatted_lines.append(f"

    {line}

    ") + + # Handle case where code block wasn't closed + if in_code_block and code_lines: + code_content = "\n".join(code_lines) + formatted_lines.append( + f'
    {html.escape(code_content)}
    ' + ) + + # Wrap list items in ul tags + result = [] + in_list = False + for line in formatted_lines: + if line.startswith("
  • "): + if not in_list: + result.append("") + in_list = False + result.append(line) + + if in_list: + result.append("") + + return "\n".join(result) + + +def generate_executive_summary( + query_context: QueryContext, + synthesis_data: SynthesisData, + analysis_data: AnalysisData, + executive_summary_prompt: Prompt, + llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", + langfuse_project_name: str = "deep-research", +) -> str: + """Generate an executive summary using LLM based on the complete research findings. + + Args: + query_context: The query context with main query and sub-questions + synthesis_data: The synthesis data with all synthesized information + analysis_data: The analysis data with viewpoint analysis + executive_summary_prompt: Prompt for generating executive summary + llm_model: The model to use for generation + langfuse_project_name: Name of the Langfuse project for tracking + + Returns: + HTML formatted executive summary + """ + logger.info("Generating executive summary using LLM") + + # Prepare the context + summary_input = { + "main_query": query_context.main_query, + "sub_questions": query_context.sub_questions, + "key_findings": {}, + "viewpoint_analysis": None, + } + + # Include key findings from synthesis data + # Prefer enhanced info if available + info_source = ( + synthesis_data.enhanced_info + if synthesis_data.enhanced_info + else synthesis_data.synthesized_info + ) + + for question in query_context.sub_questions: + if question in info_source: + info = info_source[question] + summary_input["key_findings"][question] = { + "answer": info.synthesized_answer, + "confidence": info.confidence_level, + "gaps": info.information_gaps, + } + + # Include viewpoint analysis if available + if analysis_data.viewpoint_analysis: + va = analysis_data.viewpoint_analysis + summary_input["viewpoint_analysis"] = { + "agreements": va.main_points_of_agreement, + "tensions": len(va.areas_of_tension), + "insights": va.integrative_insights, + } + + try: + # Call LLM to generate executive summary + result = run_llm_completion( + prompt=json.dumps(summary_input), + system_prompt=str(executive_summary_prompt), + model=llm_model, + temperature=0.7, + max_tokens=600, + project=langfuse_project_name, + tags=["executive_summary_generation"], + ) + + if result: + content = remove_reasoning_from_output(result) + # Clean up the HTML + content = extract_html_from_content(content) + logger.info("Successfully generated LLM-based executive summary") + return content + else: + logger.warning("Failed to generate executive summary via LLM") + return generate_fallback_executive_summary( + query_context, synthesis_data + ) + + except Exception as e: + logger.error(f"Error generating executive summary: {e}") + return generate_fallback_executive_summary( + query_context, synthesis_data + ) + + +def generate_introduction( + query_context: QueryContext, + introduction_prompt: Prompt, + llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", + langfuse_project_name: str = "deep-research", +) -> str: + """Generate an introduction using LLM based on research query and sub-questions. + + Args: + query_context: The query context with main query and sub-questions + introduction_prompt: Prompt for generating introduction + llm_model: The model to use for generation + langfuse_project_name: Name of the Langfuse project for tracking + + Returns: + HTML formatted introduction + """ + logger.info("Generating introduction using LLM") + + # Prepare the context + context = f"Main Research Query: {query_context.main_query}\n\n" + context += "Sub-questions being explored:\n" + for i, sub_question in enumerate(query_context.sub_questions, 1): + context += f"{i}. {sub_question}\n" + + try: + # Call LLM to generate introduction + result = run_llm_completion( + prompt=context, + system_prompt=str(introduction_prompt), + model=llm_model, + temperature=0.7, + max_tokens=600, + project=langfuse_project_name, + tags=["introduction_generation"], + ) + + if result: + content = remove_reasoning_from_output(result) + # Clean up the HTML + content = extract_html_from_content(content) + logger.info("Successfully generated LLM-based introduction") + return content + else: + logger.warning("Failed to generate introduction via LLM") + return generate_fallback_introduction(query_context) + + except Exception as e: + logger.error(f"Error generating introduction: {e}") + return generate_fallback_introduction(query_context) + + +def generate_fallback_executive_summary( + query_context: QueryContext, synthesis_data: SynthesisData +) -> str: + """Generate a fallback executive summary when LLM fails.""" + summary = f"

    This report examines the question: {html.escape(query_context.main_query)}

    " + summary += f"

    The research explored {len(query_context.sub_questions)} key dimensions of this topic, " + summary += "synthesizing findings from multiple sources to provide a comprehensive analysis.

    " + + # Add confidence overview + confidence_counts = {"high": 0, "medium": 0, "low": 0} + info_source = ( + synthesis_data.enhanced_info + if synthesis_data.enhanced_info + else synthesis_data.synthesized_info + ) + for info in info_source.values(): + level = info.confidence_level.lower() + if level in confidence_counts: + confidence_counts[level] += 1 + + summary += f"

    Overall confidence in findings: {confidence_counts['high']} high, " + summary += f"{confidence_counts['medium']} medium, {confidence_counts['low']} low.

    " + + return summary + + +def generate_fallback_introduction(query_context: QueryContext) -> str: + """Generate a fallback introduction when LLM fails.""" + intro = f"

    This report addresses the research query: {html.escape(query_context.main_query)}

    " + intro += f"

    The research was conducted by breaking down the main query into {len(query_context.sub_questions)} " + intro += ( + "sub-questions to explore different aspects of the topic in depth. " + ) + intro += "Each sub-question was researched independently, with findings synthesized from various sources.

    " + return intro + + +def generate_conclusion( + query_context: QueryContext, + synthesis_data: SynthesisData, + analysis_data: AnalysisData, + conclusion_generation_prompt: Prompt, + llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", + langfuse_project_name: str = "deep-research", +) -> str: + """Generate a comprehensive conclusion using LLM based on all research findings. + + Args: + query_context: The query context with main query and sub-questions + synthesis_data: The synthesis data with all synthesized information + analysis_data: The analysis data with viewpoint analysis + conclusion_generation_prompt: Prompt for generating conclusion + llm_model: The model to use for conclusion generation + langfuse_project_name: Name of the Langfuse project for tracking + + Returns: + str: HTML-formatted conclusion content + """ + logger.info("Generating comprehensive conclusion using LLM") + + # Prepare input data for conclusion generation + conclusion_input = { + "main_query": query_context.main_query, + "sub_questions": query_context.sub_questions, + "enhanced_info": {}, + } + + # Include enhanced information for each sub-question + info_source = ( + synthesis_data.enhanced_info + if synthesis_data.enhanced_info + else synthesis_data.synthesized_info + ) + + for question in query_context.sub_questions: + if question in info_source: + info = info_source[question] + conclusion_input["enhanced_info"][question] = { + "synthesized_answer": info.synthesized_answer, + "confidence_level": info.confidence_level, + "information_gaps": info.information_gaps, + "key_sources": info.key_sources, + "improvements": getattr(info, "improvements", []), + } + + # Include viewpoint analysis + if analysis_data.viewpoint_analysis: + va = analysis_data.viewpoint_analysis + conclusion_input["viewpoint_analysis"] = { + "main_points_of_agreement": va.main_points_of_agreement, + "areas_of_tension": [ + {"topic": t.topic, "viewpoints": t.viewpoints} + for t in va.areas_of_tension + ], + "integrative_insights": va.integrative_insights, + } + + # Include reflection metadata if available + if analysis_data.reflection_metadata: + rm = analysis_data.reflection_metadata + conclusion_input["reflection_insights"] = { + "improvements_made": rm.improvements_made, + "additional_questions_identified": rm.additional_questions_identified, + } + + try: + # Call LLM to generate conclusion + result = run_llm_completion( + prompt=json.dumps(conclusion_input), + system_prompt=str(conclusion_generation_prompt), + model=llm_model, + temperature=0.7, + max_tokens=800, + project=langfuse_project_name, + tags=["conclusion_generation"], + ) + + if result: + content = remove_reasoning_from_output(result) + # Clean up the HTML + content = extract_html_from_content(content) + logger.info("Successfully generated LLM-based conclusion") + return content + else: + logger.warning("Failed to generate conclusion via LLM") + return generate_fallback_conclusion(query_context, synthesis_data) + + except Exception as e: + logger.error(f"Error generating conclusion: {e}") + return generate_fallback_conclusion(query_context, synthesis_data) + + +def generate_fallback_conclusion( + query_context: QueryContext, synthesis_data: SynthesisData +) -> str: + """Generate a fallback conclusion when LLM fails. + + Args: + query_context: The query context with main query and sub-questions + synthesis_data: The synthesis data with all synthesized information + + Returns: + str: Basic HTML-formatted conclusion + """ + conclusion = f"

    This research has explored the question: {html.escape(query_context.main_query)}

    " + conclusion += f"

    Through systematic investigation of {len(query_context.sub_questions)} sub-questions, " + conclusion += ( + "we have gathered insights from multiple sources and perspectives.

    " + ) + + # Add a summary of confidence levels + info_source = ( + synthesis_data.enhanced_info + if synthesis_data.enhanced_info + else synthesis_data.synthesized_info + ) + high_confidence = sum( + 1 + for info in info_source.values() + if info.confidence_level.lower() == "high" + ) + + if high_confidence > 0: + conclusion += f"

    The research yielded {high_confidence} high-confidence findings out of " + conclusion += f"{len(info_source)} total areas investigated.

    " + + conclusion += "

    Further research may be beneficial to address remaining information gaps " + conclusion += "and explore emerging questions identified during this investigation.

    " + + return conclusion + + +def generate_report_from_template( + query_context: QueryContext, + search_data: SearchData, + synthesis_data: SynthesisData, + analysis_data: AnalysisData, + conclusion_generation_prompt: Prompt, + executive_summary_prompt: Prompt, + introduction_prompt: Prompt, + llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", + langfuse_project_name: str = "deep-research", +) -> str: + """Generate a final HTML report from a static template. + + Instead of using an LLM to generate HTML, this function uses predefined HTML + templates and populates them with data from the research artifacts. + + Args: + query_context: The query context with main query and sub-questions + search_data: The search data (for source information) + synthesis_data: The synthesis data with all synthesized information + analysis_data: The analysis data with viewpoint analysis + conclusion_generation_prompt: Prompt for generating conclusion + executive_summary_prompt: Prompt for generating executive summary + introduction_prompt: Prompt for generating introduction + llm_model: The model to use for conclusion generation + langfuse_project_name: Name of the Langfuse project for tracking + + Returns: + str: The HTML content of the report + """ + logger.info( + f"Generating templated HTML report for query: {query_context.main_query}" + ) + + # Generate table of contents for sub-questions + sub_questions_toc = "" + for i, question in enumerate(query_context.sub_questions, 1): + safe_id = f"question-{i}" + sub_questions_toc += ( + f'
  • {html.escape(question)}
  • \n' + ) + + # Add viewpoint analysis to TOC if available + additional_sections_toc = "" + if analysis_data.viewpoint_analysis: + additional_sections_toc += ( + '
  • Viewpoint Analysis
  • \n' + ) + + # Generate HTML for sub-questions + sub_questions_html = "" + all_sources = set() + + # Determine which info source to use (merge original with enhanced) + # Start with the original synthesized info + info_source = synthesis_data.synthesized_info.copy() + + # Override with enhanced info where available + if synthesis_data.enhanced_info: + info_source.update(synthesis_data.enhanced_info) + + # Debug logging + logger.info( + f"Synthesis data has enhanced_info: {bool(synthesis_data.enhanced_info)}" + ) + logger.info( + f"Synthesis data has synthesized_info: {bool(synthesis_data.synthesized_info)}" + ) + logger.info(f"Info source has {len(info_source)} entries") + logger.info(f"Processing {len(query_context.sub_questions)} sub-questions") + + # Log the keys in info_source for debugging + if info_source: + logger.info( + f"Keys in info_source: {list(info_source.keys())[:3]}..." + ) # First 3 keys + logger.info( + f"Sub-questions from query_context: {query_context.sub_questions[:3]}..." + ) # First 3 + + for i, question in enumerate(query_context.sub_questions, 1): + info = info_source.get(question, None) + + # Skip if no information is available + if not info: + logger.warning( + f"No synthesis info found for question {i}: {question}" + ) + continue + + # Process confidence level + confidence = info.confidence_level.lower() + confidence_upper = info.confidence_level.upper() + + # Process key sources + key_sources_html = "" + if info.key_sources: + all_sources.update(info.key_sources) + sources_list = "\n".join( + [ + f'
  • {html.escape(source)}
  • ' + if source.startswith(("http://", "https://")) + else f"
  • {html.escape(source)}
  • " + for source in info.key_sources + ] + ) + key_sources_html = f""" +
    +

    📚 Key Sources

    + +
    + """ + + # Process information gaps + info_gaps_html = "" + if info.information_gaps: + info_gaps_html = f""" +
    +

    🧩 Information Gaps

    +

    {format_text_with_code_blocks(info.information_gaps)}

    +
    + """ + + # Determine confidence icon based on level + confidence_icon = "🔴" # Default (low) + if confidence_upper == "HIGH": + confidence_icon = "🟢" + elif confidence_upper == "MEDIUM": + confidence_icon = "🟡" + + # Format the subquestion section using the template + sub_question_html = SUB_QUESTION_TEMPLATE.format( + index=i, + question=html.escape(question), + confidence=confidence, + confidence_upper=confidence_upper, + confidence_icon=confidence_icon, + answer=format_text_with_code_blocks(info.synthesized_answer), + key_sources_html=key_sources_html, + info_gaps_html=info_gaps_html, + ) + + sub_questions_html += sub_question_html + + # Generate viewpoint analysis HTML if available + viewpoint_analysis_html = "" + if analysis_data.viewpoint_analysis: + va = analysis_data.viewpoint_analysis + # Format tensions + tensions_html = "" + for tension in va.areas_of_tension: + viewpoints_list = "\n".join( + [ + f"
  • {html.escape(viewpoint)}: {html.escape(description)}
  • " + for viewpoint, description in tension.viewpoints.items() + ] + ) + tensions_html += f""" +
    +

    {html.escape(tension.topic)}

    + +
    + """ + + # Format agreements (just the list items) + agreements_html = "" + if va.main_points_of_agreement: + agreements_html = "\n".join( + [ + f"
  • {html.escape(point)}
  • " + for point in va.main_points_of_agreement + ] + ) + + # Get perspective gaps if available + perspective_gaps = "" + if hasattr(va, "perspective_gaps") and va.perspective_gaps: + perspective_gaps = va.perspective_gaps + else: + perspective_gaps = "No significant perspective gaps identified." + + # Get integrative insights + integrative_insights = "" + if va.integrative_insights: + integrative_insights = format_text_with_code_blocks( + va.integrative_insights + ) + + viewpoint_analysis_html = VIEWPOINT_ANALYSIS_TEMPLATE.format( + agreements_html=agreements_html, + tensions_html=tensions_html, + perspective_gaps=perspective_gaps, + integrative_insights=integrative_insights, + ) + + # Generate references section + references_html = '" + + # Generate dynamic executive summary using LLM + logger.info("Generating dynamic executive summary...") + executive_summary = generate_executive_summary( + query_context, + synthesis_data, + analysis_data, + executive_summary_prompt, + llm_model, + langfuse_project_name, + ) + logger.info( + f"Executive summary generated: {len(executive_summary)} characters" + ) + + # Generate dynamic introduction using LLM + logger.info("Generating dynamic introduction...") + introduction_html = generate_introduction( + query_context, introduction_prompt, llm_model, langfuse_project_name + ) + logger.info(f"Introduction generated: {len(introduction_html)} characters") + + # Generate comprehensive conclusion using LLM + conclusion_html = generate_conclusion( + query_context, + synthesis_data, + analysis_data, + conclusion_generation_prompt, + llm_model, + langfuse_project_name, + ) + + # Generate complete HTML report + html_content = STATIC_HTML_TEMPLATE.format( + main_query=html.escape(query_context.main_query), + shared_css=get_shared_css_tag(), + sub_questions_toc=sub_questions_toc, + additional_sections_toc=additional_sections_toc, + executive_summary=executive_summary, + introduction_html=introduction_html, + num_sub_questions=len(query_context.sub_questions), + sub_questions_html=sub_questions_html, + viewpoint_analysis_html=viewpoint_analysis_html, + conclusion_html=conclusion_html, + references_html=references_html, + ) + + return html_content + + +def _generate_fallback_report( + query_context: QueryContext, + synthesis_data: SynthesisData, + analysis_data: AnalysisData, +) -> str: + """Generate a minimal fallback report when the main report generation fails. + + This function creates a simplified HTML report with a consistent structure when + the main report generation process encounters an error. + + Args: + query_context: The query context with main query and sub-questions + synthesis_data: The synthesis data with all synthesized information + analysis_data: The analysis data with viewpoint analysis + + Returns: + str: A basic HTML report with a standard research report structure + """ + # Create a simple HTML structure with embedded CSS for styling + html = f""" + + + + + + Research Report - {html.escape(query_context.main_query)} + + +
    +

    Research Report: {html.escape(query_context.main_query)}

    + +
    + Note: This is a simplified version of the report generated due to processing limitations. +
    + +
    +

    Introduction

    +

    This report investigates the research query: {html.escape(query_context.main_query)}

    +

    The investigation was structured around {len(query_context.sub_questions)} key sub-questions to provide comprehensive coverage of the topic.

    +
    + +
    +

    Research Findings

    +""" + + # Add findings for each sub-question + info_source = ( + synthesis_data.enhanced_info + if synthesis_data.enhanced_info + else synthesis_data.synthesized_info + ) + + for i, question in enumerate(query_context.sub_questions, 1): + if question in info_source: + info = info_source[question] + confidence_class = info.confidence_level.lower() + + html += f""" +
    +

    {i}. {html.escape(question)}

    + Confidence: {info.confidence_level.upper()} +

    {html.escape(info.synthesized_answer)}

    + """ + + if info.information_gaps: + html += f"

    Information Gaps: {html.escape(info.information_gaps)}

    " + + html += "
    " + + html += """ +
    + +
    +

    Conclusion

    +

    This research has provided insights into the various aspects of the main query through systematic investigation.

    +

    The findings represent a synthesis of available information, with varying levels of confidence across different areas.

    +
    + +
    +

    References

    +

    Sources were gathered from various search providers and synthesized to create this report.

    +
    +
    + +""" + + return html + + +@step( + output_materializers={ + "final_report": FinalReportMaterializer, + } +) +def pydantic_final_report_step( + query_context: QueryContext, + search_data: SearchData, + synthesis_data: SynthesisData, + analysis_data: AnalysisData, + conclusion_generation_prompt: Prompt, + executive_summary_prompt: Prompt, + introduction_prompt: Prompt, + use_static_template: bool = True, + llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", + langfuse_project_name: str = "deep-research", +) -> Tuple[ + Annotated[FinalReport, "final_report"], + Annotated[HTMLString, "report_html"], +]: + """Generate the final research report in HTML format using artifact-based approach. + + This step uses the individual artifacts to generate a final HTML report. + + Args: + query_context: The query context with main query and sub-questions + search_data: The search data (for source information) + synthesis_data: The synthesis data with all synthesized information + analysis_data: The analysis data with viewpoint analysis and reflection metadata + conclusion_generation_prompt: Prompt for generating conclusions + executive_summary_prompt: Prompt for generating executive summary + introduction_prompt: Prompt for generating introduction + use_static_template: Whether to use a static template instead of LLM generation + llm_model: The model to use for report generation with provider prefix + langfuse_project_name: Name of the Langfuse project for tracking + + Returns: + A tuple containing the FinalReport artifact and the HTML report string + """ + start_time = time.time() + logger.info( + "Generating final research report using artifact-based approach" + ) + + if use_static_template: + # Use the static HTML template approach + logger.info("Using static HTML template for report generation") + html_content = generate_report_from_template( + query_context, + search_data, + synthesis_data, + analysis_data, + conclusion_generation_prompt, + executive_summary_prompt, + introduction_prompt, + llm_model, + langfuse_project_name, + ) + + # Create the FinalReport artifact + final_report = FinalReport( + report_html=html_content, + main_query=query_context.main_query, + ) + + # Calculate execution time + execution_time = time.time() - start_time + + # Calculate report metrics + info_source = ( + synthesis_data.enhanced_info + if synthesis_data.enhanced_info + else synthesis_data.synthesized_info + ) + confidence_distribution = {"high": 0, "medium": 0, "low": 0} + for info in info_source.values(): + level = info.confidence_level.lower() + if level in confidence_distribution: + confidence_distribution[level] += 1 + + # Count various elements in the report + num_sources = len( + set( + source + for info in info_source.values() + for source in info.key_sources + ) + ) + has_viewpoint_analysis = analysis_data.viewpoint_analysis is not None + has_reflection_insights = ( + analysis_data.reflection_metadata is not None + and analysis_data.reflection_metadata.improvements_made > 0 + ) + + # Log step metadata + log_metadata( + metadata={ + "final_report_generation": { + "execution_time_seconds": execution_time, + "use_static_template": use_static_template, + "llm_model": llm_model, + "main_query_length": len(query_context.main_query), + "num_sub_questions": len(query_context.sub_questions), + "num_synthesized_answers": len(info_source), + "has_enhanced_info": bool(synthesis_data.enhanced_info), + "confidence_distribution": confidence_distribution, + "num_unique_sources": num_sources, + "has_viewpoint_analysis": has_viewpoint_analysis, + "has_reflection_insights": has_reflection_insights, + "report_length_chars": len(html_content), + "report_generation_success": True, + } + } + ) + + # Log artifact metadata + log_metadata( + metadata={ + "final_report_characteristics": { + "report_length": len(html_content), + "main_query": query_context.main_query, + "num_sections": len(query_context.sub_questions) + + (1 if has_viewpoint_analysis else 0), + "has_executive_summary": True, + "has_introduction": True, + "has_conclusion": True, + } + }, + artifact_name="final_report", + infer_artifact=True, + ) + + # Add tags to the artifact + # add_tags(tags=["report", "final", "html"], artifact_name="final_report", infer_artifact=True) + + logger.info( + f"Successfully generated final report ({len(html_content)} characters)" + ) + return final_report, HTMLString(html_content) + + else: + # Handle non-static template case (future implementation) + logger.warning( + "Non-static template generation not yet implemented, falling back to static template" + ) + return pydantic_final_report_step( + query_context=query_context, + search_data=search_data, + synthesis_data=synthesis_data, + analysis_data=analysis_data, + conclusion_generation_prompt=conclusion_generation_prompt, + executive_summary_prompt=executive_summary_prompt, + introduction_prompt=introduction_prompt, + use_static_template=True, + llm_model=llm_model, + langfuse_project_name=langfuse_project_name, + ) diff --git a/deep_research/steps/query_decomposition_step.py b/deep_research/steps/query_decomposition_step.py new file mode 100644 index 00000000..053a8fd4 --- /dev/null +++ b/deep_research/steps/query_decomposition_step.py @@ -0,0 +1,186 @@ +import logging +import time +from typing import Annotated + +from materializers.query_context_materializer import QueryContextMaterializer +from utils.llm_utils import get_structured_llm_output +from utils.pydantic_models import Prompt, QueryContext +from zenml import log_metadata, step + +logger = logging.getLogger(__name__) + + +@step(output_materializers=QueryContextMaterializer) +def initial_query_decomposition_step( + main_query: str, + query_decomposition_prompt: Prompt, + llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", + max_sub_questions: int = 8, + langfuse_project_name: str = "deep-research", +) -> Annotated[QueryContext, "query_context"]: + """Break down a complex research query into specific sub-questions. + + Args: + main_query: The main research query to decompose + query_decomposition_prompt: Prompt for query decomposition + llm_model: The reasoning model to use with provider prefix + max_sub_questions: Maximum number of sub-questions to generate + langfuse_project_name: Project name for tracing + + Returns: + QueryContext containing the main query and decomposed sub-questions + """ + start_time = time.time() + logger.info(f"Decomposing research query: {main_query}") + + # Get the prompt content + system_prompt = str(query_decomposition_prompt) + + try: + # Call OpenAI API to decompose the query + updated_system_prompt = ( + system_prompt + + f"\nPlease generate at most {max_sub_questions} sub-questions." + ) + logger.info( + f"Calling {llm_model} to decompose query into max {max_sub_questions} sub-questions" + ) + + # Define fallback questions + fallback_questions = [ + { + "sub_question": f"What is {main_query}?", + "reasoning": "Basic understanding of the topic", + }, + { + "sub_question": f"What are the key aspects of {main_query}?", + "reasoning": "Exploring important dimensions", + }, + { + "sub_question": f"What are the implications of {main_query}?", + "reasoning": "Understanding broader impact", + }, + ] + + # Use utility function to get structured output + decomposed_questions = get_structured_llm_output( + prompt=main_query, + system_prompt=updated_system_prompt, + model=llm_model, + fallback_response=fallback_questions, + project=langfuse_project_name, + ) + + # Extract just the sub-questions + sub_questions = [ + item.get("sub_question") + for item in decomposed_questions + if "sub_question" in item + ] + + # Limit to max_sub_questions + sub_questions = sub_questions[:max_sub_questions] + + logger.info(f"Generated {len(sub_questions)} sub-questions") + for i, question in enumerate(sub_questions, 1): + logger.info(f" {i}. {question}") + + # Create the QueryContext + query_context = QueryContext( + main_query=main_query, sub_questions=sub_questions + ) + + # Log step metadata + execution_time = time.time() - start_time + log_metadata( + metadata={ + "query_decomposition": { + "execution_time_seconds": execution_time, + "num_sub_questions": len(sub_questions), + "llm_model": llm_model, + "max_sub_questions_requested": max_sub_questions, + "fallback_used": False, + "main_query_length": len(main_query), + "sub_questions": sub_questions, + } + } + ) + + # Log model metadata for cross-pipeline tracking + log_metadata( + metadata={ + "research_scope": { + "num_sub_questions": len(sub_questions), + } + }, + infer_model=True, + ) + + # Log artifact metadata for the output query context + log_metadata( + metadata={ + "query_context_characteristics": { + "main_query": main_query, + "num_sub_questions": len(sub_questions), + "timestamp": query_context.decomposition_timestamp, + } + }, + infer_artifact=True, + ) + + # Add tags to the artifact + # add_tags(tags=["query", "decomposed"], artifact_name="query_context", infer_artifact=True) + + return query_context + + except Exception as e: + logger.error(f"Error decomposing query: {e}") + # Return fallback questions + fallback_questions = [ + f"What is {main_query}?", + f"What are the key aspects of {main_query}?", + f"What are the implications of {main_query}?", + ] + fallback_questions = fallback_questions[:max_sub_questions] + logger.info(f"Using {len(fallback_questions)} fallback questions:") + for i, question in enumerate(fallback_questions, 1): + logger.info(f" {i}. {question}") + + # Create QueryContext with fallback questions + query_context = QueryContext( + main_query=main_query, sub_questions=fallback_questions + ) + + # Log metadata for fallback scenario + execution_time = time.time() - start_time + log_metadata( + metadata={ + "query_decomposition": { + "execution_time_seconds": execution_time, + "num_sub_questions": len(fallback_questions), + "llm_model": llm_model, + "max_sub_questions_requested": max_sub_questions, + "fallback_used": True, + "error_message": str(e), + "main_query_length": len(main_query), + "sub_questions": fallback_questions, + } + } + ) + + # Log model metadata for cross-pipeline tracking + log_metadata( + metadata={ + "research_scope": { + "num_sub_questions": len(fallback_questions), + } + }, + infer_model=True, + ) + + # Add tags to the artifact + # add_tags( + # tags=["query", "decomposed", "fallback"], artifact_name="query_context", infer_artifact=True + # ) + + return query_context diff --git a/deep_research/tests/__init__.py b/deep_research/tests/__init__.py new file mode 100644 index 00000000..6206856b --- /dev/null +++ b/deep_research/tests/__init__.py @@ -0,0 +1 @@ +"""Test package for ZenML Deep Research project.""" diff --git a/deep_research/tests/conftest.py b/deep_research/tests/conftest.py new file mode 100644 index 00000000..b972a5e1 --- /dev/null +++ b/deep_research/tests/conftest.py @@ -0,0 +1,11 @@ +"""Test configuration for pytest. + +This file sets up the proper Python path for importing modules in tests. +""" + +import os +import sys + +# Add the project root directory to the Python path +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.insert(0, project_root) diff --git a/deep_research/tests/test_approval_utils.py b/deep_research/tests/test_approval_utils.py new file mode 100644 index 00000000..f1dd15a5 --- /dev/null +++ b/deep_research/tests/test_approval_utils.py @@ -0,0 +1,120 @@ +"""Unit tests for approval utility functions.""" + +from utils.approval_utils import ( + calculate_estimated_cost, + format_approval_request, + format_critique_summary, + format_query_list, + parse_approval_response, +) + + +def test_parse_approval_responses(): + """Test parsing different approval responses.""" + queries = ["query1", "query2", "query3"] + + # Test approve all + decision = parse_approval_response("APPROVE ALL", queries) + assert decision.approved == True + assert decision.selected_queries == queries + assert decision.approval_method == "APPROVE_ALL" + + # Test skip + decision = parse_approval_response( + "skip", queries + ) # Test case insensitive + assert decision.approved == False + assert decision.selected_queries == [] + assert decision.approval_method == "SKIP" + + # Test selection + decision = parse_approval_response("SELECT 1,3", queries) + assert decision.approved == True + assert decision.selected_queries == ["query1", "query3"] + assert decision.approval_method == "SELECT_SPECIFIC" + + # Test invalid selection + decision = parse_approval_response("SELECT invalid", queries) + assert decision.approved == False + assert decision.approval_method == "PARSE_ERROR" + + # Test out of range indices + decision = parse_approval_response("SELECT 1,5,10", queries) + assert decision.approved == True + assert decision.selected_queries == ["query1"] # Only valid indices + assert decision.approval_method == "SELECT_SPECIFIC" + + # Test unknown response + decision = parse_approval_response("maybe later", queries) + assert decision.approved == False + assert decision.approval_method == "UNKNOWN_RESPONSE" + + +def test_format_approval_request(): + """Test formatting of approval request messages.""" + message = format_approval_request( + main_query="Test query", + progress_summary={ + "completed_count": 5, + "avg_confidence": 0.75, + "low_confidence_count": 2, + }, + critique_points=[ + {"issue": "Missing data", "importance": "high"}, + {"issue": "Minor gap", "importance": "low"}, + ], + proposed_queries=["query1", "query2"], + ) + + assert "Test query" in message + assert "5" in message + assert "0.75" in message + assert "2 queries" in message + assert "approve" in message.lower() + assert "reject" in message.lower() + assert "Missing data" in message + + +def test_format_critique_summary(): + """Test critique summary formatting.""" + # Test with no critiques + result = format_critique_summary([]) + assert result == "No critical issues identified." + + # Test with few critiques + critiques = [{"issue": "Issue 1"}, {"issue": "Issue 2"}] + result = format_critique_summary(critiques) + assert "- Issue 1" in result + assert "- Issue 2" in result + assert "more issues" not in result + + # Test with many critiques + critiques = [{"issue": f"Issue {i}"} for i in range(5)] + result = format_critique_summary(critiques) + assert "- Issue 0" in result + assert "- Issue 1" in result + assert "- Issue 2" in result + assert "- Issue 3" not in result + assert "... and 2 more issues" in result + + +def test_format_query_list(): + """Test query list formatting.""" + # Test empty list + result = format_query_list([]) + assert result == "No queries proposed." + + # Test with queries + queries = ["Query A", "Query B", "Query C"] + result = format_query_list(queries) + assert "1. Query A" in result + assert "2. Query B" in result + assert "3. Query C" in result + + +def test_calculate_estimated_cost(): + """Test cost estimation.""" + assert calculate_estimated_cost([]) == 0.0 + assert calculate_estimated_cost(["q1"]) == 0.01 + assert calculate_estimated_cost(["q1", "q2", "q3"]) == 0.03 + assert calculate_estimated_cost(["q1"] * 10) == 0.10 diff --git a/deep_research/tests/test_artifact_models.py b/deep_research/tests/test_artifact_models.py new file mode 100644 index 00000000..415862b5 --- /dev/null +++ b/deep_research/tests/test_artifact_models.py @@ -0,0 +1,210 @@ +"""Tests for the new artifact models.""" + +import time + +import pytest +from utils.pydantic_models import ( + AnalysisData, + FinalReport, + QueryContext, + ReflectionMetadata, + SearchCostDetail, + SearchData, + SearchResult, + SynthesisData, + SynthesizedInfo, + ViewpointAnalysis, +) + + +class TestQueryContext: + """Test the QueryContext artifact.""" + + def test_query_context_creation(self): + """Test creating a QueryContext.""" + query = QueryContext( + main_query="What is quantum computing?", + sub_questions=["What are qubits?", "How do quantum gates work?"], + ) + + assert query.main_query == "What is quantum computing?" + assert len(query.sub_questions) == 2 + assert query.decomposition_timestamp > 0 + + def test_query_context_immutable(self): + """Test that QueryContext is immutable.""" + query = QueryContext(main_query="Test query", sub_questions=[]) + + # Should raise error when trying to modify + with pytest.raises(Exception): # Pydantic will raise validation error + query.main_query = "Modified query" + + def test_query_context_defaults(self): + """Test QueryContext with defaults.""" + query = QueryContext(main_query="Test") + assert query.sub_questions == [] + assert query.decomposition_timestamp > 0 + + +class TestSearchData: + """Test the SearchData artifact.""" + + def test_search_data_creation(self): + """Test creating SearchData.""" + search_data = SearchData() + + assert search_data.search_results == {} + assert search_data.search_costs == {} + assert search_data.search_cost_details == [] + assert search_data.total_searches == 0 + + def test_search_data_with_results(self): + """Test SearchData with actual results.""" + result = SearchResult( + url="https://example.com", + content="Test content", + title="Test Title", + ) + + cost_detail = SearchCostDetail( + provider="exa", + query="test query", + cost=0.01, + timestamp=time.time(), + step="process_sub_question", + ) + + search_data = SearchData( + search_results={"Question 1": [result]}, + search_costs={"exa": 0.01}, + search_cost_details=[cost_detail], + total_searches=1, + ) + + assert len(search_data.search_results) == 1 + assert search_data.search_costs["exa"] == 0.01 + assert len(search_data.search_cost_details) == 1 + assert search_data.total_searches == 1 + + def test_search_data_merge(self): + """Test merging SearchData instances.""" + # Create first instance + data1 = SearchData( + search_results={ + "Q1": [SearchResult(url="url1", content="content1")] + }, + search_costs={"exa": 0.01}, + total_searches=1, + ) + + # Create second instance + data2 = SearchData( + search_results={ + "Q1": [SearchResult(url="url2", content="content2")], + "Q2": [SearchResult(url="url3", content="content3")], + }, + search_costs={"exa": 0.02, "tavily": 0.01}, + total_searches=2, + ) + + # Merge + data1.merge(data2) + + # Check results + assert len(data1.search_results["Q1"]) == 2 # Merged Q1 results + assert "Q2" in data1.search_results # Added Q2 + assert data1.search_costs["exa"] == 0.03 # Combined costs + assert data1.search_costs["tavily"] == 0.01 # New provider + assert data1.total_searches == 3 + + +class TestSynthesisData: + """Test the SynthesisData artifact.""" + + def test_synthesis_data_creation(self): + """Test creating SynthesisData.""" + synthesis = SynthesisData() + + assert synthesis.synthesized_info == {} + assert synthesis.enhanced_info == {} + + def test_synthesis_data_with_info(self): + """Test SynthesisData with synthesized info.""" + synth_info = SynthesizedInfo( + synthesized_answer="Test answer", + key_sources=["source1", "source2"], + confidence_level="high", + ) + + synthesis = SynthesisData(synthesized_info={"Q1": synth_info}) + + assert "Q1" in synthesis.synthesized_info + assert synthesis.synthesized_info["Q1"].confidence_level == "high" + + def test_synthesis_data_merge(self): + """Test merging SynthesisData instances.""" + info1 = SynthesizedInfo(synthesized_answer="Answer 1") + info2 = SynthesizedInfo(synthesized_answer="Answer 2") + + data1 = SynthesisData(synthesized_info={"Q1": info1}) + data2 = SynthesisData(synthesized_info={"Q2": info2}) + + data1.merge(data2) + + assert "Q1" in data1.synthesized_info + assert "Q2" in data1.synthesized_info + + +class TestAnalysisData: + """Test the AnalysisData artifact.""" + + def test_analysis_data_creation(self): + """Test creating AnalysisData.""" + analysis = AnalysisData() + + assert analysis.viewpoint_analysis is None + assert analysis.reflection_metadata is None + + def test_analysis_data_with_viewpoint(self): + """Test AnalysisData with viewpoint analysis.""" + viewpoint = ViewpointAnalysis( + main_points_of_agreement=["Point 1", "Point 2"], + perspective_gaps="Some gaps", + ) + + analysis = AnalysisData(viewpoint_analysis=viewpoint) + + assert analysis.viewpoint_analysis is not None + assert len(analysis.viewpoint_analysis.main_points_of_agreement) == 2 + + def test_analysis_data_with_reflection(self): + """Test AnalysisData with reflection metadata.""" + reflection = ReflectionMetadata( + critique_summary=["Critique 1"], improvements_made=3.0 + ) + + analysis = AnalysisData(reflection_metadata=reflection) + + assert analysis.reflection_metadata is not None + assert analysis.reflection_metadata.improvements_made == 3.0 + + +class TestFinalReport: + """Test the FinalReport artifact.""" + + def test_final_report_creation(self): + """Test creating FinalReport.""" + report = FinalReport() + + assert report.report_html == "" + assert report.generated_at > 0 + assert report.main_query == "" + + def test_final_report_with_content(self): + """Test FinalReport with HTML content.""" + html = "Test Report" + report = FinalReport(report_html=html, main_query="What is AI?") + + assert report.report_html == html + assert report.main_query == "What is AI?" + assert report.generated_at > 0 diff --git a/deep_research/tests/test_prompt_models.py b/deep_research/tests/test_prompt_models.py new file mode 100644 index 00000000..fcc437d4 --- /dev/null +++ b/deep_research/tests/test_prompt_models.py @@ -0,0 +1,110 @@ +"""Unit tests for prompt models and utilities.""" + +from utils.prompt_models import PromptTemplate +from utils.pydantic_models import Prompt + + +class TestPromptTemplate: + """Test cases for PromptTemplate model.""" + + def test_prompt_template_creation(self): + """Test creating a prompt template with all fields.""" + prompt = PromptTemplate( + name="test_prompt", + content="This is a test prompt", + description="A test prompt for unit testing", + version="1.0.0", + tags=["test", "unit"], + ) + + assert prompt.name == "test_prompt" + assert prompt.content == "This is a test prompt" + assert prompt.description == "A test prompt for unit testing" + assert prompt.version == "1.0.0" + assert prompt.tags == ["test", "unit"] + + def test_prompt_template_minimal(self): + """Test creating a prompt template with minimal fields.""" + prompt = PromptTemplate( + name="minimal_prompt", content="Minimal content" + ) + + assert prompt.name == "minimal_prompt" + assert prompt.content == "Minimal content" + assert prompt.description == "" + assert prompt.version == "1.0.0" + assert prompt.tags == [] + + +class TestPrompt: + """Test cases for the new Prompt model.""" + + def test_prompt_creation(self): + """Test creating a prompt with all fields.""" + prompt = Prompt( + name="test_prompt", + content="This is a test prompt", + description="A test prompt for unit testing", + version="1.0.0", + tags=["test", "unit"], + ) + + assert prompt.name == "test_prompt" + assert prompt.content == "This is a test prompt" + assert prompt.description == "A test prompt for unit testing" + assert prompt.version == "1.0.0" + assert prompt.tags == ["test", "unit"] + + def test_prompt_minimal(self): + """Test creating a prompt with minimal fields.""" + prompt = Prompt(name="minimal_prompt", content="Minimal content") + + assert prompt.name == "minimal_prompt" + assert prompt.content == "Minimal content" + assert prompt.description == "" + assert prompt.version == "1.0.0" + assert prompt.tags == [] + + def test_prompt_str_conversion(self): + """Test converting prompt to string returns content.""" + prompt = Prompt( + name="test_prompt", + content="This is the prompt content", + description="Test prompt", + ) + + assert str(prompt) == "This is the prompt content" + + def test_prompt_repr(self): + """Test prompt representation.""" + prompt = Prompt(name="test_prompt", content="Content", version="2.0.0") + + assert repr(prompt) == "Prompt(name='test_prompt', version='2.0.0')" + + def test_prompt_create_factory(self): + """Test creating prompt using factory method.""" + prompt = Prompt.create( + content="Factory created prompt", + name="factory_prompt", + description="Created via factory", + version="1.1.0", + tags=["factory", "test"], + ) + + assert prompt.name == "factory_prompt" + assert prompt.content == "Factory created prompt" + assert prompt.description == "Created via factory" + assert prompt.version == "1.1.0" + assert prompt.tags == ["factory", "test"] + + def test_prompt_create_factory_minimal(self): + """Test creating prompt using factory method with minimal args.""" + prompt = Prompt.create( + content="Minimal factory prompt", name="minimal_factory" + ) + + assert prompt.name == "minimal_factory" + assert prompt.content == "Minimal factory prompt" + assert prompt.description == "" + assert prompt.version == "1.0.0" + assert prompt.tags == [] diff --git a/deep_research/tests/test_pydantic_final_report_step.py b/deep_research/tests/test_pydantic_final_report_step.py new file mode 100644 index 00000000..b4dcd956 --- /dev/null +++ b/deep_research/tests/test_pydantic_final_report_step.py @@ -0,0 +1,265 @@ +"""Tests for the Pydantic-based final report step. + +This module contains tests for the Pydantic-based implementation of +final_report_step, which uses the new Pydantic models and materializers. +""" + +from typing import Dict, List + +import pytest +from steps.pydantic_final_report_step import pydantic_final_report_step +from utils.pydantic_models import ( + AnalysisData, + FinalReport, + Prompt, + QueryContext, + ReflectionMetadata, + SearchData, + SearchResult, + SynthesisData, + SynthesizedInfo, + ViewpointAnalysis, + ViewpointTension, +) +from zenml.types import HTMLString + + +@pytest.fixture +def sample_artifacts(): + """Create sample artifacts for testing.""" + # Create QueryContext + query_context = QueryContext( + main_query="What are the impacts of climate change?", + sub_questions=["Economic impacts", "Environmental impacts"], + ) + + # Create SearchData + search_results: Dict[str, List[SearchResult]] = { + "Economic impacts": [ + SearchResult( + url="https://example.com/economy", + title="Economic Impacts of Climate Change", + snippet="Overview of economic impacts", + content="Detailed content about economic impacts of climate change", + ) + ], + "Environmental impacts": [ + SearchResult( + url="https://example.com/environment", + title="Environmental Impacts", + snippet="Environmental impact overview", + content="Content about environmental impacts", + ) + ], + } + search_data = SearchData(search_results=search_results) + + # Create SynthesisData + synthesized_info: Dict[str, SynthesizedInfo] = { + "Economic impacts": SynthesizedInfo( + synthesized_answer="Climate change will have significant economic impacts...", + key_sources=["https://example.com/economy"], + confidence_level="high", + ), + "Environmental impacts": SynthesizedInfo( + synthesized_answer="Environmental impacts include rising sea levels...", + key_sources=["https://example.com/environment"], + confidence_level="high", + ), + } + synthesis_data = SynthesisData( + synthesized_info=synthesized_info, + enhanced_info=synthesized_info, # Same as synthesized for this test + ) + + # Create AnalysisData + viewpoint_analysis = ViewpointAnalysis( + main_points_of_agreement=[ + "Climate change is happening", + "Action is needed", + ], + areas_of_tension=[ + ViewpointTension( + topic="Economic policy", + viewpoints={ + "Progressive": "Support carbon taxes and regulations", + "Conservative": "Prefer market-based solutions", + }, + ) + ], + perspective_gaps="Indigenous perspectives are underrepresented", + integrative_insights="A balanced approach combining regulations and market incentives may be most effective", + ) + + reflection_metadata = ReflectionMetadata( + critique_summary=["Need more sources for economic impacts"], + additional_questions_identified=[ + "How will climate change affect different regions?" + ], + searches_performed=[ + "economic impacts of climate change", + "regional climate impacts", + ], + improvements_made=2.0, + ) + + analysis_data = AnalysisData( + viewpoint_analysis=viewpoint_analysis, + reflection_metadata=reflection_metadata, + ) + + # Create prompts + conclusion_prompt = Prompt( + name="conclusion_generation", + content="Generate a conclusion based on the research findings.", + ) + executive_summary_prompt = Prompt( + name="executive_summary", content="Generate an executive summary." + ) + introduction_prompt = Prompt( + name="introduction", content="Generate an introduction." + ) + + return { + "query_context": query_context, + "search_data": search_data, + "synthesis_data": synthesis_data, + "analysis_data": analysis_data, + "conclusion_generation_prompt": conclusion_prompt, + "executive_summary_prompt": executive_summary_prompt, + "introduction_prompt": introduction_prompt, + } + + +def test_pydantic_final_report_step_returns_tuple(): + """Test that the step returns a tuple with FinalReport and HTML.""" + # Create simple artifacts + query_context = QueryContext( + main_query="What is climate change?", + sub_questions=["What causes climate change?"], + ) + search_data = SearchData() + synthesis_data = SynthesisData( + synthesized_info={ + "What causes climate change?": SynthesizedInfo( + synthesized_answer="Climate change is caused by greenhouse gases.", + confidence_level="high", + key_sources=["https://example.com/causes"], + ) + } + ) + analysis_data = AnalysisData() + + # Create prompts + conclusion_prompt = Prompt( + name="conclusion_generation", content="Generate a conclusion." + ) + executive_summary_prompt = Prompt( + name="executive_summary", content="Generate summary." + ) + introduction_prompt = Prompt( + name="introduction", content="Generate intro." + ) + + # Run the step + result = pydantic_final_report_step( + query_context=query_context, + search_data=search_data, + synthesis_data=synthesis_data, + analysis_data=analysis_data, + conclusion_generation_prompt=conclusion_prompt, + executive_summary_prompt=executive_summary_prompt, + introduction_prompt=introduction_prompt, + ) + + # Assert that result is a tuple with 2 elements + assert isinstance(result, tuple) + assert len(result) == 2 + + # Assert first element is FinalReport + assert isinstance(result[0], FinalReport) + + # Assert second element is HTMLString + assert isinstance(result[1], HTMLString) + + +def test_pydantic_final_report_step_with_complex_artifacts(sample_artifacts): + """Test that the step handles complex artifacts properly.""" + # Run the step with complex artifacts + result = pydantic_final_report_step( + query_context=sample_artifacts["query_context"], + search_data=sample_artifacts["search_data"], + synthesis_data=sample_artifacts["synthesis_data"], + analysis_data=sample_artifacts["analysis_data"], + conclusion_generation_prompt=sample_artifacts[ + "conclusion_generation_prompt" + ], + executive_summary_prompt=sample_artifacts["executive_summary_prompt"], + introduction_prompt=sample_artifacts["introduction_prompt"], + ) + + # Unpack the results + final_report, html_report = result + + # Assert FinalReport contains expected data + assert final_report.main_query == "What are the impacts of climate change?" + assert len(final_report.sub_questions) == 2 + assert final_report.report_html != "" + + # Assert HTML report contains key elements + html_str = str(html_report) + assert "Economic impacts" in html_str + assert "Environmental impacts" in html_str + assert "Viewpoint Analysis" in html_str + assert "Progressive" in html_str + assert "Conservative" in html_str + + +def test_pydantic_final_report_step_creates_report(): + """Test that the step properly creates a final report.""" + # Create artifacts + query_context = QueryContext( + main_query="What is climate change?", + sub_questions=["What causes climate change?"], + ) + search_data = SearchData() + synthesis_data = SynthesisData( + synthesized_info={ + "What causes climate change?": SynthesizedInfo( + synthesized_answer="Climate change is caused by greenhouse gases.", + confidence_level="high", + key_sources=["https://example.com/causes"], + ) + } + ) + analysis_data = AnalysisData() + + # Create prompts + conclusion_prompt = Prompt( + name="conclusion_generation", content="Generate a conclusion." + ) + executive_summary_prompt = Prompt( + name="executive_summary", content="Generate summary." + ) + introduction_prompt = Prompt( + name="introduction", content="Generate intro." + ) + + # Run the step + final_report, html_report = pydantic_final_report_step( + query_context=query_context, + search_data=search_data, + synthesis_data=synthesis_data, + analysis_data=analysis_data, + conclusion_generation_prompt=conclusion_prompt, + executive_summary_prompt=executive_summary_prompt, + introduction_prompt=introduction_prompt, + ) + + # Verify FinalReport was created with content + assert final_report.report_html != "" + assert "climate change" in final_report.report_html.lower() + + # Verify HTML report was created + assert str(html_report) != "" + assert "climate change" in str(html_report).lower() diff --git a/deep_research/tests/test_pydantic_models.py b/deep_research/tests/test_pydantic_models.py new file mode 100644 index 00000000..21d25123 --- /dev/null +++ b/deep_research/tests/test_pydantic_models.py @@ -0,0 +1,199 @@ +"""Tests for Pydantic model implementations. + +This module contains tests for the Pydantic models that validate: +1. Basic model instantiation +2. Default values +3. Serialization and deserialization +4. Method functionality +""" + +import json + +from utils.pydantic_models import ( + ReflectionMetadata, + SearchResult, + SynthesizedInfo, + ViewpointAnalysis, + ViewpointTension, +) + + +def test_search_result_creation(): + """Test creating a SearchResult model.""" + # Create with defaults + result = SearchResult() + assert result.url == "" + assert result.content == "" + assert result.title == "" + assert result.snippet == "" + + # Create with values + result = SearchResult( + url="https://example.com", + content="Example content", + title="Example Title", + snippet="This is a snippet", + ) + assert result.url == "https://example.com" + assert result.content == "Example content" + assert result.title == "Example Title" + assert result.snippet == "This is a snippet" + + +def test_search_result_serialization(): + """Test serializing and deserializing a SearchResult.""" + result = SearchResult( + url="https://example.com", + content="Example content", + title="Example Title", + snippet="This is a snippet", + ) + + # Serialize to dict + result_dict = result.model_dump() + assert result_dict["url"] == "https://example.com" + assert result_dict["content"] == "Example content" + + # Serialize to JSON + result_json = result.model_dump_json() + result_dict_from_json = json.loads(result_json) + assert result_dict_from_json["url"] == "https://example.com" + + # Deserialize from dict + new_result = SearchResult.model_validate(result_dict) + assert new_result.url == "https://example.com" + assert new_result.content == "Example content" + + # Deserialize from JSON + new_result_from_json = SearchResult.model_validate_json(result_json) + assert new_result_from_json.url == "https://example.com" + + +def test_viewpoint_tension_model(): + """Test the ViewpointTension model.""" + # Empty model + tension = ViewpointTension() + assert tension.topic == "" + assert tension.viewpoints == {} + + # With data + tension = ViewpointTension( + topic="Climate Change Impacts", + viewpoints={ + "Economic": "Focuses on financial costs and benefits", + "Environmental": "Emphasizes ecosystem impacts", + }, + ) + assert tension.topic == "Climate Change Impacts" + assert len(tension.viewpoints) == 2 + assert "Economic" in tension.viewpoints + + # Serialization + tension_dict = tension.model_dump() + assert tension_dict["topic"] == "Climate Change Impacts" + assert len(tension_dict["viewpoints"]) == 2 + + # Deserialization + new_tension = ViewpointTension.model_validate(tension_dict) + assert new_tension.topic == tension.topic + assert new_tension.viewpoints == tension.viewpoints + + +def test_synthesized_info_model(): + """Test the SynthesizedInfo model.""" + # Default values + info = SynthesizedInfo() + assert info.synthesized_answer == "" + assert info.key_sources == [] + assert info.confidence_level == "medium" + assert info.information_gaps == "" + assert info.improvements == [] + + # With values + info = SynthesizedInfo( + synthesized_answer="This is a synthesized answer", + key_sources=["https://source1.com", "https://source2.com"], + confidence_level="high", + information_gaps="Missing some context", + improvements=["Add more detail", "Check more sources"], + ) + assert info.synthesized_answer == "This is a synthesized answer" + assert len(info.key_sources) == 2 + assert info.confidence_level == "high" + + # Serialization and deserialization + info_dict = info.model_dump() + new_info = SynthesizedInfo.model_validate(info_dict) + assert new_info.synthesized_answer == info.synthesized_answer + assert new_info.key_sources == info.key_sources + + +def test_viewpoint_analysis_model(): + """Test the ViewpointAnalysis model.""" + # Create tensions for the analysis + tension1 = ViewpointTension( + topic="Economic Impact", + viewpoints={ + "Positive": "Creates jobs", + "Negative": "Increases inequality", + }, + ) + tension2 = ViewpointTension( + topic="Environmental Impact", + viewpoints={ + "Positive": "Reduces emissions", + "Negative": "Land use changes", + }, + ) + + # Create the analysis + analysis = ViewpointAnalysis( + main_points_of_agreement=[ + "Need for action", + "Technological innovation", + ], + areas_of_tension=[tension1, tension2], + perspective_gaps="Missing indigenous perspectives", + integrative_insights="Combined economic and environmental approach needed", + ) + + assert len(analysis.main_points_of_agreement) == 2 + assert len(analysis.areas_of_tension) == 2 + assert analysis.areas_of_tension[0].topic == "Economic Impact" + + # Test serialization + analysis_dict = analysis.model_dump() + assert len(analysis_dict["areas_of_tension"]) == 2 + assert analysis_dict["areas_of_tension"][0]["topic"] == "Economic Impact" + + # Test deserialization + new_analysis = ViewpointAnalysis.model_validate(analysis_dict) + assert len(new_analysis.areas_of_tension) == 2 + assert new_analysis.areas_of_tension[0].topic == "Economic Impact" + assert new_analysis.perspective_gaps == analysis.perspective_gaps + + +def test_reflection_metadata_model(): + """Test the ReflectionMetadata model.""" + metadata = ReflectionMetadata( + critique_summary=["Need more sources", "Missing detailed analysis"], + additional_questions_identified=["What about future trends?"], + searches_performed=["future climate trends", "economic impacts"], + improvements_made=3, + error=None, + ) + + assert len(metadata.critique_summary) == 2 + assert len(metadata.additional_questions_identified) == 1 + assert metadata.improvements_made == 3 + assert metadata.error is None + + # Serialization + metadata_dict = metadata.model_dump() + assert len(metadata_dict["critique_summary"]) == 2 + assert metadata_dict["improvements_made"] == 3 + + # Deserialization + new_metadata = ReflectionMetadata.model_validate(metadata_dict) + assert new_metadata.improvements_made == metadata.improvements_made + assert new_metadata.critique_summary == metadata.critique_summary diff --git a/deep_research/utils/__init__.py b/deep_research/utils/__init__.py new file mode 100644 index 00000000..395e1d67 --- /dev/null +++ b/deep_research/utils/__init__.py @@ -0,0 +1,7 @@ +""" +Utilities package for the ZenML Deep Research project. + +This package contains various utility functions and helpers used throughout the project, +including data models, LLM interaction utilities, search functionality, and common helper +functions for text processing and state management. +""" diff --git a/deep_research/utils/approval_utils.py b/deep_research/utils/approval_utils.py new file mode 100644 index 00000000..94cd5a47 --- /dev/null +++ b/deep_research/utils/approval_utils.py @@ -0,0 +1,137 @@ +"""Utility functions for the human approval process.""" + +from typing import Any, Dict, List + +from utils.pydantic_models import ApprovalDecision + + +def format_critique_summary(critique_points: List[Dict[str, Any]]) -> str: + """Format critique points for display.""" + if not critique_points: + return "No critical issues identified." + + formatted = [] + for point in critique_points[:3]: # Show top 3 + issue = point.get("issue", "Unknown issue") + formatted.append(f"- {issue}") + + if len(critique_points) > 3: + formatted.append(f"- ... and {len(critique_points) - 3} more issues") + + return "\n".join(formatted) + + +def format_query_list(queries: List[str]) -> str: + """Format query list for display.""" + if not queries: + return "No queries proposed." + + formatted = [] + for i, query in enumerate(queries, 1): + formatted.append(f"{i}. {query}") + + return "\n".join(formatted) + + +def calculate_estimated_cost(queries: List[str]) -> float: + """Calculate estimated cost for additional queries.""" + # Rough estimate: ~$0.01 per query (including search API + LLM costs) + return round(len(queries) * 0.01, 2) + + +def format_approval_request( + main_query: str, + progress_summary: Dict[str, Any], + critique_points: List[Dict[str, Any]], + proposed_queries: List[str], + timeout: int = 3600, +) -> str: + """Format the approval request message.""" + + # High-priority critiques + high_priority = [ + c for c in critique_points if c.get("importance") == "high" + ] + + message = f"""📊 **Research Progress Update** + +**Main Query:** {main_query} + +**Current Status:** +- Sub-questions analyzed: {progress_summary["completed_count"]} +- Average confidence: {progress_summary["avg_confidence"]} +- Low confidence areas: {progress_summary["low_confidence_count"]} + +**Key Issues Identified:** +{format_critique_summary(high_priority or critique_points)} + +**Proposed Additional Research** ({len(proposed_queries)} queries): +{format_query_list(proposed_queries)} + +**Estimated Additional Time:** ~{len(proposed_queries) * 2} minutes +**Estimated Additional Cost:** ~${calculate_estimated_cost(proposed_queries)} + +**Response Options:** +- Reply with `approve`, `yes`, `ok`, or `LGTM` to proceed with all queries +- Reply with `reject`, `no`, `skip`, or `decline` to finish with current findings + +**Timeout:** Response required within {timeout // 60} minutes""" + + return message + + +def parse_approval_response( + response: str, proposed_queries: List[str] +) -> ApprovalDecision: + """Parse the approval response from user.""" + + response_upper = response.strip().upper() + + if response_upper == "APPROVE ALL": + return ApprovalDecision( + approved=True, + selected_queries=proposed_queries, + approval_method="APPROVE_ALL", + reviewer_notes=response, + ) + + elif response_upper == "SKIP": + return ApprovalDecision( + approved=False, + selected_queries=[], + approval_method="SKIP", + reviewer_notes=response, + ) + + elif response_upper.startswith("SELECT"): + # Parse selection like "SELECT 1,3,5" + try: + # Extract the part after "SELECT" + selection_part = response_upper[6:].strip() + indices = [int(x.strip()) - 1 for x in selection_part.split(",")] + selected = [ + proposed_queries[i] + for i in indices + if 0 <= i < len(proposed_queries) + ] + return ApprovalDecision( + approved=True, + selected_queries=selected, + approval_method="SELECT_SPECIFIC", + reviewer_notes=response, + ) + except Exception as e: + return ApprovalDecision( + approved=False, + selected_queries=[], + approval_method="PARSE_ERROR", + reviewer_notes=f"Failed to parse: {response} - {str(e)}", + ) + + else: + return ApprovalDecision( + approved=False, + selected_queries=[], + approval_method="UNKNOWN_RESPONSE", + reviewer_notes=f"Unknown response: {response}", + ) diff --git a/deep_research/utils/config_utils.py b/deep_research/utils/config_utils.py new file mode 100644 index 00000000..3c40fcde --- /dev/null +++ b/deep_research/utils/config_utils.py @@ -0,0 +1,72 @@ +"""Configuration and environment utilities for the Deep Research Agent.""" + +import logging +import os +from typing import Any, Dict + +import yaml + +logger = logging.getLogger(__name__) + + +def load_pipeline_config(config_path: str) -> Dict[str, Any]: + """Load pipeline configuration from YAML file. + + This is used only for pipeline-level configuration, not for step parameters. + Step parameters should be defined directly in the step functions. + + Args: + config_path: Path to the configuration YAML file + + Returns: + Pipeline configuration dictionary + """ + # Get absolute path if relative + if not os.path.isabs(config_path): + base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + config_path = os.path.join(base_dir, config_path) + + # Load YAML configuration + try: + with open(config_path, "r") as f: + config = yaml.safe_load(f) + return config + except Exception as e: + logger.error(f"Error loading pipeline configuration: {e}") + # Return a minimal default configuration in case of loading error + return { + "pipeline": { + "name": "deep_research_pipeline", + "enable_cache": True, + }, + "environment": { + "docker": { + "requirements": [ + "openai>=1.0.0", + "tavily-python>=0.2.8", + "PyYAML>=6.0", + "click>=8.0.0", + "pydantic>=2.0.0", + "typing_extensions>=4.0.0", + ] + } + }, + "resources": {"cpu": 1, "memory": "4Gi"}, + "timeout": 3600, + } + + +def check_required_env_vars(env_vars: list[str]) -> list[str]: + """Check if required environment variables are set. + + Args: + env_vars: List of environment variable names to check + + Returns: + List of missing environment variables + """ + missing_vars = [] + for var in env_vars: + if not os.environ.get(var): + missing_vars.append(var) + return missing_vars diff --git a/deep_research/utils/css_utils.py b/deep_research/utils/css_utils.py new file mode 100644 index 00000000..e2c5cc15 --- /dev/null +++ b/deep_research/utils/css_utils.py @@ -0,0 +1,267 @@ +"""CSS utility functions for consistent styling across materializers.""" + +import json +import os +from typing import Optional + + +def get_shared_css_path() -> str: + """Get the absolute path to the shared CSS file. + + Returns: + Absolute path to assets/styles.css + """ + base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + return os.path.join(base_dir, "assets", "styles.css") + + +def get_shared_css_content() -> str: + """Read and return the content of the shared CSS file. + + Returns: + Content of the shared CSS file + """ + css_path = get_shared_css_path() + try: + with open(css_path, "r") as f: + return f.read() + except FileNotFoundError: + # Fallback to basic styles if file not found + return """ + body { + font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; + margin: 20px; + color: #333; + } + """ + + +def get_shared_css_tag() -> str: + """Get the complete style tag with shared CSS content. + + Returns: + HTML style tag with shared CSS + """ + css_content = get_shared_css_content() + return f"" + + +def get_confidence_class(level: str) -> str: + """Return appropriate CSS class for confidence level. + + Args: + level: Confidence level (high, medium, low) + + Returns: + CSS class string + """ + return f"dr-confidence dr-confidence--{level.lower()}" + + +def get_badge_class(badge_type: str) -> str: + """Return appropriate CSS class for badges. + + Args: + badge_type: Badge type (success, warning, danger, info, primary) + + Returns: + CSS class string + """ + return f"dr-badge dr-badge--{badge_type.lower()}" + + +def get_status_class(status: str) -> str: + """Return appropriate CSS class for status indicators. + + Args: + status: Status type (approved, pending, rejected, etc.) + + Returns: + CSS class string + """ + status_map = { + "approved": "success", + "pending": "warning", + "rejected": "danger", + "completed": "success", + "in_progress": "info", + "failed": "danger", + } + badge_type = status_map.get(status.lower(), "primary") + return get_badge_class(badge_type) + + +def get_section_class(section_type: Optional[str] = None) -> str: + """Return appropriate CSS class for sections. + + Args: + section_type: Optional section type (info, warning, success, danger) + + Returns: + CSS class string + """ + if section_type: + return f"dr-section dr-section--{section_type.lower()}" + return "dr-section" + + +def get_card_class(hoverable: bool = True) -> str: + """Return appropriate CSS class for cards. + + Args: + hoverable: Whether the card should have hover effects + + Returns: + CSS class string + """ + classes = ["dr-card"] + if not hoverable: + classes.append("dr-card--no-hover") + return " ".join(classes) + + +def get_table_class(striped: bool = False) -> str: + """Return appropriate CSS class for tables. + + Args: + striped: Whether the table should have striped rows + + Returns: + CSS class string + """ + classes = ["dr-table"] + if striped: + classes.append("dr-table--striped") + return " ".join(classes) + + +def get_button_class( + button_type: str = "primary", size: str = "normal" +) -> str: + """Return appropriate CSS class for buttons. + + Args: + button_type: Button type (primary, secondary, success) + size: Button size (normal, small) + + Returns: + CSS class string + """ + classes = ["dr-button"] + if button_type != "primary": + classes.append(f"dr-button--{button_type}") + if size == "small": + classes.append("dr-button--small") + return " ".join(classes) + + +def get_grid_class(grid_type: str = "cards") -> str: + """Return appropriate CSS class for grid layouts. + + Args: + grid_type: Grid type (stats, cards, metrics) + + Returns: + CSS class string + """ + return f"dr-grid dr-grid--{grid_type}" + + +def wrap_with_container(content: str, wide: bool = False) -> str: + """Wrap content with container div. + + Args: + content: HTML content to wrap + wide: Whether to use wide container + + Returns: + Wrapped HTML content + """ + container_class = ( + "dr-container dr-container--wide" if wide else "dr-container" + ) + return f'
    {content}
    ' + + +def create_stat_card(value: str, label: str, format_value: bool = True) -> str: + """Create a stat card HTML. + + Args: + value: The statistic value + label: The label for the statistic + format_value: Whether to wrap value in stat-value div + + Returns: + HTML for stat card + """ + value_html = ( + f'
    {value}
    ' if format_value else value + ) + return f""" +
    + {value_html} +
    {label}
    +
    + """ + + +def create_notice(content: str, notice_type: str = "info") -> str: + """Create a notice box HTML. + + Args: + content: Notice content + notice_type: Notice type (info, warning) + + Returns: + HTML for notice box + """ + return f""" +
    + {content} +
    + """ + + +def extract_html_from_content(content: str) -> str: + """Attempt to extract HTML content from a response that might be wrapped in other formats. + + Args: + content: The content to extract HTML from + + Returns: + The extracted HTML, or a basic fallback if extraction fails + """ + if not content: + return "" + + # Try to find HTML between tags + if "" in content: + start = content.find("") + 7 # Include the closing tag + return content[start:end] + + # Try to find div class="research-report" + if '
    " in content: + start = content.find('
    ") + if last_div > start: + return content[start : last_div + 6] # Include the closing tag + + # Look for code blocks + if "```html" in content and "```" in content: + start = content.find("```html") + 7 + end = content.find("```", start) + if end > start: + return content[start:end].strip() + + # Look for JSON with an "html" field + try: + parsed = json.loads(content) + if isinstance(parsed, dict) and "html" in parsed: + return parsed["html"] + except: + pass + + # If all extraction attempts fail, return the original content + return content diff --git a/deep_research/utils/llm_utils.py b/deep_research/utils/llm_utils.py new file mode 100644 index 00000000..54dd27ee --- /dev/null +++ b/deep_research/utils/llm_utils.py @@ -0,0 +1,458 @@ +import contextlib +import json +import logging +from json.decoder import JSONDecodeError +from typing import Any, Dict, List, Optional + +import litellm +from litellm import completion +from utils.prompts import SYNTHESIS_PROMPT +from zenml import get_step_context + +logger = logging.getLogger(__name__) + +# This module uses litellm for all LLM interactions +# Models are specified with a provider prefix (e.g., "sambanova/DeepSeek-R1-Distill-Llama-70B") +# ALL model names require a provider prefix (e.g., "sambanova/", "openai/", "anthropic/") + +litellm.callbacks = ["langfuse"] + + +def remove_reasoning_from_output(output: str) -> str: + """Remove the reasoning portion from LLM output. + + Args: + output: Raw output from LLM that may contain reasoning + + Returns: + Cleaned output without the reasoning section + """ + if not output: + return "" + + if "" in output: + return output.split("")[-1].strip() + return output.strip() + + +def clean_json_tags(text: str) -> str: + """Clean JSON markdown tags from text. + + Args: + text: Text with potential JSON markdown tags + + Returns: + Cleaned text without JSON markdown tags + """ + if not text: + return "" + + cleaned = text.replace("```json\n", "").replace("\n```", "") + cleaned = cleaned.replace("```json", "").replace("```", "") + return cleaned + + +def clean_markdown_tags(text: str) -> str: + """Clean Markdown tags from text. + + Args: + text: Text with potential markdown tags + + Returns: + Cleaned text without markdown tags + """ + if not text: + return "" + + cleaned = text.replace("```markdown\n", "").replace("\n```", "") + cleaned = cleaned.replace("```markdown", "").replace("```", "") + return cleaned + + +def safe_json_loads(json_str: Optional[str]) -> Dict[str, Any]: + """Safely parse JSON string. + + Args: + json_str: JSON string to parse, can be None. + + Returns: + Dict[str, Any]: Parsed JSON as dictionary or empty dict if parsing fails or input is None. + """ + if json_str is None: + # Optionally, log a warning here if None input is unexpected for certain call sites + # logger.warning("safe_json_loads received None input.") + return {} + try: + return json.loads(json_str) + except ( + JSONDecodeError, + TypeError, + ): # Catch TypeError if json_str is not a valid type for json.loads + # Optionally, log the error and the problematic string (or its beginning) + # logger.warning(f"Failed to decode JSON string: '{str(json_str)[:200]}...'", exc_info=True) + return {} + + +def run_llm_completion( + prompt: str, + system_prompt: str, + model: str = "openrouter/google/gemini-2.0-flash-lite-001", + clean_output: bool = True, + max_tokens: int = 2000, # Increased default token limit + temperature: float = 0.2, + top_p: float = 0.9, + project: str = "deep-research", + tags: Optional[List[str]] = None, +) -> str: + """Run an LLM completion with standard error handling and output cleaning. + + Uses litellm for model inference. + + Args: + prompt: User prompt for the LLM + system_prompt: System prompt for the LLM + model: Model to use for completion (with provider prefix) + clean_output: Whether to clean reasoning and JSON tags from output. When True, + this removes any reasoning sections marked with tags and strips JSON + code block markers. + max_tokens: Maximum tokens to generate + temperature: Sampling temperature + top_p: Top-p sampling value + project: Langfuse project name for LLM tracking + tags: Optional list of tags for Langfuse tracking. If provided, also converted to trace_metadata format. + + Returns: + str: Processed LLM output with optional cleaning applied + """ + try: + # Ensure model name has provider prefix + if not any( + model.startswith(prefix + "/") + for prefix in [ + "sambanova", + "openai", + "anthropic", + "meta", + "google", + "aws", + "openrouter", + ] + ): + # Raise an error if no provider prefix is specified + error_msg = f"Model '{model}' does not have a provider prefix. Please specify provider (e.g., 'sambanova/{model}')" + logger.error(error_msg) + raise ValueError(error_msg) + + # Get pipeline run name and id for trace_name and trace_id if running in a step + trace_name = None + trace_id = None + with contextlib.suppress(RuntimeError): + context = get_step_context() + trace_name = context.pipeline_run.name + trace_id = str(context.pipeline_run.id) + # Build metadata dict + metadata = {"project": project} + if tags is not None: + metadata["tags"] = tags + # Convert tags to trace_metadata format + metadata["trace_metadata"] = {tag: True for tag in tags} + if trace_name: + metadata["trace_name"] = trace_name + if trace_id: + metadata["trace_id"] = trace_id + + response = completion( + model=model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt}, + ], + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + metadata=metadata, + ) + + # Defensive access to content + content = None + if response and response.choices and len(response.choices) > 0: + choice = response.choices[0] + if choice and choice.message: + content = choice.message.content + + if content is None: + logger.warning("LLM response content is missing or empty.") + return "" + + if clean_output: + content = remove_reasoning_from_output(content) + content = clean_json_tags(content) + + return content + except Exception as e: + logger.error(f"Error in LLM completion: {e}") + return "" + + +def get_structured_llm_output( + prompt: str, + system_prompt: str, + model: str = "openrouter/google/gemini-2.0-flash-lite-001", + fallback_response: Optional[Dict[str, Any]] = None, + max_tokens: int = 2000, # Increased default token limit for structured outputs + temperature: float = 0.2, + top_p: float = 0.9, + project: str = "deep-research", + tags: Optional[List[str]] = None, +) -> Dict[str, Any]: + """Get structured JSON output from an LLM with error handling. + + Uses litellm for model inference. + + Args: + prompt: User prompt for the LLM + system_prompt: System prompt for the LLM + model: Model to use for completion (with provider prefix) + fallback_response: Fallback response if parsing fails + max_tokens: Maximum tokens to generate + temperature: Sampling temperature + top_p: Top-p sampling value + project: Langfuse project name for LLM tracking + tags: Optional list of tags for Langfuse tracking. Defaults to ["structured_llm_output"] if None. + + Returns: + Parsed JSON response or fallback + """ + try: + # Use provided tags or default to ["structured_llm_output"] + if tags is None: + tags = ["structured_llm_output"] + + content = run_llm_completion( + prompt=prompt, + system_prompt=system_prompt, + model=model, + clean_output=True, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + project=project, + tags=tags, + ) + + if not content: + logger.warning("Empty content returned from LLM") + return fallback_response if fallback_response is not None else {} + + result = safe_json_loads(content) + + if not result and fallback_response is not None: + return fallback_response + + return result + except Exception as e: + logger.error(f"Error processing structured LLM output: {e}") + return fallback_response if fallback_response is not None else {} + + +def is_text_relevant(text1: str, text2: str, min_word_length: int = 4) -> bool: + """Determine if two pieces of text are relevant to each other. + + Relevance is determined by checking if one text is contained within the other, + or if they share significant words (words longer than min_word_length). + This is a simple heuristic approach that checks for: + 1. Complete containment (one text string inside the other) + 2. Shared significant words (words longer than min_word_length) + + Args: + text1: First text to compare + text2: Second text to compare + min_word_length: Minimum length of words to check for shared content + + Returns: + bool: True if the texts are deemed relevant to each other based on the criteria + """ + if not text1 or not text2: + return False + + return ( + text1.lower() in text2.lower() + or text2.lower() in text1.lower() + or any( + word + for word in text1.lower().split() + if len(word) > min_word_length and word in text2.lower() + ) + ) + + +def find_most_relevant_string( + target: str, + options: List[str], + model: Optional[str] = "openrouter/google/gemini-2.0-flash-lite-001", + project: str = "deep-research", + tags: Optional[List[str]] = None, +) -> Optional[str]: + """Find the most relevant string from a list of options using simple text matching. + + If model is provided, uses litellm to determine relevance. + + Args: + target: The target string to find relevance for + options: List of string options to check against + model: Model to use for matching (with provider prefix) + project: Langfuse project name for LLM tracking + tags: Optional list of tags for Langfuse tracking. Defaults to ["find_most_relevant_string"] if None. + + Returns: + The most relevant string, or None if no relevant options + """ + if not options: + return None + + if len(options) == 1: + return options[0] + + # If model is provided, use litellm for more accurate matching + if model: + try: + # Ensure model name has provider prefix + if not any( + model.startswith(prefix + "/") + for prefix in [ + "sambanova", + "openai", + "anthropic", + "meta", + "google", + "aws", + "openrouter", + ] + ): + # Raise an error if no provider prefix is specified + error_msg = f"Model '{model}' does not have a provider prefix. Please specify provider (e.g., 'sambanova/{model}')" + logger.error(error_msg) + raise ValueError(error_msg) + + system_prompt = "You are a research assistant." + prompt = f"""Given the text: "{target}" +Which of the following options is most relevant to this text? +{options} + +Respond with only the exact text of the most relevant option.""" + + # Get pipeline run name and id for trace_name and trace_id if running in a step + trace_name = None + trace_id = None + try: + context = get_step_context() + trace_name = context.pipeline_run.name + trace_id = str(context.pipeline_run.id) + except RuntimeError: + # Not running in a step context + pass + + # Use provided tags or default to ["find_most_relevant_string"] + if tags is None: + tags = ["find_most_relevant_string"] + + # Build metadata dict + metadata = {"project": project, "tags": tags} + # Convert tags to trace_metadata format + metadata["trace_metadata"] = {tag: True for tag in tags} + if trace_name: + metadata["trace_name"] = trace_name + if trace_id: + metadata["trace_id"] = trace_id + + response = completion( + model=model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt}, + ], + max_tokens=100, + temperature=0.2, + metadata=metadata, + ) + + answer = response.choices[0].message.content.strip() + + # Check if the answer is one of the options + if answer in options: + return answer + + # If not an exact match, find the closest one + for option in options: + if option in answer or answer in option: + return option + + except Exception as e: + logger.error(f"Error finding relevant string with LLM: {e}") + + # Simple relevance check - find exact matches first + for option in options: + if target.lower() == option.lower(): + return option + + # Then check partial matches + for option in options: + if is_text_relevant(target, option): + return option + + # Return the first option as a fallback + return options[0] + + +def synthesize_information( + synthesis_input: Dict[str, Any], + model: str = "openrouter/google/gemini-2.0-flash-lite-001", + system_prompt: Optional[str] = None, + project: str = "deep-research", + tags: Optional[List[str]] = None, +) -> Dict[str, Any]: + """Synthesize information from search results for a sub-question. + + Uses litellm for model inference. + + Args: + synthesis_input: Dictionary with sub-question, search results, and sources + model: Model to use (with provider prefix) + system_prompt: System prompt for the LLM + project: Langfuse project name for LLM tracking + tags: Optional list of tags for Langfuse tracking. Defaults to ["information_synthesis"] if None. + + Returns: + Dictionary with synthesized information + """ + if system_prompt is None: + system_prompt = SYNTHESIS_PROMPT + + sub_question_for_log = synthesis_input.get( + "sub_question", "unknown question" + ) + + # Define the fallback response + fallback_response = { + "synthesized_answer": f"Synthesis failed for '{sub_question_for_log}'.", + "key_sources": synthesis_input.get("sources", [])[:1], + "confidence_level": "low", + "information_gaps": "An error occurred during the synthesis process.", + } + + # Use provided tags or default to ["information_synthesis"] + if tags is None: + tags = ["information_synthesis"] + + # Use the utility function to get structured output + result = get_structured_llm_output( + prompt=json.dumps(synthesis_input), + system_prompt=system_prompt, + model=model, + fallback_response=fallback_response, + max_tokens=3000, # Increased for more detailed synthesis + project=project, + tags=tags, + ) + + return result diff --git a/deep_research/utils/prompt_models.py b/deep_research/utils/prompt_models.py new file mode 100644 index 00000000..00e08157 --- /dev/null +++ b/deep_research/utils/prompt_models.py @@ -0,0 +1,27 @@ +"""Pydantic models for prompt tracking and management. + +This module contains models for tracking prompts as artifacts +in the ZenML pipeline, enabling better observability and version control. +""" + +from pydantic import BaseModel, Field + + +class PromptTemplate(BaseModel): + """Represents a single prompt template with metadata.""" + + name: str = Field(..., description="Unique identifier for the prompt") + content: str = Field(..., description="The actual prompt template content") + description: str = Field( + "", description="Human-readable description of what this prompt does" + ) + version: str = Field("1.0.0", description="Version of the prompt template") + tags: list[str] = Field( + default_factory=list, description="Tags for categorizing prompts" + ) + + model_config = { + "extra": "ignore", + "frozen": False, + "validate_assignment": True, + } diff --git a/deep_research/utils/prompts.py b/deep_research/utils/prompts.py new file mode 100644 index 00000000..ced42f6f --- /dev/null +++ b/deep_research/utils/prompts.py @@ -0,0 +1,1435 @@ +""" +Centralized collection of prompts used throughout the deep research pipeline. + +This module contains all system prompts used by LLM calls in various steps of the +research pipeline to ensure consistency and make prompt management easier. +""" + +# Search query generation prompt +# Used to generate effective search queries from sub-questions +DEFAULT_SEARCH_QUERY_PROMPT = """ +You are a Deep Research assistant. Given a specific research sub-question, your task is to formulate an effective search +query that will help find relevant information to answer the question. + +A good search query should: +1. Extract the key concepts from the sub-question +2. Use precise, specific terminology +3. Exclude unnecessary words or context +4. Include alternative terms or synonyms when helpful +5. Be concise yet comprehensive enough to find relevant results + +Format the output in json with the following json schema definition: + + +{ + "type": "object", + "properties": { + "search_query": {"type": "string"}, + "reasoning": {"type": "string"} + } +} + + +Make sure that the output is a json object with an output json schema defined above. +Only return the json object, no explanation or additional text. +""" + +# Query decomposition prompt +# Used to break down complex research queries into specific sub-questions +QUERY_DECOMPOSITION_PROMPT = """ +You are a Deep Research assistant specializing in research design. You will be given a MAIN RESEARCH QUERY that needs to be explored comprehensively. Your task is to create diverse, insightful sub-questions that explore different dimensions of the topic. + +IMPORTANT: The main query should be interpreted as a single research question, not as a noun phrase. For example: +- If the query is "Is LLMOps a subset of MLOps?", create questions ABOUT LLMOps and MLOps, not questions like "What is 'Is LLMOps a subset of MLOps?'" +- Focus on the concepts, relationships, and implications within the query + +Create sub-questions that explore these DIFFERENT DIMENSIONS: + +1. **Definitional/Conceptual**: Define key terms and establish conceptual boundaries + Example: "What are the core components and characteristics of LLMOps?" + +2. **Comparative/Relational**: Compare and contrast the concepts mentioned + Example: "How do the workflows and tooling of LLMOps differ from traditional MLOps?" + +3. **Historical/Evolutionary**: Trace development and emergence + Example: "How did LLMOps emerge from MLOps practices?" + +4. **Structural/Technical**: Examine technical architecture and implementation + Example: "What specific tools and platforms are unique to LLMOps?" + +5. **Practical/Use Cases**: Explore real-world applications + Example: "What are the key use cases that require LLMOps but not traditional MLOps?" + +6. **Stakeholder/Industry**: Consider different perspectives and adoption + Example: "How are different industries adopting LLMOps vs MLOps?" + +7. **Challenges/Limitations**: Identify problems and constraints + Example: "What unique challenges does LLMOps face that MLOps doesn't?" + +8. **Future/Trends**: Look at emerging developments + Example: "How is the relationship between LLMOps and MLOps expected to evolve?" + +QUALITY GUIDELINES: +- Each sub-question must explore a DIFFERENT dimension - no repetitive variations +- Questions should be specific, concrete, and investigable +- Mix descriptive ("what/who") with analytical ("why/how") questions +- Ensure questions build toward answering the main query comprehensively +- Frame questions to elicit detailed, nuanced responses +- Consider technical, business, organizational, and strategic aspects + +Format the output in json with the following json schema definition: + + +{ + "type": "array", + "items": { + "type": "object", + "properties": { + "sub_question": {"type": "string"}, + "reasoning": {"type": "string"} + } + } +} + + +Make sure that the output is a json object with an output json schema defined above. +Only return the json object, no explanation or additional text. +""" + +# Synthesis prompt for individual sub-questions +# Used to synthesize search results into comprehensive answers for sub-questions +SYNTHESIS_PROMPT = """ +You are a Deep Research assistant specializing in information synthesis. Given a sub-question and search results, your task is to synthesize the information +into a comprehensive, accurate, and well-structured answer. + +Your synthesis should: +1. Begin with a direct, concise answer to the sub-question in the first paragraph +2. Provide detailed evidence and explanation in subsequent paragraphs (at least 3-5 paragraphs total) +3. Integrate information from multiple sources, citing them within your answer +4. Acknowledge any conflicting information or contrasting viewpoints you encounter +5. Use data, statistics, examples, and quotations when available to strengthen your answer +6. Organize information logically with a clear flow between concepts +7. Identify key sources that provided the most valuable information (at least 2-3 sources) +8. Explicitly acknowledge information gaps where the search results were incomplete +9. Write in plain text format - do NOT use markdown formatting, bullet points, or special characters + +Confidence level criteria: +- HIGH: Multiple high-quality sources provide consistent information, comprehensive coverage of the topic, and few information gaps +- MEDIUM: Decent sources with some consistency, but notable information gaps or some conflicting information +- LOW: Limited sources, major information gaps, significant contradictions, or only tangentially relevant information + +Information gaps should specifically identify: +1. Aspects of the question that weren't addressed in the search results +2. Areas where more detailed or up-to-date information would be valuable +3. Perspectives or data sources that would complement the existing information + +Format the output in json with the following json schema definition: + + +{ + "type": "object", + "properties": { + "synthesized_answer": {"type": "string"}, + "key_sources": { + "type": "array", + "items": {"type": "string"} + }, + "confidence_level": {"type": "string", "enum": ["high", "medium", "low"]}, + "information_gaps": {"type": "string"}, + "improvements": { + "type": "array", + "items": {"type": "string"} + } + } +} + + +Make sure that the output is a json object with an output json schema defined above. +Only return the json object, no explanation or additional text. +""" + +# Viewpoint analysis prompt for cross-perspective examination +# Used to analyze synthesized answers across different perspectives and viewpoints +VIEWPOINT_ANALYSIS_PROMPT = """ +You are a Deep Research assistant specializing in multi-perspective analysis. You will be given a set of synthesized answers +to sub-questions related to a main research query. Your task is to perform a thorough, nuanced analysis of how different +perspectives would interpret this information. + +Think deeply about the following viewpoint categories and how they would approach the information differently: +- Scientific: Evidence-based, empirical approach focused on data, research findings, and methodological rigor +- Political: Power dynamics, governance structures, policy implications, and ideological frameworks +- Economic: Resource allocation, financial impacts, market dynamics, and incentive structures +- Social: Cultural norms, community impacts, group dynamics, and public welfare +- Ethical: Moral principles, values considerations, rights and responsibilities, and normative judgments +- Historical: Long-term patterns, precedents, contextual development, and evolutionary change + +For each synthesized answer, analyze how these different perspectives would interpret the information by: + +1. Identifying 5-8 main points of agreement where multiple perspectives align (with specific examples) +2. Analyzing at least 3-5 areas of tension between perspectives with: + - A clear topic title for each tension point + - Contrasting interpretations from at least 2-3 different viewpoint categories per tension + - Specific examples or evidence showing why these perspectives differ + - The nuanced positions of each perspective, not just simplified oppositions + +3. Thoroughly examining perspective gaps by identifying: + - Which perspectives are underrepresented or missing in the current research + - How including these missing perspectives would enrich understanding + - Specific questions or dimensions that remain unexplored + - Write in plain text format - do NOT use markdown formatting, bullet points, or special characters + +4. Developing integrative insights that: + - Synthesize across multiple perspectives to form a more complete understanding + - Highlight how seemingly contradictory viewpoints can complement each other + - Suggest frameworks for reconciling tensions or finding middle-ground approaches + - Identify actionable takeaways that incorporate multiple perspectives + - Write in plain text format - do NOT use markdown formatting, bullet points, or special characters + +Format the output in json with the following json schema definition: + + +{ + "type": "object", + "properties": { + "main_points_of_agreement": { + "type": "array", + "items": {"type": "string"} + }, + "areas_of_tension": { + "type": "array", + "items": { + "type": "object", + "properties": { + "topic": {"type": "string"}, + "viewpoints": { + "type": "object", + "additionalProperties": {"type": "string"} + } + } + } + }, + "perspective_gaps": {"type": "string"}, + "integrative_insights": {"type": "string"} + } +} + + +Make sure that the output is a json object with an output json schema defined above. +Only return the json object, no explanation or additional text. +""" + +# Reflection prompt for self-critique and improvement +# Used to evaluate the research and identify gaps, biases, and areas for improvement +REFLECTION_PROMPT = """ +You are a Deep Research assistant with the ability to critique and improve your own research. You will be given: +1. The main research query +2. The sub-questions explored so far +3. The synthesized information for each sub-question +4. Any viewpoint analysis performed + +Your task is to critically evaluate this research and identify: +1. Areas where the research is incomplete or has gaps +2. Questions that are important but not yet answered +3. Aspects where additional evidence or depth would significantly improve the research +4. Potential biases or limitations in the current findings + +Be constructively critical and identify the most important improvements that would substantially enhance the research. + +Format the output in json with the following json schema definition: + + +{ + "type": "object", + "properties": { + "critique": { + "type": "array", + "items": { + "type": "object", + "properties": { + "area": {"type": "string"}, + "issue": {"type": "string"}, + "importance": {"type": "string", "enum": ["high", "medium", "low"]} + } + } + }, + "additional_questions": { + "type": "array", + "items": {"type": "string"} + }, + "recommended_search_queries": { + "type": "array", + "items": {"type": "string"} + } + } +} + + +Make sure that the output is a json object with an output json schema defined above. +Only return the json object, no explanation or additional text. +""" + +# Additional synthesis prompt for incorporating new information +# Used to enhance original synthesis with new information and address critique points +ADDITIONAL_SYNTHESIS_PROMPT = """ +You are a Deep Research assistant. You will be given: +1. The original synthesized information on a research topic +2. New information from additional research +3. A critique of the original synthesis + +Your task is to enhance the original synthesis by incorporating the new information and addressing the critique. +The updated synthesis should: +1. Integrate new information seamlessly +2. Address gaps identified in the critique +3. Maintain a balanced, comprehensive, and accurate representation +4. Preserve the strengths of the original synthesis +5. Write in plain text format - do NOT use markdown formatting, bullet points, or special characters + +Format the output in json with the following json schema definition: + + +{ + "type": "object", + "properties": { + "enhanced_synthesis": {"type": "string"}, + "improvements_made": { + "type": "array", + "items": {"type": "string"} + }, + "remaining_limitations": {"type": "string"} + } +} + + +Make sure that the output is a json object with an output json schema defined above. +Only return the json object, no explanation or additional text. +""" + +# Final report generation prompt +# Used to compile a comprehensive HTML research report from all synthesized information +REPORT_GENERATION_PROMPT = """ +You are a Deep Research assistant responsible for compiling an in-depth, comprehensive research report. You will be given: +1. The original research query +2. The sub-questions that were explored +3. Synthesized information for each sub-question +4. Viewpoint analysis comparing different perspectives (if available) +5. Reflection metadata highlighting improvements and limitations + +Your task is to create a well-structured, coherent, professional-quality research report with the following features: + +EXECUTIVE SUMMARY (250-400 words): +- Begin with a compelling, substantive executive summary that provides genuine insight +- Highlight 3-5 key findings or insights that represent the most important discoveries +- Include brief mention of methodology and limitations +- Make the summary self-contained so it can be read independently of the full report +- End with 1-2 sentences on broader implications or applications of the research + +INTRODUCTION (200-300 words): +- Provide relevant background context on the main research query +- Explain why this topic is significant or worth investigating +- Outline the methodological approach used (sub-questions, search strategy, synthesis) +- Preview the overall structure of the report + +SUB-QUESTION SECTIONS: +- For each sub-question, create a dedicated section with: + * A descriptive section title (not just repeating the sub-question) + * A brief (1 paragraph) overview of key findings for this sub-question + * A "Key Findings" box highlighting 3-4 important discoveries for scannable reading + * The detailed, synthesized answer with appropriate paragraph breaks, lists, and formatting + * Proper citation of sources within the text (e.g., "According to [Source Name]...") + * Clear confidence indicator with appropriate styling + * Information gaps clearly identified in their own subsection + * Complete list of key sources used + +VIEWPOINT ANALYSIS SECTION (if available): +- Create a detailed section that: + * Explains the purpose and value of multi-perspective analysis + * Presents points of agreement as actionable insights, not just observations + * Structures tension areas with clear topic headings and balanced presentation of viewpoints + * Uses visual elements (different background colors, icons) to distinguish different perspectives + * Integrates perspective gaps and insights into a cohesive narrative + +CONCLUSION (300-400 words): +- Synthesize the overall findings, not just summarizing each section +- Connect insights from different sub-questions to form higher-level understanding +- Address the main research query directly with evidence-based conclusions +- Acknowledge remaining uncertainties and suggestions for further research +- End with implications or applications of the research findings + +OVERALL QUALITY REQUIREMENTS: +1. Create visually scannable content with clear headings, bullet points, and short paragraphs +2. Use semantic HTML (h1, h2, h3, p, blockquote, etc.) to create proper document structure +3. Include a comprehensive table of contents with anchor links to all major sections +4. Format all sources consistently in the references section with proper linking when available +5. Use tables, lists, and blockquotes to improve readability and highlight important information +6. Apply appropriate styling for different confidence levels (high, medium, low) +7. Ensure proper HTML nesting and structure throughout the document +8. Balance sufficient detail with clarity and conciseness +9. Make all text directly actionable and insight-driven, not just descriptive + +The report should be formatted in HTML with appropriate headings, paragraphs, citations, and formatting. +Use semantic HTML (h1, h2, h3, p, blockquote, etc.) to create a structured document. +Include a table of contents at the beginning with anchor links to each section. +For citations, use a consistent format and collect them in a references section at the end. + +Include this exact CSS stylesheet in your HTML to ensure consistent styling (do not modify it): + +```css + +``` + +The HTML structure should follow this pattern: + +```html + + + + + + [CSS STYLESHEET GOES HERE] + + +
    +

    Research Report: [Main Query]

    + + +
    +

    Table of Contents

    + +
    + + +
    +

    Executive Summary

    + [CONCISE SUMMARY OF KEY FINDINGS] +
    + + +
    +

    Introduction

    +

    [INTRODUCTION TO THE RESEARCH QUERY]

    +

    [OVERVIEW OF THE APPROACH AND SUB-QUESTIONS]

    +
    + + + [FOR EACH SUB-QUESTION]: +
    +

    [INDEX]. [SUB-QUESTION TEXT]

    +

    Confidence Level: [LEVEL]

    + + +
    +

    Key Findings

    +
      +
    • [KEY FINDING 1]
    • +
    • [KEY FINDING 2]
    • + [...] +
    +
    + +
    + [DETAILED ANSWER] +
    + + +
    +

    Information Gaps

    +

    [GAPS TEXT]

    +
    + + +
    +

    Key Sources

    +
      +
    • [SOURCE 1]
    • +
    • [SOURCE 2]
    • + [...] +
    +
    +
    + + +
    +

    Viewpoint Analysis

    + +

    Points of Agreement

    +
    +
      +
    • [AGREEMENT 1]
    • +
    • [AGREEMENT 2]
    • + [...] +
    +
    + +

    Areas of Tension

    + [FOR EACH TENSION]: +
    +

    [TENSION TOPIC]

    +
    +
    [VIEWPOINT 1 TITLE]
    +
    [VIEWPOINT 1 CONTENT]
    +
    [VIEWPOINT 2 TITLE]
    +
    [VIEWPOINT 2 CONTENT]
    + [...] +
    +
    + +

    Perspective Gaps

    +

    [PERSPECTIVE GAPS CONTENT]

    + +

    Integrative Insights

    +

    [INTEGRATIVE INSIGHTS CONTENT]

    +
    + + +
    +

    Conclusion

    +

    [CONCLUSION TEXT]

    +
    + + +
    +

    References

    +
      +
    • [REFERENCE 1]
    • +
    • [REFERENCE 2]
    • + [...] +
    +
    +
    + + +``` + +Special instructions: +1. For each sub-question, display the confidence level with appropriate styling (confidence-high, confidence-medium, or confidence-low) +2. Extract 2-3 key findings from each answer to create the key-findings box +3. Format all sources consistently in the references section +4. Use tables, lists, and blockquotes where appropriate to improve readability +5. Use the notice classes (info, warning) to highlight important information or limitations +6. Ensure all sections have proper ID attributes for the table of contents links + +Return only the complete HTML code for the report, with no explanations or additional text. +""" + + +# Executive Summary generation prompt +# Used to create a compelling, insight-driven executive summary +EXECUTIVE_SUMMARY_GENERATION_PROMPT = """ +You are a Deep Research assistant specializing in creating executive summaries. Given comprehensive research findings, your task is to create a compelling executive summary that captures the essence of the research and its key insights. + +Your executive summary should: + +1. **Opening Statement (1-2 sentences):** + - Start with a powerful, direct answer to the main research question + - Make it clear and definitive based on the evidence gathered + +2. **Key Findings (3-5 bullet points):** + - Extract the MOST IMPORTANT discoveries from across all sub-questions + - Focus on insights that are surprising, actionable, or paradigm-shifting + - Each finding should be specific and evidence-based, not generic + - Prioritize findings that directly address the main query + +3. **Critical Insights (2-3 sentences):** + - Synthesize patterns or themes that emerged across multiple sub-questions + - Highlight any unexpected discoveries or counter-intuitive findings + - Connect disparate findings to reveal higher-level understanding + +4. **Implications (2-3 sentences):** + - What do these findings mean for practitioners/stakeholders? + - What actions or decisions can be made based on this research? + - Why should the reader care about these findings? + +5. **Confidence and Limitations (1-2 sentences):** + - Briefly acknowledge the overall confidence level of the findings + - Note any significant gaps or areas requiring further investigation + +IMPORTANT GUIDELINES: +- Be CONCISE but INSIGHTFUL - every sentence should add value +- Use active voice and strong, definitive language where evidence supports it +- Avoid generic statements - be specific to the actual research findings +- Lead with the most important information +- Make it self-contained - reader should understand key findings without reading the full report +- Target length: 250-400 words + +Format as well-structured HTML paragraphs using

    tags and

      /
    • for bullet points. +""" + +# Introduction generation prompt +# Used to create a contextual, engaging introduction +INTRODUCTION_GENERATION_PROMPT = """ +You are a Deep Research assistant specializing in creating engaging introductions. Given a research query and the sub-questions explored, your task is to create an introduction that provides context and sets up the reader's expectations. + +Your introduction should: + +1. **Context and Relevance (2-3 sentences):** + - Why is this research question important NOW? + - What makes this topic significant or worth investigating? + - Connect to current trends, debates, or challenges in the field + +2. **Scope and Approach (2-3 sentences):** + - What specific aspects of the topic does this research explore? + - Briefly mention the key dimensions covered (based on sub-questions) + - Explain the systematic approach without being too technical + +3. **What to Expect (2-3 sentences):** + - Preview the structure of the report + - Hint at some of the interesting findings or tensions discovered + - Set expectations about the depth and breadth of analysis + +IMPORTANT GUIDELINES: +- Make it engaging - hook the reader's interest from the start +- Provide real context, not generic statements +- Connect to why this matters for the reader +- Keep it concise but informative (200-300 words) +- Use active voice and clear language +- Build anticipation for the findings without giving everything away + +Format as well-structured HTML paragraphs using

      tags. Do NOT include any headings or section titles. +""" + +# Conclusion generation prompt +# Used to synthesize all research findings into a comprehensive conclusion +CONCLUSION_GENERATION_PROMPT = """ +You are a Deep Research assistant specializing in synthesizing comprehensive research conclusions. Given all the research findings from a deep research study, your task is to create a thoughtful, evidence-based conclusion that ties together the overall findings. + +Your conclusion should: + +1. **Synthesis and Integration (150-200 words):** + - Connect insights from different sub-questions to form a higher-level understanding + - Identify overarching themes and patterns that emerge from the research + - Highlight how different findings relate to and support each other + - Avoid simply summarizing each section separately + +2. **Direct Response to Main Query (100-150 words):** + - Address the original research question directly with evidence-based conclusions + - State what the research definitively established vs. what remains uncertain + - Provide a clear, actionable answer based on the synthesized evidence + +3. **Limitations and Future Directions (100-120 words):** + - Acknowledge remaining uncertainties and information gaps across all sections + - Suggest specific areas where additional research would be most valuable + - Identify what types of evidence or perspectives would strengthen the findings + +4. **Implications and Applications (80-100 words):** + - Explain the practical significance of the research findings + - Suggest how the insights might be applied or what they mean for stakeholders + - Connect findings to broader contexts or implications + +Format your output as a well-structured conclusion section in HTML format with appropriate paragraph breaks and formatting. Use

      tags for paragraphs and organize the content logically with clear transitions between the different aspects outlined above. + +IMPORTANT: Do NOT include any headings like "Conclusion",

      , or

      tags - the section already has a heading. Start directly with the conclusion content in paragraph form. Just create flowing, well-structured paragraphs that cover all four aspects naturally. + +Ensure the conclusion feels cohesive and draws meaningful connections between findings rather than just listing them sequentially. +""" + +# Static HTML template for direct report generation without LLM +STATIC_HTML_TEMPLATE = """ + + + + + Research Report: {main_query} + + + + {shared_css} + + + +
      +

      Research Report: {main_query}

      + + +
      +

      Table of Contents

      + +
      + + +
      +

      Executive Summary

      +

      {executive_summary}

      +
      + + +
      +

      Introduction

      + {introduction_html} +
      + + + {sub_questions_html} + + + {viewpoint_analysis_html} + + +
      +

      Conclusion

      + {conclusion_html} +
      + + +
      +

      References

      + {references_html} +
      +
      + + +""" + +# Template for sub-question section in the static HTML report +SUB_QUESTION_TEMPLATE = """ +
      +
      +

      {index}. {question}

      + + + {confidence_icon} + + Confidence: {confidence_upper} + +
      + +
      +

      {answer}

      +
      + + {info_gaps_html} + + {key_sources_html} +
      +""" + +# Template for viewpoint analysis section in the static HTML report +VIEWPOINT_ANALYSIS_TEMPLATE = """ +
      +

      Viewpoint Analysis

      + +
      +

      🤝 Points of Agreement

      +
      +
        + {agreements_html} +
      +
      +
      + +
      +

      ⚖️ Areas of Tension

      +
      + {tensions_html} +
      +
      + +
      +

      🔍 Perspective Gaps

      +
      +

      {perspective_gaps}

      +
      +
      + +
      +

      💡 Integrative Insights

      +
      +

      {integrative_insights}

      +
      +
      +
      +""" diff --git a/deep_research/utils/pydantic_models.py b/deep_research/utils/pydantic_models.py new file mode 100644 index 00000000..822afe99 --- /dev/null +++ b/deep_research/utils/pydantic_models.py @@ -0,0 +1,503 @@ +"""Pydantic model definitions for the research pipeline. + +This module contains all the Pydantic models that represent the state of the research +pipeline. These models replace the previous dataclasses implementation and leverage +Pydantic's validation, serialization, and integration with ZenML. +""" + +import time +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Field +from typing_extensions import Literal + + +class Prompt(BaseModel): + """A single prompt with metadata for tracking and visualization. + + This class is designed to be simple and intuitive to use. You can access + the prompt content directly via the content attribute or by converting + to string. + """ + + content: str = Field(..., description="The actual prompt text") + name: str = Field(..., description="Unique identifier for the prompt") + description: str = Field( + "", description="Human-readable description of what this prompt does" + ) + version: str = Field("1.0.0", description="Version of the prompt") + tags: List[str] = Field( + default_factory=list, description="Tags for categorizing the prompt" + ) + + model_config = { + "extra": "ignore", + "frozen": False, + "validate_assignment": True, + } + + def __str__(self) -> str: + """Return the prompt content as a string.""" + return self.content + + def __repr__(self) -> str: + """Return a readable representation of the prompt.""" + return f"Prompt(name='{self.name}', version='{self.version}')" + + @classmethod + def create( + cls, + content: str, + name: str, + description: str = "", + version: str = "1.0.0", + tags: Optional[List[str]] = None, + ) -> "Prompt": + """Factory method to create a Prompt instance. + + Args: + content: The prompt text + name: Unique identifier for the prompt + description: Optional description of the prompt's purpose + version: Version string (defaults to "1.0.0") + tags: Optional list of tags for categorization + + Returns: + A new Prompt instance + """ + return cls( + content=content, + name=name, + description=description, + version=version, + tags=tags or [], + ) + + +class SearchResult(BaseModel): + """Represents a search result for a sub-question.""" + + url: str = "" + content: str = "" + title: str = "" + snippet: str = "" + metadata: Optional[Dict[str, Any]] = Field(default_factory=dict) + + model_config = { + "extra": "ignore", # Ignore extra fields during deserialization + "frozen": False, # Allow attribute updates + "validate_assignment": True, # Validate when attributes are set + } + + +class ViewpointTension(BaseModel): + """Represents a tension between different viewpoints on a topic.""" + + topic: str = "" + viewpoints: Dict[str, str] = Field(default_factory=dict) + + model_config = { + "extra": "ignore", + "frozen": False, + "validate_assignment": True, + } + + +class SynthesizedInfo(BaseModel): + """Represents synthesized information for a sub-question.""" + + synthesized_answer: str = "" + key_sources: List[str] = Field(default_factory=list) + confidence_level: Literal["high", "medium", "low"] = "medium" + information_gaps: str = "" + improvements: List[str] = Field(default_factory=list) + + model_config = { + "extra": "ignore", + "frozen": False, + "validate_assignment": True, + } + + +class ViewpointAnalysis(BaseModel): + """Represents the analysis of different viewpoints on the research topic.""" + + main_points_of_agreement: List[str] = Field(default_factory=list) + areas_of_tension: List[ViewpointTension] = Field(default_factory=list) + perspective_gaps: str = "" + integrative_insights: str = "" + + model_config = { + "extra": "ignore", + "frozen": False, + "validate_assignment": True, + } + + +class ReflectionMetadata(BaseModel): + """Metadata about the reflection process.""" + + critique_summary: List[str] = Field(default_factory=list) + additional_questions_identified: List[str] = Field(default_factory=list) + searches_performed: List[str] = Field(default_factory=list) + improvements_made: float = Field( + default=0 + ) # Changed from int to float to handle timestamp values + error: Optional[str] = None + + model_config = { + "extra": "ignore", + "frozen": False, + "validate_assignment": True, + } + + +class ResearchState(BaseModel): + """Comprehensive state object for the enhanced research pipeline.""" + + # Initial query information + main_query: str = "" + sub_questions: List[str] = Field(default_factory=list) + + # Information gathering results + search_results: Dict[str, List[SearchResult]] = Field(default_factory=dict) + + # Synthesized information + synthesized_info: Dict[str, SynthesizedInfo] = Field(default_factory=dict) + + # Viewpoint analysis + viewpoint_analysis: Optional[ViewpointAnalysis] = None + + # Reflection results + enhanced_info: Dict[str, SynthesizedInfo] = Field(default_factory=dict) + reflection_metadata: Optional[ReflectionMetadata] = None + + # Final report + final_report_html: str = "" + + # Search cost tracking + search_costs: Dict[str, float] = Field( + default_factory=dict, + description="Total costs by search provider (e.g., {'exa': 0.0, 'tavily': 0.0})", + ) + search_cost_details: List[Dict[str, Any]] = Field( + default_factory=list, + description="Detailed log of each search with cost information", + ) + # Format: [{"provider": "exa", "query": "...", "cost": 0.0, "timestamp": ..., "step": "...", "sub_question": "..."}] + + model_config = { + "extra": "ignore", + "frozen": False, + "validate_assignment": True, + } + + def get_current_stage(self) -> str: + """Determine the current stage of research based on filled data.""" + if self.final_report_html: + return "final_report" + elif self.enhanced_info: + return "after_reflection" + elif self.viewpoint_analysis: + return "after_viewpoint_analysis" + elif self.synthesized_info: + return "after_synthesis" + elif self.search_results: + return "after_search" + elif self.sub_questions: + return "after_query_decomposition" + elif self.main_query: + return "initial" + else: + return "empty" + + def update_sub_questions(self, sub_questions: List[str]) -> None: + """Update the sub-questions list.""" + self.sub_questions = sub_questions + + def update_search_results( + self, search_results: Dict[str, List[SearchResult]] + ) -> None: + """Update the search results.""" + self.search_results = search_results + + def update_synthesized_info( + self, synthesized_info: Dict[str, SynthesizedInfo] + ) -> None: + """Update the synthesized information.""" + self.synthesized_info = synthesized_info + + def update_viewpoint_analysis( + self, viewpoint_analysis: ViewpointAnalysis + ) -> None: + """Update the viewpoint analysis.""" + self.viewpoint_analysis = viewpoint_analysis + + def update_after_reflection( + self, + enhanced_info: Dict[str, SynthesizedInfo], + metadata: ReflectionMetadata, + ) -> None: + """Update with reflection results.""" + self.enhanced_info = enhanced_info + self.reflection_metadata = metadata + + def set_final_report(self, html: str) -> None: + """Set the final report HTML.""" + self.final_report_html = html + + +class ApprovalDecision(BaseModel): + """Approval decision from human reviewer.""" + + approved: bool = False + selected_queries: List[str] = Field(default_factory=list) + approval_method: str = "" # "APPROVE_ALL", "SKIP", "SELECT_SPECIFIC" + reviewer_notes: str = "" + timestamp: float = Field(default_factory=lambda: time.time()) + + model_config = { + "extra": "ignore", + "frozen": False, + "validate_assignment": True, + } + + +class PromptTypeMetrics(BaseModel): + """Metrics for a specific prompt type.""" + + prompt_type: str + total_cost: float + input_tokens: int + output_tokens: int + call_count: int + avg_cost_per_call: float + percentage_of_total_cost: float + + model_config = { + "extra": "ignore", + "frozen": False, + "validate_assignment": True, + } + + +class TracingMetadata(BaseModel): + """Metadata about token usage, costs, and performance for a pipeline run.""" + + # Pipeline information + pipeline_run_name: str = "" + pipeline_run_id: str = "" + + # Token usage + total_input_tokens: int = 0 + total_output_tokens: int = 0 + total_tokens: int = 0 + + # Cost information + total_cost: float = 0.0 + cost_breakdown_by_model: Dict[str, float] = Field(default_factory=dict) + + # Performance metrics + total_latency_seconds: float = 0.0 + formatted_latency: str = "" + observation_count: int = 0 + + # Model usage + models_used: List[str] = Field(default_factory=list) + model_token_breakdown: Dict[str, Dict[str, int]] = Field( + default_factory=dict + ) + # Format: {"model_name": {"input_tokens": X, "output_tokens": Y, "total_tokens": Z}} + + # Trace information + trace_id: str = "" + trace_name: str = "" + trace_tags: List[str] = Field(default_factory=list) + trace_metadata: Dict[str, Any] = Field(default_factory=dict) + + # Step-by-step breakdown + step_costs: Dict[str, float] = Field(default_factory=dict) + step_tokens: Dict[str, Dict[str, int]] = Field(default_factory=dict) + # Format: {"step_name": {"input_tokens": X, "output_tokens": Y}} + + # Prompt-level metrics + prompt_metrics: List[PromptTypeMetrics] = Field( + default_factory=list, description="Cost breakdown by prompt type" + ) + + # Search provider costs + search_costs: Dict[str, float] = Field( + default_factory=dict, description="Total costs by search provider" + ) + search_queries_count: Dict[str, int] = Field( + default_factory=dict, + description="Number of queries by search provider", + ) + search_cost_details: List[Dict[str, Any]] = Field( + default_factory=list, description="Detailed search cost information" + ) + + # Timestamp + collected_at: float = Field(default_factory=lambda: time.time()) + + model_config = { + "extra": "ignore", + "frozen": False, + "validate_assignment": True, + } + + +# ============================================================================ +# New Artifact Classes for ResearchState Refactoring +# ============================================================================ + + +class QueryContext(BaseModel): + """Immutable context containing the research query and its decomposition. + + This artifact is created once at the beginning of the pipeline and + remains unchanged throughout execution. + """ + + main_query: str = Field( + ..., description="The main research question from the user" + ) + sub_questions: List[str] = Field( + default_factory=list, + description="Decomposed sub-questions for parallel processing", + ) + decomposition_timestamp: float = Field( + default_factory=lambda: time.time(), + description="When the query was decomposed", + ) + + model_config = { + "extra": "ignore", + "frozen": True, # Make immutable after creation + "validate_assignment": True, + } + + +class SearchCostDetail(BaseModel): + """Detailed information about a single search operation.""" + + provider: str + query: str + cost: float + timestamp: float + step: str + sub_question: Optional[str] = None + + model_config = { + "extra": "ignore", + "frozen": False, + "validate_assignment": True, + } + + +class SearchData(BaseModel): + """Accumulates search results and cost tracking throughout the pipeline. + + This artifact grows as searches are performed and can be merged + when parallel searches complete. + """ + + search_results: Dict[str, List[SearchResult]] = Field( + default_factory=dict, + description="Map of sub-question to search results", + ) + search_costs: Dict[str, float] = Field( + default_factory=dict, description="Total costs by provider" + ) + search_cost_details: List[SearchCostDetail] = Field( + default_factory=list, description="Detailed log of each search" + ) + total_searches: int = Field( + default=0, description="Total number of searches performed" + ) + + model_config = { + "extra": "ignore", + "frozen": False, + "validate_assignment": True, + } + + def merge(self, other: "SearchData") -> "SearchData": + """Merge another SearchData instance into this one.""" + # Merge search results + for sub_q, results in other.search_results.items(): + if sub_q in self.search_results: + self.search_results[sub_q].extend(results) + else: + self.search_results[sub_q] = results + + # Merge costs + for provider, cost in other.search_costs.items(): + self.search_costs[provider] = ( + self.search_costs.get(provider, 0.0) + cost + ) + + # Merge cost details + self.search_cost_details.extend(other.search_cost_details) + + # Update total searches + self.total_searches += other.total_searches + + return self + + +class SynthesisData(BaseModel): + """Contains synthesized information for all sub-questions.""" + + synthesized_info: Dict[str, SynthesizedInfo] = Field( + default_factory=dict, + description="Synthesized answers for each sub-question", + ) + enhanced_info: Dict[str, SynthesizedInfo] = Field( + default_factory=dict, + description="Enhanced information after reflection", + ) + + model_config = { + "extra": "ignore", + "frozen": False, + "validate_assignment": True, + } + + def merge(self, other: "SynthesisData") -> "SynthesisData": + """Merge another SynthesisData instance into this one.""" + self.synthesized_info.update(other.synthesized_info) + self.enhanced_info.update(other.enhanced_info) + return self + + +class AnalysisData(BaseModel): + """Contains viewpoint analysis and reflection metadata.""" + + viewpoint_analysis: Optional[ViewpointAnalysis] = None + reflection_metadata: Optional[ReflectionMetadata] = None + + model_config = { + "extra": "ignore", + "frozen": False, + "validate_assignment": True, + } + + +class FinalReport(BaseModel): + """Contains the final HTML report.""" + + report_html: str = Field(default="", description="The final HTML report") + generated_at: float = Field( + default_factory=lambda: time.time(), + description="Timestamp when report was generated", + ) + main_query: str = Field( + default="", description="The original research query" + ) + + model_config = { + "extra": "ignore", + "frozen": False, + "validate_assignment": True, + } diff --git a/deep_research/utils/search_utils.py b/deep_research/utils/search_utils.py new file mode 100644 index 00000000..e63aa650 --- /dev/null +++ b/deep_research/utils/search_utils.py @@ -0,0 +1,721 @@ +import logging +import os +from enum import Enum +from typing import Any, Dict, List, Optional, Union + +from tavily import TavilyClient + +try: + from exa_py import Exa + + EXA_AVAILABLE = True +except ImportError: + EXA_AVAILABLE = False + Exa = None + +from utils.llm_utils import get_structured_llm_output +from utils.prompts import DEFAULT_SEARCH_QUERY_PROMPT +from utils.pydantic_models import SearchResult + +logger = logging.getLogger(__name__) + + +class SearchProvider(Enum): + TAVILY = "tavily" + EXA = "exa" + BOTH = "both" + + +class SearchEngineConfig: + """Configuration for search engines""" + + def __init__(self): + self.tavily_api_key = os.getenv("TAVILY_API_KEY") + self.exa_api_key = os.getenv("EXA_API_KEY") + self.default_provider = os.getenv("DEFAULT_SEARCH_PROVIDER", "tavily") + self.enable_parallel_search = ( + os.getenv("ENABLE_PARALLEL_SEARCH", "false").lower() == "true" + ) + + +def get_search_client(provider: Union[str, SearchProvider]) -> Optional[Any]: + """Get the appropriate search client based on provider.""" + if isinstance(provider, str): + provider = SearchProvider(provider.lower()) + + config = SearchEngineConfig() + + if provider == SearchProvider.TAVILY: + if not config.tavily_api_key: + raise ValueError("TAVILY_API_KEY environment variable not set") + return TavilyClient(api_key=config.tavily_api_key) + + elif provider == SearchProvider.EXA: + if not EXA_AVAILABLE: + raise ImportError( + "exa-py is not installed. Please install it with: pip install exa-py" + ) + if not config.exa_api_key: + raise ValueError("EXA_API_KEY environment variable not set") + return Exa(config.exa_api_key) + + return None + + +def tavily_search( + query: str, + include_raw_content: bool = True, + max_results: int = 3, + cap_content_length: int = 20000, +) -> Dict[str, Any]: + """Perform a search using the Tavily API. + + Args: + query: Search query + include_raw_content: Whether to include raw content in results + max_results: Maximum number of results to return + cap_content_length: Maximum length of content to return + + Returns: + Dict[str, Any]: Search results from Tavily in the following format: + { + "query": str, # The original query + "results": List[Dict], # List of search result objects + "error": str, # Error message (if an error occurred, otherwise omitted) + } + + Each result in "results" has the following structure: + { + "url": str, # URL of the search result + "raw_content": str, # Raw content of the page (if include_raw_content=True) + "title": str, # Title of the page + "snippet": str, # Snippet of the page content + } + """ + try: + tavily_client = get_search_client(SearchProvider.TAVILY) + + # First try with advanced search + results = tavily_client.search( + query=query, + include_raw_content=include_raw_content, + max_results=max_results, + search_depth="advanced", # Use advanced search for better results + include_domains=[], # No domain restrictions + exclude_domains=[], # No exclusions + include_answer=False, # We don't need the answer field + include_images=False, # We don't need images + # Note: 'include_snippets' is not a supported parameter + ) + + # Check if we got good results (with non-None and non-empty content) + if include_raw_content and "results" in results: + bad_content_count = sum( + 1 + for r in results["results"] + if "raw_content" in r + and ( + r["raw_content"] is None or r["raw_content"].strip() == "" + ) + ) + + # If more than half of results have bad content, try a different approach + if bad_content_count > len(results["results"]) / 2: + logger.warning( + f"{bad_content_count}/{len(results['results'])} results have None or empty content. " + "Trying to use 'content' field instead of 'raw_content'..." + ) + + # Try to use the 'content' field which comes by default + for result in results["results"]: + if ( + "raw_content" in result + and ( + result["raw_content"] is None + or result["raw_content"].strip() == "" + ) + ) and "content" in result: + result["raw_content"] = result["content"] + logger.info( + f"Using 'content' field as 'raw_content' for URL {result.get('url', 'unknown')}" + ) + + # Re-check after our fix + bad_content_count = sum( + 1 + for r in results["results"] + if "raw_content" in r + and ( + r["raw_content"] is None + or r["raw_content"].strip() == "" + ) + ) + + if bad_content_count > 0: + logger.warning( + f"Still have {bad_content_count}/{len(results['results'])} results with bad content after fixes." + ) + + # Try alternative approach - search with 'include_answer=True' + try: + # Search with include_answer=True which may give us better content + logger.info( + "Trying alternative search with include_answer=True" + ) + alt_results = tavily_client.search( + query=query, + include_raw_content=include_raw_content, + max_results=max_results, + search_depth="advanced", + include_domains=[], + exclude_domains=[], + include_answer=True, # Include answer this time + include_images=False, + ) + + # Check if we got any improved content + if "results" in alt_results: + # Create a merged results set taking the best content + for i, result in enumerate(alt_results["results"]): + if i < len(results["results"]): + if ( + "raw_content" in result + and result["raw_content"] + and ( + results["results"][i].get( + "raw_content" + ) + is None + or results["results"][i] + .get("raw_content", "") + .strip() + == "" + ) + ): + # Replace the bad content with better content from alt_results + results["results"][i]["raw_content"] = ( + result["raw_content"] + ) + logger.info( + f"Replaced bad content with better content from alternative search for URL {result.get('url', 'unknown')}" + ) + + # If answer is available, add it as a special result + if "answer" in alt_results and alt_results["answer"]: + answer_text = alt_results["answer"] + answer_result = { + "url": "tavily-generated-answer", + "title": "Generated Answer", + "raw_content": f"Generated Answer based on search results:\n\n{answer_text}", + "content": answer_text, + } + results["results"].append(answer_result) + logger.info( + "Added Tavily generated answer as additional search result" + ) + + except Exception as alt_error: + logger.warning( + f"Failed to get better results with alternative search: {alt_error}" + ) + + # Cap content length if specified + if cap_content_length > 0 and "results" in results: + for result in results["results"]: + if "raw_content" in result and result["raw_content"]: + result["raw_content"] = result["raw_content"][ + :cap_content_length + ] + + return results + except Exception as e: + logger.error(f"Error in Tavily search: {e}") + # Return an error structure that's compatible with our expected format + return {"query": query, "results": [], "error": str(e)} + + +def exa_search( + query: str, + max_results: int = 3, + cap_content_length: int = 20000, + search_mode: str = "auto", + include_highlights: bool = False, +) -> Dict[str, Any]: + """Perform a search using the Exa API. + + Args: + query: Search query + max_results: Maximum number of results to return + cap_content_length: Maximum length of content to return + search_mode: Search mode ("neural", "keyword", or "auto") + include_highlights: Whether to include highlights in results + + Returns: + Dict[str, Any]: Search results from Exa in a format compatible with Tavily + """ + try: + exa_client = get_search_client(SearchProvider.EXA) + + # Configure content options + text_options = {"max_characters": cap_content_length} + + kwargs = { + "query": query, + "num_results": max_results, + "type": search_mode, # "neural", "keyword", or "auto" + "text": text_options, + } + + if include_highlights: + kwargs["highlights"] = { + "highlights_per_url": 2, + "num_sentences": 3, + } + + response = exa_client.search_and_contents(**kwargs) + + # Extract cost information + exa_cost = 0.0 + if hasattr(response, "cost_dollars") and hasattr( + response.cost_dollars, "total" + ): + exa_cost = response.cost_dollars.total + logger.info( + f"Exa search cost for query '{query}': ${exa_cost:.4f}" + ) + + # Convert to standardized format compatible with Tavily + results = {"query": query, "results": [], "exa_cost": exa_cost} + + for r in response.results: + result_dict = { + "url": r.url, + "title": r.title or "", + "snippet": "", + "raw_content": getattr(r, "text", ""), + "content": getattr(r, "text", ""), + } + + # Add highlights as snippet if available + if hasattr(r, "highlights") and r.highlights: + result_dict["snippet"] = " ".join(r.highlights[:1]) + + # Store additional metadata + result_dict["_metadata"] = { + "provider": "exa", + "score": getattr(r, "score", None), + "published_date": getattr(r, "published_date", None), + "author": getattr(r, "author", None), + } + + results["results"].append(result_dict) + + return results + + except Exception as e: + logger.error(f"Error in Exa search: {e}") + return {"query": query, "results": [], "error": str(e)} + + +def unified_search( + query: str, + provider: Union[str, SearchProvider, None] = None, + max_results: int = 3, + cap_content_length: int = 20000, + search_mode: str = "auto", + include_highlights: bool = False, + compare_results: bool = False, + **kwargs, +) -> Union[List[SearchResult], Dict[str, List[SearchResult]]]: + """Unified search interface supporting multiple providers. + + Args: + query: Search query + provider: Search provider to use (tavily, exa, both) + max_results: Maximum number of results + cap_content_length: Maximum content length + search_mode: Search mode for Exa ("neural", "keyword", "auto") + include_highlights: Include highlights for Exa results + compare_results: Return results from both providers separately + + Returns: + List[SearchResult] or Dict mapping provider to results (when compare_results=True or provider="both") + """ + # Use default provider if not specified + if provider is None: + config = SearchEngineConfig() + provider = config.default_provider + + # Convert string to enum if needed + if isinstance(provider, str): + provider = SearchProvider(provider.lower()) + + # Handle single provider case + if provider == SearchProvider.TAVILY: + results = tavily_search( + query, + max_results=max_results, + cap_content_length=cap_content_length, + ) + extracted, cost = extract_search_results(results, provider="tavily") + return extracted if not compare_results else {"tavily": extracted} + + elif provider == SearchProvider.EXA: + results = exa_search( + query=query, + max_results=max_results, + cap_content_length=cap_content_length, + search_mode=search_mode, + include_highlights=include_highlights, + ) + extracted, cost = extract_search_results(results, provider="exa") + return extracted if not compare_results else {"exa": extracted} + + elif provider == SearchProvider.BOTH: + # Run both searches + tavily_results = tavily_search( + query, + max_results=max_results, + cap_content_length=cap_content_length, + ) + exa_results = exa_search( + query=query, + max_results=max_results, + cap_content_length=cap_content_length, + search_mode=search_mode, + include_highlights=include_highlights, + ) + + # Extract results from both + tavily_extracted, tavily_cost = extract_search_results( + tavily_results, provider="tavily" + ) + exa_extracted, exa_cost = extract_search_results( + exa_results, provider="exa" + ) + + if compare_results: + return {"tavily": tavily_extracted, "exa": exa_extracted} + else: + # Merge results, interleaving them + merged = [] + max_len = max(len(tavily_extracted), len(exa_extracted)) + for i in range(max_len): + if i < len(tavily_extracted): + merged.append(tavily_extracted[i]) + if i < len(exa_extracted): + merged.append(exa_extracted[i]) + return merged[:max_results] # Limit to requested number + + else: + raise ValueError(f"Unknown provider: {provider}") + + +def extract_search_results( + search_results: Dict[str, Any], provider: str = "tavily" +) -> tuple[List[SearchResult], float]: + """Extract SearchResult objects from provider-specific API responses. + + Args: + search_results: Results from search API + provider: Which provider the results came from + + Returns: + Tuple of (List[SearchResult], float): List of converted SearchResult objects with standardized fields + and the search cost (0.0 if not available). + SearchResult is a Pydantic model defined in data_models.py that includes: + - url: The URL of the search result + - content: The raw content of the page + - title: The title of the page + - snippet: A brief snippet of the page content + """ + results_list = [] + search_cost = search_results.get( + "exa_cost", 0.0 + ) # Extract cost if present + + if "results" in search_results: + for result in search_results["results"]: + if "url" in result: + # Get fields with defaults + url = result["url"] + title = result.get("title", "") + + # Try to extract the best content available: + # 1. First try raw_content (if we requested it) + # 2. Then try regular content (always available) + # 3. Then try to use snippet combined with title + # 4. Last resort: use just title + + raw_content = result.get("raw_content", None) + regular_content = result.get("content", "") + snippet = result.get("snippet", "") + + # Set our final content - prioritize raw_content if available and not None + if raw_content is not None and raw_content.strip(): + content = raw_content + # Next best is the regular content field + elif regular_content and regular_content.strip(): + content = regular_content + logger.info( + f"Using 'content' field for URL {url} because raw_content was not available" + ) + # Try to create a usable content from snippet and title + elif snippet: + content = f"Title: {title}\n\nContent: {snippet}" + logger.warning( + f"Using title and snippet as content fallback for {url}" + ) + # Last resort - just use the title + elif title: + content = ( + f"Title: {title}\n\nNo content available for this URL." + ) + logger.warning( + f"Using only title as content fallback for {url}" + ) + # Nothing available + else: + content = "" + logger.warning( + f"No content available for URL {url}, using empty string" + ) + + # Create SearchResult with provider metadata + search_result = SearchResult( + url=url, + content=content, + title=title, + snippet=snippet, + ) + + # Add provider info to metadata if available + if "_metadata" in result: + search_result.metadata = result["_metadata"] + else: + search_result.metadata = {"provider": provider} + + results_list.append(search_result) + + # If we got the answer (Tavily specific), add it as a special result + if ( + provider == "tavily" + and "answer" in search_results + and search_results["answer"] + ): + answer_text = search_results["answer"] + results_list.append( + SearchResult( + url="tavily-generated-answer", + content=f"Generated Answer based on search results:\n\n{answer_text}", + title="Tavily Generated Answer", + snippet=answer_text[:100] + "..." + if len(answer_text) > 100 + else answer_text, + metadata={"provider": "tavily", "type": "generated_answer"}, + ) + ) + logger.info("Added Tavily generated answer as a search result") + + return results_list, search_cost + + +def generate_search_query( + sub_question: str, + model: str = "openrouter/google/gemini-2.0-flash-lite-001", + system_prompt: Optional[str] = None, + project: str = "deep-research", +) -> Dict[str, Any]: + """Generate an optimized search query for a sub-question. + + Uses litellm for model inference via get_structured_llm_output. + + Args: + sub_question: The sub-question to generate a search query for + model: Model to use (with provider prefix) + system_prompt: System prompt for the LLM, defaults to DEFAULT_SEARCH_QUERY_PROMPT + project: Langfuse project name for LLM tracking + + Returns: + Dictionary with search query and reasoning + """ + if system_prompt is None: + system_prompt = DEFAULT_SEARCH_QUERY_PROMPT + + fallback_response = {"search_query": sub_question, "reasoning": ""} + + return get_structured_llm_output( + prompt=sub_question, + system_prompt=system_prompt, + model=model, + fallback_response=fallback_response, + project=project, + ) + + +def search_and_extract_results( + query: str, + max_results: int = 3, + cap_content_length: int = 20000, + max_retries: int = 2, + provider: Optional[Union[str, SearchProvider]] = None, + search_mode: str = "auto", + include_highlights: bool = False, +) -> tuple[List[SearchResult], float]: + """Perform a search and extract results in one step. + + Args: + query: Search query + max_results: Maximum number of results to return + cap_content_length: Maximum length of content to return + max_retries: Maximum number of retries in case of failure + provider: Search provider to use (tavily, exa, both) + search_mode: Search mode for Exa ("neural", "keyword", "auto") + include_highlights: Include highlights for Exa results + + Returns: + Tuple of (List of SearchResult objects, search cost) + """ + results = [] + total_cost = 0.0 + retry_count = 0 + + # List of alternative query formats to try if the original query fails + # to yield good results with non-None content + alternative_queries = [ + query, # Original query first + f'"{query}"', # Try exact phrase matching + f"about {query}", # Try broader context + f"research on {query}", # Try research-oriented results + query.replace(" OR ", " "), # Try without OR operator + ] + + while retry_count <= max_retries and retry_count < len( + alternative_queries + ): + try: + current_query = alternative_queries[retry_count] + logger.info( + f"Searching with query ({retry_count + 1}/{max_retries + 1}): {current_query}" + ) + + # Determine if we're using Exa to track costs + using_exa = False + if provider: + if isinstance(provider, str): + using_exa = provider.lower() in ["exa", "both"] + else: + using_exa = provider in [ + SearchProvider.EXA, + SearchProvider.BOTH, + ] + else: + config = SearchEngineConfig() + using_exa = config.default_provider.lower() in ["exa", "both"] + + # Perform search based on provider + if using_exa and provider != SearchProvider.BOTH: + # Direct Exa search + search_results = exa_search( + query=current_query, + max_results=max_results, + cap_content_length=cap_content_length, + search_mode=search_mode, + include_highlights=include_highlights, + ) + results, cost = extract_search_results( + search_results, provider="exa" + ) + total_cost += cost + elif provider == SearchProvider.BOTH: + # Search with both providers + tavily_results = tavily_search( + current_query, + max_results=max_results, + cap_content_length=cap_content_length, + ) + exa_results = exa_search( + query=current_query, + max_results=max_results, + cap_content_length=cap_content_length, + search_mode=search_mode, + include_highlights=include_highlights, + ) + + # Extract results from both + tavily_extracted, _ = extract_search_results( + tavily_results, provider="tavily" + ) + exa_extracted, exa_cost = extract_search_results( + exa_results, provider="exa" + ) + total_cost += exa_cost + + # Merge results + results = [] + max_len = max(len(tavily_extracted), len(exa_extracted)) + for i in range(max_len): + if i < len(tavily_extracted): + results.append(tavily_extracted[i]) + if i < len(exa_extracted): + results.append(exa_extracted[i]) + results = results[:max_results] + else: + # Tavily search or unified search + results = unified_search( + query=current_query, + provider=provider, + max_results=max_results, + cap_content_length=cap_content_length, + search_mode=search_mode, + include_highlights=include_highlights, + ) + + # Handle case where unified_search returns a dict + if isinstance(results, dict): + all_results = [] + for provider_results in results.values(): + all_results.extend(provider_results) + results = all_results[:max_results] + + # Check if we got results with actual content + if results: + # Count results with non-empty content + content_results = sum(1 for r in results if r.content.strip()) + + if content_results >= max(1, len(results) // 2): + logger.info( + f"Found {content_results}/{len(results)} results with content" + ) + return results, total_cost + else: + logger.warning( + f"Only found {content_results}/{len(results)} results with content. " + f"Trying alternative query..." + ) + + # If we didn't get good results but haven't hit max retries yet, try again + if retry_count < max_retries: + logger.warning( + f"Inadequate search results. Retrying with alternative query... ({retry_count + 1}/{max_retries})" + ) + retry_count += 1 + else: + # If we're out of retries, return whatever we have + logger.warning( + f"Out of retries. Returning best results found ({len(results)} results)." + ) + return results, total_cost + + except Exception as e: + if retry_count < max_retries: + logger.warning( + f"Search failed with error: {e}. Retrying... ({retry_count + 1}/{max_retries})" + ) + retry_count += 1 + else: + logger.error(f"Search failed after {max_retries} retries: {e}") + return [], 0.0 + + # If we've exhausted all retries, return the best results we have + return results, total_cost diff --git a/deep_research/utils/tracing_metadata_utils.py b/deep_research/utils/tracing_metadata_utils.py new file mode 100644 index 00000000..59c7b37e --- /dev/null +++ b/deep_research/utils/tracing_metadata_utils.py @@ -0,0 +1,745 @@ +"""Utilities for collecting and analyzing tracing metadata from Langfuse.""" + +import time +from datetime import datetime, timedelta, timezone +from functools import wraps +from typing import Any, Dict, List, Optional, Tuple + +from langfuse import Langfuse +from langfuse.api.core import ApiError +from langfuse.client import ObservationsView, TraceWithDetails +from rich import print +from rich.console import Console +from rich.table import Table + +console = Console() + +langfuse = Langfuse() + +# Prompt type identification keywords +PROMPT_IDENTIFIERS = { + "query_decomposition": [ + "MAIN RESEARCH QUERY", + "DIFFERENT DIMENSIONS", + "sub-questions", + ], + "search_query": ["Deep Research assistant", "effective search query"], + "synthesis": [ + "information synthesis", + "comprehensive answer", + "confidence level", + ], + "viewpoint_analysis": [ + "multi-perspective analysis", + "viewpoint categories", + ], + "reflection": ["critique and improve", "information gaps"], + "additional_synthesis": ["enhance the original synthesis"], + "conclusion_generation": [ + "Synthesis and Integration", + "Direct Response to Main Query", + ], + "executive_summary": [ + "executive summaries", + "Key Findings", + "250-400 words", + ], + "introduction": ["engaging introductions", "Context and Relevance"], +} + +# Rate limiting configuration +# Adjust these based on your Langfuse tier: +# - Hobby: 30 req/min for Other APIs -> ~2s between requests +# - Core: 100 req/min -> ~0.6s between requests +# - Pro: 1000 req/min -> ~0.06s between requests +RATE_LIMIT_DELAY = 0.1 # 100ms between requests (safe for most tiers) +MAX_RETRIES = 3 +INITIAL_BACKOFF = 1.0 # Initial backoff in seconds + +# Batch processing configuration +BATCH_DELAY = 0.5 # Additional delay between batches of requests + + +def rate_limited(func): + """Decorator to add rate limiting between API calls.""" + + @wraps(func) + def wrapper(*args, **kwargs): + time.sleep(RATE_LIMIT_DELAY) + return func(*args, **kwargs) + + return wrapper + + +def retry_with_backoff(func): + """Decorator to retry functions with exponential backoff on rate limit errors.""" + + @wraps(func) + def wrapper(*args, **kwargs): + backoff = INITIAL_BACKOFF + last_exception = None + + for attempt in range(MAX_RETRIES): + try: + return func(*args, **kwargs) + except ApiError as e: + if e.status_code == 429: # Rate limit error + last_exception = e + if attempt < MAX_RETRIES - 1: + wait_time = backoff * (2**attempt) + console.print( + f"[yellow]Rate limit hit. Retrying in {wait_time:.1f}s...[/yellow]" + ) + time.sleep(wait_time) + continue + raise + except Exception: + # For non-rate limit errors, raise immediately + raise + + # If we've exhausted all retries + if last_exception: + raise last_exception + + return wrapper + + +@rate_limited +@retry_with_backoff +def fetch_traces_safe(limit: Optional[int] = None) -> List[TraceWithDetails]: + """Safely fetch traces with rate limiting and retry logic.""" + return langfuse.fetch_traces(limit=limit).data + + +@rate_limited +@retry_with_backoff +def fetch_observations_safe(trace_id: str) -> List[ObservationsView]: + """Safely fetch observations with rate limiting and retry logic.""" + return langfuse.fetch_observations(trace_id=trace_id).data + + +def get_total_trace_cost(trace_id: str) -> float: + """Calculate the total cost for a single trace by summing all observation costs. + + Args: + trace_id: The ID of the trace to calculate cost for + + Returns: + Total cost across all observations in the trace + """ + try: + observations = fetch_observations_safe(trace_id=trace_id) + total_cost = 0.0 + + for obs in observations: + # Check multiple possible cost fields + if ( + hasattr(obs, "calculated_total_cost") + and obs.calculated_total_cost + ): + total_cost += obs.calculated_total_cost + elif hasattr(obs, "total_price") and obs.total_price: + total_cost += obs.total_price + elif hasattr(obs, "total_cost") and obs.total_cost: + total_cost += obs.total_cost + # If cost details are available, calculate from input/output costs + elif hasattr(obs, "calculated_input_cost") and hasattr( + obs, "calculated_output_cost" + ): + if obs.calculated_input_cost and obs.calculated_output_cost: + total_cost += ( + obs.calculated_input_cost + obs.calculated_output_cost + ) + + return total_cost + except Exception as e: + print(f"[red]Error calculating trace cost: {e}[/red]") + return 0.0 + + +def get_total_tokens_used(trace_id: str) -> Tuple[int, int]: + """Calculate total input and output tokens used for a trace. + + Args: + trace_id: The ID of the trace to calculate tokens for + + Returns: + Tuple of (input_tokens, output_tokens) + """ + try: + observations = fetch_observations_safe(trace_id=trace_id) + total_input_tokens = 0 + total_output_tokens = 0 + + for obs in observations: + # Check for token fields in different possible locations + if hasattr(obs, "usage") and obs.usage: + if hasattr(obs.usage, "input") and obs.usage.input: + total_input_tokens += obs.usage.input + if hasattr(obs.usage, "output") and obs.usage.output: + total_output_tokens += obs.usage.output + # Also check for direct token fields + elif hasattr(obs, "promptTokens") and hasattr( + obs, "completionTokens" + ): + if obs.promptTokens: + total_input_tokens += obs.promptTokens + if obs.completionTokens: + total_output_tokens += obs.completionTokens + + return total_input_tokens, total_output_tokens + except Exception as e: + print(f"[red]Error calculating tokens: {e}[/red]") + return 0, 0 + + +def get_trace_stats(trace: TraceWithDetails) -> Dict[str, Any]: + """Get comprehensive statistics for a trace. + + Args: + trace: The trace object to analyze + + Returns: + Dictionary containing trace statistics including cost, latency, tokens, and metadata + """ + try: + # Get cost and token data + total_cost = get_total_trace_cost(trace.id) + input_tokens, output_tokens = get_total_tokens_used(trace.id) + + # Get observation count + observations = fetch_observations_safe(trace_id=trace.id) + observation_count = len(observations) + + # Extract model information from observations + models_used = set() + for obs in observations: + if hasattr(obs, "model") and obs.model: + models_used.add(obs.model) + + stats = { + "trace_id": trace.id, + "timestamp": trace.timestamp, + "total_cost": total_cost, + "latency_seconds": trace.latency + if hasattr(trace, "latency") + else 0, + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "total_tokens": input_tokens + output_tokens, + "observation_count": observation_count, + "models_used": list(models_used), + "metadata": trace.metadata if hasattr(trace, "metadata") else {}, + "tags": trace.tags if hasattr(trace, "tags") else [], + "user_id": trace.user_id if hasattr(trace, "user_id") else None, + "session_id": trace.session_id + if hasattr(trace, "session_id") + else None, + } + + # Add formatted latency + if stats["latency_seconds"]: + minutes = int(stats["latency_seconds"] // 60) + seconds = stats["latency_seconds"] % 60 + stats["latency_formatted"] = f"{minutes}m {seconds:.1f}s" + else: + stats["latency_formatted"] = "0m 0.0s" + + return stats + except Exception as e: + print(f"[red]Error getting trace stats: {e}[/red]") + return {} + + +def get_traces_by_name(name: str, limit: int = 1) -> List[TraceWithDetails]: + """Get traces by name using Langfuse API. + + Args: + name: The name of the trace to search for + limit: Maximum number of traces to return (default: 1) + + Returns: + List of traces matching the name + """ + try: + # Use the Langfuse API to get traces by name + traces_response = langfuse.get_traces(name=name, limit=limit) + return traces_response.data + except Exception as e: + print(f"[red]Error fetching traces by name: {e}[/red]") + return [] + + +def get_observations_for_trace(trace_id: str) -> List[ObservationsView]: + """Get all observations for a specific trace. + + Args: + trace_id: The ID of the trace + + Returns: + List of observations for the trace + """ + try: + observations_response = langfuse.get_observations(trace_id=trace_id) + return observations_response.data + except Exception as e: + print(f"[red]Error fetching observations: {e}[/red]") + return [] + + +def filter_traces_by_date_range( + start_date: datetime, end_date: datetime, limit: Optional[int] = None +) -> List[TraceWithDetails]: + """Filter traces within a specific date range. + + Args: + start_date: Start of the date range (inclusive) + end_date: End of the date range (inclusive) + limit: Maximum number of traces to return + + Returns: + List of traces within the date range + """ + try: + # Ensure dates are timezone-aware + if start_date.tzinfo is None: + start_date = start_date.replace(tzinfo=timezone.utc) + if end_date.tzinfo is None: + end_date = end_date.replace(tzinfo=timezone.utc) + + # Fetch all traces (or up to API maximum limit of 100) + all_traces = fetch_traces_safe(limit=limit or 100) + + # Filter by date range + filtered_traces = [ + trace + for trace in all_traces + if start_date <= trace.timestamp <= end_date + ] + + # Sort by timestamp (most recent first) + filtered_traces.sort(key=lambda x: x.timestamp, reverse=True) + + # Apply limit if specified + if limit: + filtered_traces = filtered_traces[:limit] + + return filtered_traces + except Exception as e: + print(f"[red]Error filtering traces by date range: {e}[/red]") + return [] + + +def get_traces_last_n_days( + days: int, limit: Optional[int] = None +) -> List[TraceWithDetails]: + """Get traces from the last N days. + + Args: + days: Number of days to look back + limit: Maximum number of traces to return + + Returns: + List of traces from the last N days + """ + end_date = datetime.now(timezone.utc) + start_date = end_date - timedelta(days=days) + + return filter_traces_by_date_range(start_date, end_date, limit) + + +def get_trace_stats_batch( + traces: List[TraceWithDetails], show_progress: bool = True +) -> List[Dict[str, Any]]: + """Get statistics for multiple traces efficiently with progress tracking. + + Args: + traces: List of traces to analyze + show_progress: Whether to show progress bar + + Returns: + List of dictionaries containing trace statistics + """ + stats_list = [] + + for i, trace in enumerate(traces): + if show_progress and i % 5 == 0: + console.print( + f"[dim]Processing trace {i + 1}/{len(traces)}...[/dim]" + ) + + stats = get_trace_stats(trace) + stats_list.append(stats) + + return stats_list + + +def get_aggregate_stats_for_traces( + traces: List[TraceWithDetails], +) -> Dict[str, Any]: + """Calculate aggregate statistics for a list of traces. + + Args: + traces: List of traces to analyze + + Returns: + Dictionary containing aggregate statistics + """ + if not traces: + return { + "trace_count": 0, + "total_cost": 0.0, + "total_input_tokens": 0, + "total_output_tokens": 0, + "total_tokens": 0, + "average_cost_per_trace": 0.0, + "average_latency_seconds": 0.0, + "total_observations": 0, + } + + total_cost = 0.0 + total_input_tokens = 0 + total_output_tokens = 0 + total_latency = 0.0 + total_observations = 0 + all_models = set() + + for trace in traces: + stats = get_trace_stats(trace) + total_cost += stats.get("total_cost", 0) + total_input_tokens += stats.get("input_tokens", 0) + total_output_tokens += stats.get("output_tokens", 0) + total_latency += stats.get("latency_seconds", 0) + total_observations += stats.get("observation_count", 0) + all_models.update(stats.get("models_used", [])) + + return { + "trace_count": len(traces), + "total_cost": total_cost, + "total_input_tokens": total_input_tokens, + "total_output_tokens": total_output_tokens, + "total_tokens": total_input_tokens + total_output_tokens, + "average_cost_per_trace": total_cost / len(traces) if traces else 0, + "average_latency_seconds": total_latency / len(traces) + if traces + else 0, + "total_observations": total_observations, + "models_used": list(all_models), + } + + +def display_trace_stats_table( + traces: List[TraceWithDetails], title: str = "Trace Statistics" +): + """Display trace statistics in a formatted table. + + Args: + traces: List of traces to display + title: Title for the table + """ + table = Table(title=title, show_header=True, header_style="bold magenta") + table.add_column("Trace ID", style="cyan", no_wrap=True) + table.add_column("Timestamp", style="yellow") + table.add_column("Cost ($)", justify="right", style="green") + table.add_column("Tokens (In/Out)", justify="right") + table.add_column("Latency", justify="right") + table.add_column("Observations", justify="right") + + for trace in traces[:10]: # Limit to 10 for display + stats = get_trace_stats(trace) + table.add_row( + stats["trace_id"][:12] + "...", + stats["timestamp"].strftime("%Y-%m-%d %H:%M"), + f"${stats['total_cost']:.4f}", + f"{stats['input_tokens']:,}/{stats['output_tokens']:,}", + stats["latency_formatted"], + str(stats["observation_count"]), + ) + + console.print(table) + + +def identify_prompt_type(observation: ObservationsView) -> str: + """Identify the prompt type based on keywords in the observation's input. + + Examines the system prompt in observation.input['messages'][0]['content'] + for unique keywords that identify each prompt type. + + Args: + observation: The observation to analyze + + Returns: + str: The prompt type name, or "unknown" if not identified + """ + try: + # Access the system prompt from the messages + if hasattr(observation, "input") and observation.input: + messages = observation.input.get("messages", []) + if messages and len(messages) > 0: + system_content = messages[0].get("content", "") + + # Check each prompt type's keywords + for prompt_type, keywords in PROMPT_IDENTIFIERS.items(): + # Check if any keyword is in the system prompt + for keyword in keywords: + if keyword in system_content: + return prompt_type + + return "unknown" + except Exception as e: + console.print( + f"[yellow]Warning: Could not identify prompt type: {e}[/yellow]" + ) + return "unknown" + + +def get_costs_by_prompt_type(trace_id: str) -> Dict[str, Dict[str, float]]: + """Get cost breakdown by prompt type for a given trace. + + Uses observation.usage.input/output for token counts and + observation.calculated_total_cost for costs. + + Args: + trace_id: The ID of the trace to analyze + + Returns: + Dict mapping prompt_type to { + 'cost': float, + 'input_tokens': int, + 'output_tokens': int, + 'count': int # number of calls + } + """ + try: + observations = fetch_observations_safe(trace_id=trace_id) + prompt_metrics = {} + + for obs in observations: + # Identify prompt type + prompt_type = identify_prompt_type(obs) + + # Initialize metrics for this prompt type if needed + if prompt_type not in prompt_metrics: + prompt_metrics[prompt_type] = { + "cost": 0.0, + "input_tokens": 0, + "output_tokens": 0, + "count": 0, + } + + # Add cost + cost = 0.0 + if ( + hasattr(obs, "calculated_total_cost") + and obs.calculated_total_cost + ): + cost = obs.calculated_total_cost + prompt_metrics[prompt_type]["cost"] += cost + + # Add tokens + if hasattr(obs, "usage") and obs.usage: + if hasattr(obs.usage, "input") and obs.usage.input: + prompt_metrics[prompt_type]["input_tokens"] += ( + obs.usage.input + ) + if hasattr(obs.usage, "output") and obs.usage.output: + prompt_metrics[prompt_type]["output_tokens"] += ( + obs.usage.output + ) + + # Increment count + prompt_metrics[prompt_type]["count"] += 1 + + return prompt_metrics + except Exception as e: + print(f"[red]Error getting costs by prompt type: {e}[/red]") + return {} + + +def get_prompt_type_statistics(trace_id: str) -> Dict[str, Dict[str, Any]]: + """Get detailed statistics for each prompt type. + + Args: + trace_id: The ID of the trace to analyze + + Returns: + Dict mapping prompt_type to { + 'cost': float, + 'input_tokens': int, + 'output_tokens': int, + 'count': int, + 'avg_cost_per_call': float, + 'avg_input_tokens': float, + 'avg_output_tokens': float, + 'percentage_of_total_cost': float + } + """ + try: + # Get basic metrics + prompt_metrics = get_costs_by_prompt_type(trace_id) + + # Calculate total cost for percentage calculation + total_cost = sum( + metrics["cost"] for metrics in prompt_metrics.values() + ) + + # Enhance with statistics + enhanced_metrics = {} + for prompt_type, metrics in prompt_metrics.items(): + count = metrics["count"] + enhanced_metrics[prompt_type] = { + "cost": metrics["cost"], + "input_tokens": metrics["input_tokens"], + "output_tokens": metrics["output_tokens"], + "count": count, + "avg_cost_per_call": metrics["cost"] / count + if count > 0 + else 0, + "avg_input_tokens": metrics["input_tokens"] / count + if count > 0 + else 0, + "avg_output_tokens": metrics["output_tokens"] / count + if count > 0 + else 0, + "percentage_of_total_cost": ( + metrics["cost"] / total_cost * 100 + ) + if total_cost > 0 + else 0, + } + + return enhanced_metrics + except Exception as e: + print(f"[red]Error getting prompt type statistics: {e}[/red]") + return {} + + +if __name__ == "__main__": + print( + "[bold cyan]ZenML Deep Research - Tracing Metadata Utilities Demo[/bold cyan]\n" + ) + + try: + # Fetch recent traces + print("[yellow]Fetching recent traces...[/yellow]") + traces = fetch_traces_safe(limit=5) + + if not traces: + print("[red]No traces found![/red]") + exit(1) + except ApiError as e: + if e.status_code == 429: + print("[red]Rate limit exceeded. Please try again later.[/red]") + print( + "[yellow]Tip: Consider upgrading your Langfuse tier for higher rate limits.[/yellow]" + ) + else: + print(f"[red]API Error: {e}[/red]") + exit(1) + except Exception as e: + print(f"[red]Error fetching traces: {e}[/red]") + exit(1) + + # Demo 1: Get stats for a single trace + print("\n[bold]1. Single Trace Statistics:[/bold]") + first_trace = traces[0] + stats = get_trace_stats(first_trace) + + console.print(f"Trace ID: [cyan]{stats['trace_id']}[/cyan]") + console.print(f"Timestamp: [yellow]{stats['timestamp']}[/yellow]") + console.print(f"Total Cost: [green]${stats['total_cost']:.4f}[/green]") + console.print( + f"Tokens - Input: [blue]{stats['input_tokens']:,}[/blue], Output: [blue]{stats['output_tokens']:,}[/blue]" + ) + console.print(f"Latency: [magenta]{stats['latency_formatted']}[/magenta]") + console.print(f"Observations: [white]{stats['observation_count']}[/white]") + console.print( + f"Models Used: [cyan]{', '.join(stats['models_used'])}[/cyan]" + ) + + # Demo 2: Get traces from last 7 days + print("\n[bold]2. Traces from Last 7 Days:[/bold]") + recent_traces = get_traces_last_n_days(7, limit=10) + print( + f"Found [green]{len(recent_traces)}[/green] traces in the last 7 days" + ) + + if recent_traces: + display_trace_stats_table(recent_traces, "Last 7 Days Traces") + + # Demo 3: Filter traces by date range + print("\n[bold]3. Filter Traces by Date Range:[/bold]") + end_date = datetime.now(timezone.utc) + start_date = end_date - timedelta(days=3) + + filtered_traces = filter_traces_by_date_range(start_date, end_date) + print( + f"Found [green]{len(filtered_traces)}[/green] traces between {start_date.strftime('%Y-%m-%d')} and {end_date.strftime('%Y-%m-%d')}" + ) + + # Demo 4: Aggregate statistics + print("\n[bold]4. Aggregate Statistics for All Recent Traces:[/bold]") + agg_stats = get_aggregate_stats_for_traces(traces) + + table = Table( + title="Aggregate Statistics", + show_header=True, + header_style="bold magenta", + ) + table.add_column("Metric", style="cyan") + table.add_column("Value", justify="right", style="yellow") + + table.add_row("Total Traces", str(agg_stats["trace_count"])) + table.add_row("Total Cost", f"${agg_stats['total_cost']:.4f}") + table.add_row( + "Average Cost per Trace", f"${agg_stats['average_cost_per_trace']:.4f}" + ) + table.add_row("Total Input Tokens", f"{agg_stats['total_input_tokens']:,}") + table.add_row( + "Total Output Tokens", f"{agg_stats['total_output_tokens']:,}" + ) + table.add_row("Total Tokens", f"{agg_stats['total_tokens']:,}") + table.add_row( + "Average Latency", f"{agg_stats['average_latency_seconds']:.1f}s" + ) + table.add_row("Total Observations", str(agg_stats["total_observations"])) + + console.print(table) + + # Demo 5: Cost breakdown by observation + print("\n[bold]5. Cost Breakdown for First Trace:[/bold]") + observations = fetch_observations_safe(trace_id=first_trace.id) + + if observations: + table = Table( + title="Observation Cost Breakdown", + show_header=True, + header_style="bold magenta", + ) + table.add_column("Observation", style="cyan", no_wrap=True) + table.add_column("Model", style="yellow") + table.add_column("Tokens (In/Out)", justify="right") + table.add_column("Cost", justify="right", style="green") + + for i, obs in enumerate(observations[:5]): # Show first 5 + cost = 0.0 + if hasattr(obs, "calculated_total_cost"): + cost = obs.calculated_total_cost or 0.0 + + in_tokens = 0 + out_tokens = 0 + if hasattr(obs, "usage") and obs.usage: + in_tokens = obs.usage.input or 0 + out_tokens = obs.usage.output or 0 + elif hasattr(obs, "promptTokens"): + in_tokens = obs.promptTokens or 0 + out_tokens = obs.completionTokens or 0 + + table.add_row( + f"Obs {i + 1}", + obs.model if hasattr(obs, "model") else "Unknown", + f"{in_tokens:,}/{out_tokens:,}", + f"${cost:.4f}", + ) + + console.print(table)