From db35eb8f853e9a73e225f8b017ed643e536d4fbe Mon Sep 17 00:00:00 2001 From: Hamza Tahir Date: Thu, 21 Aug 2025 19:34:46 +0200 Subject: [PATCH 01/14] Pydantic AI eda --- examples/pydantic_ai_eda/README.md | 343 +++++++ examples/pydantic_ai_eda/__init__.py | 5 + examples/pydantic_ai_eda/models.py | 171 ++++ .../pydantic_ai_eda/pipelines/__init__.py | 12 + .../pydantic_ai_eda/pipelines/eda_pipeline.py | 106 +++ examples/pydantic_ai_eda/requirements.txt | 21 + examples/pydantic_ai_eda/run.py | 299 ++++++ examples/pydantic_ai_eda/steps/__init__.py | 21 + examples/pydantic_ai_eda/steps/agent_tools.py | 434 +++++++++ examples/pydantic_ai_eda/steps/eda_agent.py | 877 ++++++++++++++++++ examples/pydantic_ai_eda/steps/ingest.py | 374 ++++++++ .../pydantic_ai_eda/steps/quality_gate.py | 325 +++++++ 12 files changed, 2988 insertions(+) create mode 100644 examples/pydantic_ai_eda/README.md create mode 100644 examples/pydantic_ai_eda/__init__.py create mode 100644 examples/pydantic_ai_eda/models.py create mode 100644 examples/pydantic_ai_eda/pipelines/__init__.py create mode 100644 examples/pydantic_ai_eda/pipelines/eda_pipeline.py create mode 100644 examples/pydantic_ai_eda/requirements.txt create mode 100644 examples/pydantic_ai_eda/run.py create mode 100644 examples/pydantic_ai_eda/steps/__init__.py create mode 100644 examples/pydantic_ai_eda/steps/agent_tools.py create mode 100644 examples/pydantic_ai_eda/steps/eda_agent.py create mode 100644 examples/pydantic_ai_eda/steps/ingest.py create mode 100644 examples/pydantic_ai_eda/steps/quality_gate.py diff --git a/examples/pydantic_ai_eda/README.md b/examples/pydantic_ai_eda/README.md new file mode 100644 index 00000000000..824be3fce23 --- /dev/null +++ b/examples/pydantic_ai_eda/README.md @@ -0,0 +1,343 @@ +# Pydantic AI EDA Pipeline + +This example demonstrates how to build an AI-powered Exploratory Data Analysis (EDA) pipeline using **ZenML** and **Pydantic AI**. The pipeline automatically analyzes datasets, generates comprehensive reports, and makes data quality decisions for downstream processing. + +## Architecture + +``` +ingest โ†’ eda_agent โ†’ quality_gate โ†’ routing +``` + +## Key Features + +- **๐Ÿค– AI-Powered Analysis**: Uses Pydantic AI with GPT-4 or Claude for intelligent data exploration +- **๐Ÿ“Š SQL-Based EDA**: Agent performs analysis through DuckDB SQL queries with safety guards +- **โœ… Quality Gates**: Automated data quality assessment with configurable thresholds +- **๐ŸŒ Multiple Data Sources**: Support for HuggingFace, local files, and data warehouses +- **๐Ÿ“ˆ Comprehensive Reporting**: Structured JSON reports and human-readable markdown + +## What's Included + +### Pipeline Steps +- **`ingest_data`**: Load data from HuggingFace, local files, or warehouses +- **`run_eda_agent`**: AI agent performs comprehensive EDA using SQL analysis +- **`evaluate_quality_gate`**: Assess data quality against configurable thresholds +- **`route_based_on_quality`**: Make pipeline routing decisions based on quality + +### AI Agent Capabilities +- Statistical analysis and profiling +- Missing data pattern detection +- Correlation analysis +- Outlier identification +- Data quality scoring (0-100) +- Actionable remediation recommendations +- SQL query logging for reproducibility + +### CLI Interface +- **Command-line Runner**: Easy execution with various configuration options +- **Quality Assessment**: Quick quality checks without full analysis +- **Multiple Output Formats**: JSON, CSV, and text reporting + +## Quick Start + +### Prerequisites + +```bash +pip install "zenml[server]" +zenml init +``` + +### Install Dependencies + +```bash +git clone https://github.com/zenml-io/zenml.git +cd zenml/examples/pydantic_ai_eda +pip install -r requirements.txt +``` + +### Set API Keys + +```bash +# For OpenAI (recommended) +export OPENAI_API_KEY="your-openai-key" + +# Or for Anthropic +export ANTHROPIC_API_KEY="your-anthropic-key" +``` + +### Quick Example + +```bash +# Run simple example +python example.py +``` + +### CLI Usage + +```bash +# Analyze HuggingFace dataset +python run_pipeline.py --source-type hf --source-path "scikit-learn/adult-census-income" --target-column "class" + +# Analyze local file +python run_pipeline.py --source-type local --source-path "/path/to/data.csv" --target-column "target" + +# Quality-only assessment +python run_quality_check.py --source-path "/path/to/data.csv" --min-quality-score 80 +``` + +## Example Usage + +### Python API + +```python +from models import DataSourceConfig, AgentConfig +from pipelines.eda_pipeline import eda_pipeline + +# Configure data source +source_config = DataSourceConfig( + source_type="hf", + source_path="scikit-learn/adult-census-income", + target_column="class", + sample_size=10000 +) + +# Configure AI agent +agent_config = AgentConfig( + model_name="gpt-5", + max_tool_calls=50, + sql_guard_enabled=True +) + +# Run pipeline +results = eda_pipeline( + source_config=source_config, + agent_config=agent_config, + min_quality_score=70.0 +) + +print(f"Quality Score: {results['quality_decision'].quality_score}") +print(f"Quality Gate: {'PASSED' if results['quality_decision'].passed else 'FAILED'}") +``` + +## Pipeline Configuration + +### Data Sources + +**HuggingFace Datasets:** +```python +source_config = DataSourceConfig( + source_type="hf", + source_path="scikit-learn/adult-census-income", + sampling_strategy="random", + sample_size=50000 +) +``` + +**Local Files:** +```python +source_config = DataSourceConfig( + source_type="local", + source_path="/path/to/data.csv", + target_column="target" +) +``` + +**Data Warehouses:** +```python +source_config = DataSourceConfig( + source_type="warehouse", + source_path="SELECT * FROM customer_data LIMIT 100000", + warehouse_config={ + "type": "bigquery", + "project_id": "my-project" + } +) +``` + +### AI Agent Configuration + +```python +agent_config = AgentConfig( + model_name="gpt-5", # or "claude-4" + max_tool_calls=100, + sql_guard_enabled=True, + preview_limit=20, + timeout_seconds=600 +) +``` + +### Quality Gate Thresholds + +```python +quality_decision = evaluate_quality_gate( + report_json=report, + min_quality_score=75.0, + block_on_high_severity=True, + max_missing_data_pct=25.0, + require_target_column=True +) +``` + +## Analysis Outputs + +### EDA Report Structure +```json +{ + "headline": "Dataset contains 32,561 rows with moderate data quality issues", + "key_findings": [ + "Found 6 numeric columns suitable for quantitative analysis", + "Missing data is 7.3% overall, within acceptable range", + "Strong correlation detected between age and hours-per-week (0.89)" + ], + "risks": ["Potential class imbalance in target variable"], + "fixes": [ + { + "title": "Address missing values in workclass column", + "severity": "medium", + "code_snippet": "df['workclass'].fillna(df['workclass'].mode()[0])", + "estimated_impact": 0.15 + } + ], + "data_quality_score": 78.5, + "correlation_insights": [...], + "missing_data_analysis": {...} +} +``` + +### Quality Gate Decision +```json +{ + "passed": true, + "quality_score": 78.5, + "decision_reason": "All quality checks passed", + "blocking_issues": [], + "recommendations": [ + "Data quality is acceptable for downstream processing", + "Consider implementing monitoring for quality regression" + ] +} +``` + +## Data Security + +### Quality Configuration +```python +# Configure quality thresholds +results = eda_pipeline( + source_config=source_config, + min_quality_score=80.0, + max_missing_data_pct=15.0 +) +``` + +### SQL Safety Guards +- Only `SELECT` and `WITH` statements allowed +- Prohibited operations: `DROP`, `DELETE`, `INSERT`, `UPDATE` +- Auto-injection of `LIMIT` clauses for large result sets +- Query logging for full auditability + +## Production Deployment + +### Remote Orchestration +```python +# Configure ZenML stack for cloud deployment +zenml stack register remote_stack \ + --orchestrator=kubernetes \ + --artifact_store=s3 \ + --container_registry=ecr + +# Run with remote stack +zenml stack set remote_stack +python run_pipeline.py --source-path "s3://my-bucket/data.parquet" +``` + +### Monitoring & Alerts +- Pipeline execution tracking via ZenML dashboard +- Quality gate failure notifications +- Data drift detection capabilities +- Token usage and cost monitoring + +## Examples Gallery + +### Customer Segmentation Analysis +```bash +python run_pipeline.py \ + --source-type hf \ + --source-path "scikit-learn/adult-census-income" \ + --target-column "class" \ + --min-quality-score 80 +``` + +### Financial Risk Assessment +```bash +python run_pipeline.py \ + --source-type local \ + --source-path "financial_data.csv" \ + --min-quality-score 90 \ + --require-target-column \ + --target-column "risk_score" +``` + +### Time Series Data Quality Check +```bash +python run_quality_check.py \ + --source-path "time_series.parquet" \ + --max-missing-data-pct 10 \ + --require-target-column \ + --target-column "value" +``` + +## Advanced Features + +### Custom Data Warehouses +Support for BigQuery, Snowflake, Redshift, and generic SQL connections. + +### Multi-Model Analysis +Switch between OpenAI GPT-4, Anthropic Claude, and other providers. + +### Pipeline Caching +Automatic caching of expensive operations for faster iterations. + +### Artifact Lineage +Full traceability of data transformations and analysis steps. + +## Troubleshooting + +### Common Issues + +**Missing API Keys:** +```bash +export OPENAI_API_KEY="your-key" +# or +export ANTHROPIC_API_KEY="your-key" +``` + +**DuckDB Import Errors:** +```bash +pip install duckdb>=1.0.0 +``` + +**Pydantic AI Installation:** +```bash +pip install pydantic-ai>=0.0.13 +``` + +**Large Dataset Memory Issues:** +- Reduce `sample_size` in DataSourceConfig +- Use `enable_masking=True` to reduce memory footprint +- Consider using `quality_only_pipeline` for quick checks + +### Performance Optimization + +- Use `gpt-4o-mini` instead of `gpt-5` for faster analysis +- Limit `max_tool_calls` for time-constrained scenarios +- Enable snapshot caching for repeated analysis +- Use stratified sampling for large datasets + +## Contributing + +This example demonstrates the integration patterns between ZenML and Pydantic AI. Contributions for additional data sources, quality checks, and analysis capabilities are welcome. + +## License + +This example is part of the ZenML project and follows the Apache 2.0 license. \ No newline at end of file diff --git a/examples/pydantic_ai_eda/__init__.py b/examples/pydantic_ai_eda/__init__.py new file mode 100644 index 00000000000..f25ceb38faa --- /dev/null +++ b/examples/pydantic_ai_eda/__init__.py @@ -0,0 +1,5 @@ +"""Pydantic AI EDA pipeline example for ZenML. + +This example demonstrates how to build an AI-powered Exploratory Data Analysis +pipeline using ZenML and Pydantic AI for automated data analysis and quality assessment. +""" \ No newline at end of file diff --git a/examples/pydantic_ai_eda/models.py b/examples/pydantic_ai_eda/models.py new file mode 100644 index 00000000000..30253e03480 --- /dev/null +++ b/examples/pydantic_ai_eda/models.py @@ -0,0 +1,171 @@ +"""Data models for EDA pipeline with Pydantic AI. + +This module defines Pydantic models used throughout the EDA pipeline +for request/response handling, analysis results, and evaluation. +""" + +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Field + + +class DataSourceConfig(BaseModel): + """Configuration for data source ingestion. + + Supports HuggingFace datasets, local files, and warehouse connections. + + Attributes: + source_type: Type of data source (hf, local, warehouse) + source_path: Path/identifier for the data source + target_column: Optional target column for analysis focus + sampling_strategy: How to sample the data (random, stratified, first_n) + sample_size: Number of rows to sample (None for all data) + warehouse_config: Additional config for warehouse connections + """ + + source_type: str = Field( + description="Data source type: hf, local, or warehouse" + ) + source_path: str = Field( + description="Path or identifier for the data source" + ) + target_column: Optional[str] = Field( + None, description="Optional target column name" + ) + sampling_strategy: str = Field("random", description="Sampling strategy") + sample_size: Optional[int] = Field( + None, description="Number of rows to sample" + ) + warehouse_config: Optional[Dict[str, Any]] = Field( + None, description="Warehouse connection config" + ) + + +class DataQualityFix(BaseModel): + """Model for data quality issues and suggested fixes. + + Represents a specific data quality problem identified during analysis + along with recommended remediation actions. + + Attributes: + title: Short description of the issue + rationale: Explanation of why this is a problem + severity: Impact level (low, medium, high, critical) + code_snippet: Optional code to address the issue + affected_columns: Columns affected by this issue + estimated_impact: Estimated impact on data quality (0-1) + """ + + title: str = Field( + description="Short description of the data quality issue" + ) + rationale: str = Field(description="Explanation of why this is a problem") + severity: str = Field( + description="Severity level: low, medium, high, critical" + ) + code_snippet: Optional[str] = Field( + None, description="Code to fix the issue" + ) + affected_columns: List[str] = Field( + description="Columns affected by this issue" + ) + estimated_impact: float = Field( + description="Estimated impact on data quality (0-1)" + ) + + +class EDAReport(BaseModel): + """Model for comprehensive EDA report results. + + Contains the structured output from Pydantic AI analysis including + findings, quality assessment, and recommendations. + + Attributes: + headline: Executive summary of key findings + key_findings: List of important discoveries about the data + risks: Potential data quality or analysis risks identified + fixes: Recommended fixes for data quality issues + data_quality_score: Overall data quality score (0-100) + markdown: Full markdown report for human consumption + column_profiles: Statistical profiles for each column + correlation_insights: Key correlation findings + missing_data_analysis: Analysis of missing data patterns + """ + + headline: str = Field(description="Executive summary of key findings") + key_findings: List[str] = Field( + description="Important discoveries about the data" + ) + risks: List[str] = Field( + description="Potential risks identified in the data" + ) + fixes: List[DataQualityFix] = Field( + description="Recommended data quality fixes" + ) + data_quality_score: float = Field( + description="Overall quality score (0-100)" + ) + markdown: str = Field(description="Full markdown report") + column_profiles: Dict[str, Dict[str, Any]] = Field( + description="Statistical profiles per column" + ) + correlation_insights: List[str] = Field( + description="Key correlation findings" + ) + missing_data_analysis: Dict[str, Any] = Field( + description="Missing data patterns" + ) + + +class QualityGateDecision(BaseModel): + """Model for quality gate decision results. + + Represents the outcome of evaluating whether data quality meets + requirements for downstream processing or model training. + + Attributes: + passed: Whether the quality gate check passed + quality_score: The computed data quality score + decision_reason: Explanation for the pass/fail decision + blocking_issues: Issues that caused failure (if failed) + recommendations: Suggested next steps + metadata: Additional decision metadata + """ + + passed: bool = Field(description="Whether quality gate passed") + quality_score: float = Field(description="Computed data quality score") + decision_reason: str = Field(description="Explanation for the decision") + blocking_issues: List[str] = Field( + description="Issues that caused failure" + ) + recommendations: List[str] = Field(description="Recommended next steps") + metadata: Dict[str, Any] = Field( + default_factory=dict, description="Additional metadata" + ) + + +class AgentConfig(BaseModel): + """Configuration for Pydantic AI agent behavior. + + Controls how the EDA agent operates including model selection, + tool usage limits, and safety constraints. + + Attributes: + model_name: Name of the language model to use + max_tool_calls: Maximum number of tool calls allowed + sql_guard_enabled: Whether to enable SQL safety guards + preview_limit: Maximum rows to show in data previews + enable_plotting: Whether to enable chart/plot generation + timeout_seconds: Maximum execution time in seconds + """ + + model_name: str = Field("gpt-5", description="Language model to use") + max_tool_calls: int = Field(50, description="Maximum tool calls allowed") + sql_guard_enabled: bool = Field( + True, description="Enable SQL safety guards" + ) + preview_limit: int = Field(10, description="Max rows in data previews") + enable_plotting: bool = Field( + False, description="Enable plotting capabilities" + ) + timeout_seconds: int = Field(300, description="Max execution time") diff --git a/examples/pydantic_ai_eda/pipelines/__init__.py b/examples/pydantic_ai_eda/pipelines/__init__.py new file mode 100644 index 00000000000..8806b28c264 --- /dev/null +++ b/examples/pydantic_ai_eda/pipelines/__init__.py @@ -0,0 +1,12 @@ +"""ZenML pipeline for Pydantic AI EDA workflow. + +This module contains the pipeline definition for the EDA workflow: + +- eda_pipeline.py: Complete EDA pipeline with AI analysis and quality gates +""" + +from .eda_pipeline import eda_pipeline + +__all__ = [ + "eda_pipeline", +] \ No newline at end of file diff --git a/examples/pydantic_ai_eda/pipelines/eda_pipeline.py b/examples/pydantic_ai_eda/pipelines/eda_pipeline.py new file mode 100644 index 00000000000..78487c5d27f --- /dev/null +++ b/examples/pydantic_ai_eda/pipelines/eda_pipeline.py @@ -0,0 +1,106 @@ +"""EDA pipeline using Pydantic AI for automated data analysis. + +This pipeline orchestrates the complete EDA workflow: +1. Data ingestion from various sources +2. AI-powered EDA analysis with Pydantic AI +3. Quality gate evaluation for pipeline routing +""" + +from typing import Any, Dict, Optional + +from models import AgentConfig, DataSourceConfig +from steps import ( + evaluate_quality_gate, + ingest_data, + route_based_on_quality, + run_eda_agent, +) + +from zenml import pipeline +from zenml.logger import get_logger + +logger = get_logger(__name__) + + +@pipeline +def eda_pipeline( + source_config: DataSourceConfig, + agent_config: Optional[AgentConfig] = None, + min_quality_score: float = 70.0, + block_on_high_severity: bool = True, + max_missing_data_pct: float = 30.0, + require_target_column: bool = False, +) -> Dict[str, Any]: + """Complete EDA pipeline with AI-powered analysis and quality gating. + + Performs end-to-end exploratory data analysis using Pydantic AI, + from data ingestion through quality assessment and routing decisions. + + Args: + source_config: Configuration for data source (HuggingFace/local/warehouse) + agent_config: Configuration for Pydantic AI agent behavior + min_quality_score: Minimum quality score for passing quality gate + block_on_high_severity: Whether high-severity issues block the pipeline + max_missing_data_pct: Maximum allowable missing data percentage + require_target_column: Whether to require a target column for analysis + + Returns: + Dictionary containing all pipeline outputs and routing decisions + """ + logger.info( + f"Starting EDA pipeline for {source_config.source_type}:{source_config.source_path}" + ) + + # Step 1: Ingest data from configured source + raw_df, ingestion_metadata = ingest_data(source_config=source_config) + + # Step 2: Run AI-powered EDA analysis + report_markdown, report_json, sql_log, analysis_tables = run_eda_agent( + dataset_df=raw_df, + dataset_metadata=ingestion_metadata, + agent_config=agent_config, + ) + + # Step 3: Evaluate data quality gate + quality_decision = evaluate_quality_gate( + report_json=report_json, + min_quality_score=min_quality_score, + block_on_high_severity=block_on_high_severity, + max_missing_data_pct=max_missing_data_pct, + require_target_column=require_target_column, + target_column=source_config.target_column, + ) + + # Step 4: Route based on quality assessment + routing_message = route_based_on_quality( + decision=quality_decision, + on_pass_message="Data quality acceptable - ready for downstream processing", + on_fail_message="Data quality insufficient - requires remediation before use", + ) + + # Log pipeline summary (note: artifacts are returned, actual values logged in steps) + logger.info("Pipeline steps completed successfully") + logger.info("Check step outputs for detailed analysis results") + + # Return comprehensive results + return { + # Core analysis outputs + "report_markdown": report_markdown, + "report_json": report_json, + "analysis_tables": analysis_tables, + "sql_log": sql_log, + # Graph visualization removed + # Quality assessment + "quality_decision": quality_decision, + "routing_message": routing_message, + # Pipeline metadata + "source_config": source_config, + "ingestion_metadata": ingestion_metadata, + "agent_config": agent_config, + # Summary metrics (basic info only, artifacts available separately) + "pipeline_summary": { + "data_source": f"{source_config.source_type}:{source_config.source_path}", + "target_column": source_config.target_column, + "timestamp": ingestion_metadata, # This will be the artifact + }, + } diff --git a/examples/pydantic_ai_eda/requirements.txt b/examples/pydantic_ai_eda/requirements.txt new file mode 100644 index 00000000000..94380d9fcab --- /dev/null +++ b/examples/pydantic_ai_eda/requirements.txt @@ -0,0 +1,21 @@ +# ZenML (use existing installation) +# zenml + +# Core dependencies with version constraints for compatibility +pandas>=2.0.0,<3.0.0 +numpy>=1.24.0,<2.0.0 +duckdb>=1.0.0,<2.0.0 + +# Pydantic compatibility - CRITICAL: Must be <2.10 for ZenML compatibility +pydantic>=2.8.0,<2.10.0 + +# AI/ML frameworks +pydantic-ai[logfire]>=0.7.0 +openai>=1.0.0,<2.0.0 +anthropic>=0.30.0 # Alternative LLM provider + +# Data source connectors +datasets>=2.14.0,<3.0.0 # HuggingFace datasets (optional for local files) + +# Utility libraries +nest-asyncio>=1.5.6,<2.0.0 # For async compatibility in notebooks \ No newline at end of file diff --git a/examples/pydantic_ai_eda/run.py b/examples/pydantic_ai_eda/run.py new file mode 100644 index 00000000000..67d7fbe840c --- /dev/null +++ b/examples/pydantic_ai_eda/run.py @@ -0,0 +1,299 @@ +#!/usr/bin/env python3 +"""Run the Pydantic AI EDA pipeline. + +This script provides multiple ways to run the EDA pipeline: +- With HuggingFace datasets (default) +- With local CSV files +- With different quality thresholds +- For testing and production scenarios + +Works with or without API keys (falls back to statistical analysis). +""" + +import argparse +import os +import sys +from pathlib import Path +from typing import Optional + +from models import AgentConfig, DataSourceConfig +from pipelines.eda_pipeline import eda_pipeline + + +def create_sample_dataset(): + """Create a sample iris dataset CSV for local testing.""" + try: + import pandas as pd + from sklearn.datasets import load_iris + + print("๐Ÿ“ Creating sample dataset...") + iris = load_iris() + df = pd.DataFrame(iris.data, columns=iris.feature_names) + df["target"] = iris.target + + df.to_csv("iris_sample.csv", index=False) + print(f"โœ… Created iris_sample.csv with {len(df)} rows") + return "iris_sample.csv" + except ImportError: + print("โŒ sklearn not available for sample dataset creation") + return None + + +def check_api_keys(): + """Check for available API keys and return provider info.""" + has_openai = bool(os.getenv("OPENAI_API_KEY")) + has_anthropic = bool(os.getenv("ANTHROPIC_API_KEY")) + + if has_openai and has_anthropic: + print("๐Ÿค– Both OpenAI and Anthropic API keys detected") + return "both" + elif has_openai: + print("๐Ÿค– OpenAI API key detected - will use GPT models") + return "openai" + elif has_anthropic: + print("๐Ÿค– Anthropic API key detected - will use Claude models") + return "anthropic" + else: + print("โš ๏ธ No API keys found - will use statistical fallback analysis") + print( + " Set OPENAI_API_KEY or ANTHROPIC_API_KEY for full AI features" + ) + return None + + +def run_pipeline( + source_type: str = "hf", + source_path: str = "scikit-learn/iris", + target_column: Optional[str] = "target", + min_quality_score: float = 70.0, + ai_provider: Optional[str] = None, + timeout: int = 300, + sample_size: Optional[int] = None, + verbose: bool = False, +): + """Run the EDA pipeline with specified configuration.""" + + # Configure data source + source_config = DataSourceConfig( + source_type=source_type, + source_path=source_path, + target_column=target_column, + sample_size=sample_size, + ) + + # Configure AI agent based on available providers + if ai_provider == "anthropic": + model_name = "claude-4" + elif ai_provider == "openai": + model_name = "gpt-5" + else: + model_name = "gpt-5" # Default fallback + + agent_config = AgentConfig( + model_name=model_name, + max_tool_calls=15, # Reduced to prevent infinite loops + sql_guard_enabled=True, + preview_limit=10, + timeout_seconds=timeout, + temperature=0.1, + ) + + print(f"๐Ÿ“Š Analyzing dataset: {source_config.source_path}") + if target_column: + print(f"๐ŸŽฏ Target column: {target_column}") + print(f"๐Ÿ“ Quality threshold: {min_quality_score}") + + try: + print("๐Ÿš€ Running EDA pipeline") + results = eda_pipeline.with_options(enable_cache=False)( + source_config=source_config, + agent_config=agent_config, + min_quality_score=min_quality_score, + block_on_high_severity=False, # Don't block for demo + max_missing_data_pct=30.0, + require_target_column=bool(target_column), + ) + + print("โœ… Pipeline completed successfully!") + + # Display results summary + print(f"\n{'=' * 60}") + print("๐Ÿ“‹ PIPELINE RESULTS") + print("=" * 60) + + # Show pipeline run info + if hasattr(results, "id"): + print(f"๐Ÿ“ Pipeline run ID: {results.id}") + if hasattr(results, "status"): + print(f"๐Ÿ“Š Status: {results.status}") + if hasattr(results, "name"): + print(f"๐Ÿท๏ธ Name: {results.name}") + + # Show artifact locations + print(f"\n๐Ÿ“ฆ Generated Artifacts:") + print(f" โ€ข EDA report (markdown): Available in ZenML dashboard") + print(f" โ€ข Analysis results (JSON): Available in ZenML dashboard") + print(f" โ€ข Quality assessment: Available in ZenML dashboard") + print(f" โ€ข SQL execution log: Available in ZenML dashboard") + print(f" โ€ข Analysis tables: Available in ZenML dashboard") + + # Show next steps + print(f"\n๐Ÿ“– Next Steps:") + print(f" โ€ข View full results in ZenML dashboard") + print( + f" โ€ข Access artifacts: results.steps['step_name'].outputs['artifact_name'].load()" + ) + print(f" โ€ข Run with different parameters using command line options") + + if not ai_provider: + print(f"\n๐Ÿ”‘ For AI-powered analysis:") + print(f" โ€ข Set: export OPENAI_API_KEY='your-key'") + print(f" โ€ข Or: export ANTHROPIC_API_KEY='your-key'") + print(f" โ€ข Then re-run for intelligent insights") + + return results + + except Exception as e: + print(f"โŒ Pipeline failed: {e}") + + if verbose: + import traceback + + print(f"\n๐Ÿ” Full error traceback:") + traceback.print_exc() + + print(f"\n๐Ÿ”ง Troubleshooting:") + if source_type == "hf": + print(f" โ€ข Check internet connection for HuggingFace datasets") + print( + f" โ€ข Try local mode: python run.py --source-type local --create-sample" + ) + elif source_type == "local": + print(f" โ€ข Check file exists: {source_path}") + print(f" โ€ข Ensure file is valid CSV format") + + print(f" โ€ข Ensure ZenML is initialized: zenml init") + print(f" โ€ข Check ZenML stack: zenml stack list") + print(f" โ€ข Install dependencies: pip install -r requirements.txt") + + return None + + +def main(): + """Main CLI interface.""" + parser = argparse.ArgumentParser( + description="Run Pydantic AI EDA Pipeline", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Basic usage with HuggingFace dataset + python run.py + + # Use local CSV file + python run.py --source-type local --source-path data.csv --target target_col + + # Create and use sample dataset + python run.py --source-type local --create-sample + + # Custom quality threshold + python run.py --min-quality-score 80 + + # Custom dataset with specific settings + python run.py --source-path username/dataset --sample-size 1000 --timeout 600 + """, + ) + + # Data source options + parser.add_argument( + "--source-type", + choices=["hf", "local", "warehouse"], + default="hf", + help="Data source type (default: hf)", + ) + parser.add_argument( + "--source-path", + default="scikit-learn/iris", + help="Dataset path (HF dataset name or file path) (default: scikit-learn/iris)", + ) + parser.add_argument( + "--target-column", + default="target", + help="Target column name (default: target)", + ) + parser.add_argument( + "--sample-size", + type=int, + help="Limit dataset to N rows for faster processing", + ) + + # Pipeline options + parser.add_argument( + "--min-quality-score", + type=float, + default=70.0, + help="Minimum quality score threshold (default: 70.0)", + ) + parser.add_argument( + "--timeout", + type=int, + default=300, + help="AI agent timeout in seconds (default: 300)", + ) + + # Utility options + parser.add_argument( + "--create-sample", + action="store_true", + help="Create iris_sample.csv for local testing", + ) + parser.add_argument( + "--verbose", action="store_true", help="Show detailed error traces" + ) + + args = parser.parse_args() + + print("๐Ÿš€ Pydantic AI EDA Pipeline") + print("=" * 40) + + # Create sample dataset if requested + if args.create_sample: + sample_file = create_sample_dataset() + if sample_file and args.source_type == "local": + args.source_path = sample_file + print(f"๐Ÿ”„ Switched to created sample: {sample_file}") + + # Check API key availability + ai_provider = check_api_keys() + + # Validate local file exists + if args.source_type == "local": + if not Path(args.source_path).exists(): + print(f"โŒ Local file not found: {args.source_path}") + if not args.create_sample: + print( + f"๐Ÿ’ก Try: python run.py --source-type local --create-sample" + ) + sys.exit(1) + + # Run the pipeline + results = run_pipeline( + source_type=args.source_type, + source_path=args.source_path, + target_column=args.target_column, + min_quality_score=args.min_quality_score, + ai_provider=ai_provider, + timeout=args.timeout, + sample_size=args.sample_size, + verbose=args.verbose, + ) + + if results: + print(f"\n๐ŸŽ‰ Pipeline completed successfully!") + sys.exit(0) + else: + print(f"\n๐Ÿ’ฅ Pipeline failed!") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/examples/pydantic_ai_eda/steps/__init__.py b/examples/pydantic_ai_eda/steps/__init__.py new file mode 100644 index 00000000000..92f58660db5 --- /dev/null +++ b/examples/pydantic_ai_eda/steps/__init__.py @@ -0,0 +1,21 @@ +"""ZenML steps for Pydantic AI EDA pipeline. + +This module contains all the step functions used in the EDA pipeline: + +- ingest.py: Data ingestion from multiple sources (HF, local, warehouse) +- snapshot.py: Data snapshot creation with optional masking +- agent_tools.py: Pydantic AI agent tools and dependencies +- eda_agent.py: AI-powered EDA analysis step +- quality_gate.py: Data quality assessment and routing steps +""" + +from .eda_agent import run_eda_agent +from .ingest import ingest_data +from .quality_gate import evaluate_quality_gate, route_based_on_quality + +__all__ = [ + "ingest_data", + "run_eda_agent", + "evaluate_quality_gate", + "route_based_on_quality", +] \ No newline at end of file diff --git a/examples/pydantic_ai_eda/steps/agent_tools.py b/examples/pydantic_ai_eda/steps/agent_tools.py new file mode 100644 index 00000000000..34a8f8edd43 --- /dev/null +++ b/examples/pydantic_ai_eda/steps/agent_tools.py @@ -0,0 +1,434 @@ +"""Pydantic AI agent tools for SQL-based EDA analysis. + +This module provides the tools and dependencies that the Pydantic AI agent +uses to perform exploratory data analysis through SQL queries. +""" + +import logging +import re +import tempfile +from pathlib import Path +from typing import Any, Dict, List, Optional + +import duckdb +import pandas as pd +from pydantic import BaseModel + +# Import RunContext for tool signatures +try: + from pydantic_ai import RunContext +except ImportError: + # Define a fallback if not available + class RunContext: + def __init__(self, deps): + self.deps = deps + + +logger = logging.getLogger(__name__) + + +class AnalystAgentDeps(BaseModel): + """Dependencies for the EDA analyst agent. + + Manages the agent's state including datasets, query history, + and analysis outputs. Acts as the context/memory for the agent. + + Attributes: + datasets: Mapping of reference names to DataFrames + query_history: Log of executed SQL queries + output_counter: Counter for generating unique output references + """ + + datasets: Dict[str, pd.DataFrame] = {} + query_history: List[Dict[str, Any]] = [] + output_counter: int = 0 + + class Config: + arbitrary_types_allowed = True + + def store_dataset(self, df: pd.DataFrame, ref_name: str = None) -> str: + """Store a dataset and return its reference name. + + Args: + df: DataFrame to store + ref_name: Optional custom reference name + + Returns: + Reference name for the stored dataset + """ + if ref_name is None: + self.output_counter += 1 + ref_name = f"Out[{self.output_counter}]" + + self.datasets[ref_name] = df.copy() + logger.info(f"Stored dataset as {ref_name} with shape {df.shape}") + return ref_name + + def get_dataset(self, ref_name: str) -> Optional[pd.DataFrame]: + """Retrieve a dataset by reference name.""" + return self.datasets.get(ref_name) + + def list_datasets(self) -> List[str]: + """List all available dataset references.""" + return list(self.datasets.keys()) + + def log_query( + self, + sql: str, + result_ref: str, + rows_returned: int, + execution_time_ms: float, + ): + """Log an executed SQL query.""" + self.query_history.append( + { + "sql": sql, + "result_ref": result_ref, + "rows_returned": rows_returned, + "execution_time_ms": execution_time_ms, + "timestamp": pd.Timestamp.now().isoformat(), + } + ) + + +def run_duckdb_query( + ctx: RunContext[AnalystAgentDeps], dataset_ref: str, sql: str +) -> str: + """Execute SQL query against a dataset using DuckDB. + + This is the primary tool the agent uses for data analysis. + Provides read-only access with safety guards against harmful SQL. + + Args: + deps: Agent dependencies containing datasets + dataset_ref: Reference to the dataset to query + sql: SQL query to execute + + Returns: + Message describing the query execution and result location + """ + import time + + start_time = time.time() + + deps = ctx.deps + try: + # Get the dataset + df = deps.get_dataset(dataset_ref) + if df is None: + return f"Error: Dataset '{dataset_ref}' not found. Available: {deps.list_datasets()}" + + # Validate SQL query for safety + if not _is_safe_sql(sql): + return "Error: SQL query contains prohibited operations. Only SELECT queries are allowed." + + # Auto-inject LIMIT if missing and query might return large results + modified_sql = _maybe_add_limit(sql, df) + + # Execute query with DuckDB + conn = duckdb.connect(":memory:") + + # Register the dataset + conn.register("dataset", df) + + # Execute the query + try: + result = conn.execute(modified_sql).fetchdf() + execution_time = (time.time() - start_time) * 1000 + + # Store result + result_ref = deps.store_dataset(result) + + # Log the query + deps.log_query( + modified_sql, result_ref, len(result), execution_time + ) + + # Return success message + rows_msg = f"{len(result)} row(s)" if len(result) != 1 else "1 row" + return f"Executed SQL successfully. Result stored as {result_ref} with {rows_msg}. Use display('{result_ref}') to view." + + except Exception as e: + return f"SQL execution error: {str(e)}" + finally: + conn.close() + + except Exception as e: + logger.error(f"Error in run_duckdb_query: {e}") + return f"Tool error: {str(e)}" + + +def display_data( + ctx: RunContext[AnalystAgentDeps], dataset_ref: str, max_rows: int = 10 +) -> str: + """Display a preview of a dataset. + + Shows the first few rows of a dataset in a readable format. + Used by the agent to peek at data content. + + Args: + deps: Agent dependencies + dataset_ref: Reference to dataset to display + max_rows: Maximum number of rows to show + + Returns: + String representation of the dataset preview + """ + deps = ctx.deps + try: + df = deps.get_dataset(dataset_ref) + if df is None: + return f"Error: Dataset '{dataset_ref}' not found. Available: {deps.list_datasets()}" + + if len(df) == 0: + return f"Dataset {dataset_ref} is empty (0 rows, {len(df.columns)} columns)." + + # Show basic info + preview_rows = min(max_rows, len(df)) + info = f"Dataset {dataset_ref}: {len(df)} rows ร— {len(df.columns)} columns\n\n" + + # Show column info + info += "Columns and types:\n" + for col, dtype in df.dtypes.items(): + null_count = df[col].isnull().sum() + null_pct = null_count / len(df) * 100 if len(df) > 0 else 0 + info += f" {col}: {dtype} ({null_count} nulls, {null_pct:.1f}%)\n" + + info += f"\nFirst {preview_rows} rows:\n" + info += df.head(preview_rows).to_string(max_cols=10, max_colwidth=50) + + if len(df) > preview_rows: + info += f"\n... ({len(df) - preview_rows} more rows)" + + return info + + except Exception as e: + return f"Display error: {str(e)}" + + +def profile_data(ctx: RunContext[AnalystAgentDeps], dataset_ref: str) -> str: + """Generate statistical profile of a dataset. + + Provides comprehensive statistics about the dataset including + distributions, correlations, and data quality metrics. + + Args: + ctx: Agent context with dependencies + dataset_ref: Reference to dataset to profile + + Returns: + Detailed statistical profile as string + """ + deps = ctx.deps + try: + df = deps.get_dataset(dataset_ref) + if df is None: + return f"Error: Dataset '{dataset_ref}' not found. Available: {deps.list_datasets()}" + + if len(df) == 0: + return f"Cannot profile empty dataset {dataset_ref}." + + profile = f"Statistical Profile for {dataset_ref}\n" + profile += "=" * 50 + "\n\n" + + # Basic info + profile += f"Shape: {df.shape[0]:,} rows ร— {df.shape[1]} columns\n" + profile += f"Memory usage: {df.memory_usage(deep=True).sum() / 1024**2:.2f} MB\n\n" + + # Numeric columns summary + numeric_cols = df.select_dtypes(include=["number"]).columns.tolist() + if numeric_cols: + profile += "Numeric Columns:\n" + desc = df[numeric_cols].describe() + profile += desc.to_string() + "\n\n" + + # Categorical columns summary + cat_cols = df.select_dtypes( + include=["object", "category"] + ).columns.tolist() + if cat_cols: + profile += "Categorical Columns:\n" + for col in cat_cols[ + :5 + ]: # Limit to first 5 to avoid too much output + unique_count = df[col].nunique() + null_count = df[col].isnull().sum() + most_common = df[col].value_counts().head(3) + + profile += f" {col}:\n" + profile += f" Unique values: {unique_count}\n" + profile += f" Null values: {null_count}\n" + profile += f" Most common: {dict(most_common)}\n\n" + + # Missing data analysis + missing_data = df.isnull().sum() + if missing_data.sum() > 0: + profile += "Missing Data:\n" + for col, missing in missing_data[missing_data > 0].items(): + pct = missing / len(df) * 100 + profile += f" {col}: {missing} ({pct:.1f}%)\n" + profile += "\n" + + # Correlation for numeric columns (if more than 1 numeric column) + if len(numeric_cols) > 1: + corr_matrix = df[numeric_cols].corr() + # Find high correlations + high_corrs = [] + for i, col1 in enumerate(numeric_cols): + for j, col2 in enumerate(numeric_cols[i + 1 :], i + 1): + corr_val = corr_matrix.iloc[i, j] + if abs(corr_val) > 0.7: # High correlation threshold + high_corrs.append((col1, col2, corr_val)) + + if high_corrs: + profile += "High Correlations (>0.7):\n" + for col1, col2, corr in high_corrs: + profile += f" {col1} โ†” {col2}: {corr:.3f}\n" + profile += "\n" + + return profile + + except Exception as e: + return f"Profile error: {str(e)}" + + +def save_table_as_csv( + ctx: RunContext[AnalystAgentDeps], dataset_ref: str, filename: str = None +) -> str: + """Save a dataset as CSV file. + + Optional tool for exporting analysis results. Saves to temporary + directory and returns the file path. + + Args: + ctx: Agent context with dependencies + dataset_ref: Reference to dataset to save + filename: Optional filename (auto-generated if not provided) + + Returns: + Path to saved CSV file or error message + """ + deps = ctx.deps + try: + df = deps.get_dataset(dataset_ref) + if df is None: + return f"Error: Dataset '{dataset_ref}' not found. Available: {deps.list_datasets()}" + + # Generate filename if not provided + if filename is None: + filename = f"eda_export_{dataset_ref.replace('[', '').replace(']', '')}.csv" + + # Ensure .csv extension + if not filename.endswith(".csv"): + filename += ".csv" + + # Save to temporary directory + temp_dir = Path(tempfile.gettempdir()) / "zenml_eda_exports" + temp_dir.mkdir(exist_ok=True) + + file_path = temp_dir / filename + df.to_csv(file_path, index=False) + + return f"Saved {len(df)} rows to: {file_path}" + + except Exception as e: + return f"Save error: {str(e)}" + + +def _is_safe_sql(sql: str) -> bool: + """Check if SQL query is safe (read-only operations). + + Blocks potentially harmful SQL operations to ensure the agent + can only perform analysis, not modify data or system state. + """ + sql_upper = sql.upper().strip() + + # Allow only SELECT statements + if not sql_upper.startswith("SELECT") and not sql_upper.startswith("WITH"): + return False + + # Block dangerous keywords + prohibited_keywords = [ + "DROP", + "DELETE", + "INSERT", + "UPDATE", + "ALTER", + "CREATE", + "TRUNCATE", + "REPLACE", + "MERGE", + "EXEC", + "EXECUTE", + "ATTACH", + "DETACH", + "PRAGMA", + "COPY", + "IMPORT", + "EXPORT", + "LOAD", + "INSTALL", + "SET GLOBAL", + "SET PERSIST", + ] + + for keyword in prohibited_keywords: + # Use word boundaries to avoid false positives + pattern = r"\b" + re.escape(keyword) + r"\b" + if re.search(pattern, sql_upper): + return False + + return True + + +def _maybe_add_limit( + sql: str, df: pd.DataFrame, default_limit: int = 1000 +) -> str: + """Add LIMIT clause to queries that might return large results. + + Prevents the agent from accidentally creating huge result sets + that could cause memory issues or slow performance. + """ + sql_upper = sql.upper().strip() + + # If LIMIT already present, don't modify + if "LIMIT" in sql_upper: + return sql + + # If dataset is small, no need to limit + if len(df) <= default_limit: + return sql + + # Check if query might return large results + # (GROUP BY, aggregation functions usually return smaller results) + has_aggregation = any( + keyword in sql_upper + for keyword in [ + "GROUP BY", + "COUNT(", + "SUM(", + "AVG(", + "MAX(", + "MIN(", + "DISTINCT", + ] + ) + + if has_aggregation: + return sql + + # Add LIMIT clause + modified_sql = sql.rstrip() + if modified_sql.endswith(";"): + modified_sql = modified_sql[:-1] + + return f"{modified_sql} LIMIT {default_limit}" + + +# Registry of available tools for the agent +AGENT_TOOLS = { + "run_duckdb": run_duckdb_query, + "display": display_data, + "profile": profile_data, + "save_csv": save_table_as_csv, +} diff --git a/examples/pydantic_ai_eda/steps/eda_agent.py b/examples/pydantic_ai_eda/steps/eda_agent.py new file mode 100644 index 00000000000..e540b6d8f40 --- /dev/null +++ b/examples/pydantic_ai_eda/steps/eda_agent.py @@ -0,0 +1,877 @@ +"""EDA agent step using Pydantic AI for automated data analysis. + +This step implements the core EDA agent that uses Pydantic AI to perform +intelligent exploratory data analysis through SQL queries and structured reporting. +""" + +import logging +import time +from typing import Annotated, Any, Dict, List, Tuple + +import pandas as pd + +from zenml import log_metadata, step +from zenml.types import MarkdownString + +# Logfire for observability +try: + import logfire + + LOGFIRE_AVAILABLE = True +except ImportError: + LOGFIRE_AVAILABLE = False + +from models import AgentConfig, DataQualityFix, EDAReport +from steps.agent_tools import AGENT_TOOLS, AnalystAgentDeps + +logger = logging.getLogger(__name__) + +# (HTML visualization removed as per simplification decision) + + +# EDA analysis prompt template +EDA_SYSTEM_PROMPT = """You are an expert data analyst. Your task is to perform a focused exploratory data analysis (EDA) and produce a final EDAReport within 15 tool calls maximum. + +## CRITICAL: Follow this exact workflow: +1. **Initial Assessment** (2-3 tool calls): + - display('dataset') to see the data structure + - profile('dataset') for basic statistics + - run_duckdb('dataset', 'SELECT COUNT(*), COUNT(DISTINCT *) as unique_rows FROM dataset') for size/duplicates + +2. **Data Quality Analysis** (3-4 tool calls): + - Check missing data: SELECT column_name, COUNT(*) - COUNT(column_name) as nulls FROM (SELECT * FROM dataset LIMIT 1) CROSS JOIN (SELECT column_name FROM information_schema.columns WHERE table_name = 'dataset') + - Identify duplicates if any found in step 1 + - Check for obvious outliers in numeric columns (use percentiles) + +3. **Key Insights** (2-3 tool calls): + - Calculate correlations between numeric columns if >1 numeric column exists + - Analyze categorical distributions for top categories + - Identify most important patterns or issues + +4. **STOP and Generate Report** (1 tool call): + - Produce the final EDAReport with all required fields + - Do NOT continue exploring after generating the report + +## Available Tools: +- run_duckdb(dataset_ref, sql): Execute SQL against 'dataset' table (read-only) +- display(dataset_ref): Show first 10 rows +- profile(dataset_ref): Get statistical summary +- save_csv(dataset_ref, filename): Save query results + +## SQL Guidelines: +- Table name is 'dataset' +- Use efficient aggregations, avoid SELECT * +- Limit large result sets with LIMIT clause +- Focus on summary statistics, not raw data exploration + +## REQUIRED Output Format: +You MUST produce an EDAReport with: +- headline: Executive summary (1 sentence) +- key_findings: 3-5 critical discoveries +- risks: Data quality issues found +- fixes: Specific DataQualityFix objects for issues +- data_quality_score: 0-100 score based on: + * Missing data: 0-15%=good(30pts), 16-30%=fair(20pts), >30%=poor(10pts) + * Duplicates: 0-5%=good(25pts), 6-15%=fair(15pts), >15%=poor(5pts) + * Schema quality: All columns have data=good(25pts), some empty=fair(15pts) + * Consistency: Clean data=good(20pts), issues found=poor(10pts) +- markdown: Summary report for humans +- column_profiles: Per-column statistics from profiling +- correlation_insights: Key relationships found +- missing_data_analysis: Missing data summary + +## EFFICIENCY RULES: +- Maximum 15 tool calls total +- Stop analysis once you have enough information for the report +- Focus on critical issues, not exhaustive exploration +- Generate the final report as soon as you have sufficient insights +- Do NOT keep exploring after finding basic patterns""" + + +@step +def run_eda_agent( + dataset_df: pd.DataFrame, + dataset_metadata: Dict[str, Any], + agent_config: AgentConfig = None, +) -> Tuple[ + Annotated[MarkdownString, "eda_report_markdown"], + Annotated[Dict[str, Any], "eda_report_json"], + Annotated[List[Dict[str, str]], "sql_execution_log"], + Annotated[Dict[str, pd.DataFrame], "analysis_tables"], +]: + """Run Pydantic AI agent for EDA analysis. + + Executes an AI agent that performs comprehensive exploratory data analysis + on the provided dataset using SQL queries and statistical analysis. + + Args: + dataset_df: The dataset to analyze + dataset_metadata: Metadata about the dataset + agent_config: Configuration for agent behavior + + Returns: + Tuple of (report_markdown, report_json, sql_log, analysis_tables) + containing all artifacts generated during the EDA analysis + """ + start_time = time.time() + + # Configure Logfire with explicit token + if LOGFIRE_AVAILABLE: + try: + logfire.configure() + logfire.instrument_pydantic_ai() + logfire.info("EDA agent starting", dataset_shape=dataset_df.shape) + except Exception as e: + logger.warning(f"Failed to configure Logfire: {e}") + + if agent_config is None: + agent_config = AgentConfig() + + logger.info(f"Starting EDA analysis with {agent_config.model_name}") + logger.info(f"Dataset shape: {dataset_df.shape}") + + try: + # Initialize agent dependencies + deps = AnalystAgentDeps() + + # Store the main dataset as Out[1] + main_ref = deps.store_dataset(dataset_df, "Out[1]") + logger.info(f"Stored main dataset as {main_ref}") + + # Initialize and run the Pydantic AI agent + try: + import asyncio + + import nest_asyncio + from pydantic_ai import Agent + from pydantic_ai.models.anthropic import AnthropicModel + from pydantic_ai.models.openai import OpenAIModel + + # Select model based on configuration + if agent_config.model_name.startswith("gpt"): + model = OpenAIModel(agent_config.model_name) + elif agent_config.model_name.startswith("claude"): + model = AnthropicModel(agent_config.model_name) + else: + # Default to OpenAI + model = OpenAIModel("gpt-5") + logger.warning( + f"Unknown model {agent_config.model_name}, defaulting to gpt-5" + ) + + # Create the agent with tools and stricter limits + agent = Agent( + model=model, + system_prompt=EDA_SYSTEM_PROMPT, + deps_type=AnalystAgentDeps, + result_type=EDAReport, + ) + + # Set strict tool call limits + if hasattr(agent, "max_tool_calls"): + agent.max_tool_calls = 15 + elif hasattr(agent, "_max_tool_calls"): + agent._max_tool_calls = 15 + + # Register tools + for tool_func in AGENT_TOOLS.values(): + agent.tool(tool_func) + + # Prepare initial context with clear instructions + initial_prompt = f"""ANALYZE THIS DATASET EFFICIENTLY - Maximum 15 tool calls. + +Dataset Information: +- Shape: {dataset_df.shape} +- Columns: {list(dataset_df.columns)} +- Source: {dataset_metadata.get("source_type", "unknown")} + +WORKFLOW (stick to this exactly): +1. display('dataset') - see the data +2. profile('dataset') - get basic stats +3. run_duckdb('dataset', 'SELECT COUNT(*), COUNT(DISTINCT *) FROM dataset') - check duplicates +4. Check missing data patterns with SQL +5. Analyze key relationships/distributions (max 3 queries) +6. STOP and generate final EDAReport + +Do NOT over-analyze. Focus on critical issues only. Generate your report once you have the essential insights.""" + + # Run the agent with timeout + try: + + async def run_agent(): + result = await agent.run(initial_prompt, deps=deps) + return result + + # Run with timeout + result = asyncio.wait_for( + run_agent(), timeout=agent_config.timeout_seconds + ) + + # If we're in a sync context, run it + try: + result = asyncio.run(result) + except RuntimeError: + # Already in an event loop + nest_asyncio.apply() + result = asyncio.run(result) + + except ImportError: + # Fallback to mock analysis if pydantic-ai not available + logger.warning( + "Pydantic AI not available, running fallback analysis" + ) + result = _run_fallback_analysis(dataset_df, deps) + + except ImportError: + logger.warning( + "Pydantic AI not available, running fallback analysis" + ) + result = _run_fallback_analysis(dataset_df, deps) + + # Extract results + if hasattr(result, "data"): + eda_report = result.data + else: + eda_report = result + + processing_time_ms = int((time.time() - start_time) * 1000) + + # Prepare return artifacts + report_markdown = eda_report.markdown + + report_json = { + "headline": eda_report.headline, + "key_findings": eda_report.key_findings, + "risks": eda_report.risks, + "fixes": [fix.model_dump() for fix in eda_report.fixes], + "data_quality_score": eda_report.data_quality_score, + "column_profiles": eda_report.column_profiles, + "correlation_insights": eda_report.correlation_insights, + "missing_data_analysis": eda_report.missing_data_analysis, + "processing_time_ms": processing_time_ms, + "agent_metadata": { + "model": agent_config.model_name, + "tool_calls": len(deps.query_history), + "datasets_created": len(deps.datasets), + }, + } + + # Get SQL execution log + sql_log = deps.query_history.copy() + + # Filter analysis tables (keep only reasonably sized ones) + analysis_tables = {} + for ref, df in deps.datasets.items(): + if ( + ref != "Out[1]" and len(df) <= 10000 + ): # Don't return huge tables + analysis_tables[ref] = df + + logger.info(f"EDA analysis completed in {processing_time_ms}ms") + logger.info(f"Generated {len(analysis_tables)} analysis tables") + logger.info(f"Data quality score: {eda_report.data_quality_score}") + + # Log enhanced metadata for ZenML dashboard + _log_agent_metadata( + agent_config=agent_config, + eda_report=eda_report, + processing_time_ms=processing_time_ms, + sql_log=sql_log, + analysis_tables=analysis_tables, + result=result if "result" in locals() else None, + ) + + # Determine tool names dynamically when possible + tool_names: List[str] = list(AGENT_TOOLS.keys()) + try: + if "agent" in locals(): + tools_attr = getattr(agent, "tools", None) + if tools_attr is None: + tools_attr = getattr(agent, "_tools", None) + if tools_attr is not None: + if isinstance(tools_attr, dict): + tool_names = list(tools_attr.keys()) + else: + possible_keys = getattr(tools_attr, "keys", None) + if callable(possible_keys): + tool_names = list(possible_keys()) + except Exception: + # Best-effort fallback + pass + + # Convert markdown to MarkdownString for proper rendering + markdown_artifact = MarkdownString(report_markdown) + + return markdown_artifact, report_json, sql_log, analysis_tables + + except Exception as e: + logger.error(f"EDA agent failed: {e}") + + # Run fallback analysis to preserve any existing context and generate basic results + logger.info("Running fallback analysis after agent failure") + fallback_report = _run_fallback_analysis(dataset_df, deps) + processing_time_ms = int((time.time() - start_time) * 1000) + + # Prepare return artifacts with fallback data + report_markdown = ( + fallback_report.markdown + + f"\n\n**Note:** Analysis completed in fallback mode after agent error: {str(e)}" + ) + + report_json = { + "headline": f"EDA analysis completed (fallback mode after error: {str(e)})", + "key_findings": fallback_report.key_findings, + "risks": fallback_report.risks + + [f"Original agent failed: {str(e)}"], + "fixes": [fix.model_dump() for fix in fallback_report.fixes], + "data_quality_score": fallback_report.data_quality_score, + "column_profiles": fallback_report.column_profiles, + "correlation_insights": fallback_report.correlation_insights, + "missing_data_analysis": fallback_report.missing_data_analysis, + "processing_time_ms": processing_time_ms, + "agent_metadata": { + "model": "fallback_after_error", + "tool_calls": len(deps.query_history), + "datasets_created": len(deps.datasets), + "error": str(e), + }, + } + + # Get SQL execution log (preserving any queries that were executed before failure) + sql_log = deps.query_history.copy() + + # Filter analysis tables (preserving any tables created before failure) + analysis_tables = {} + for ref, df in deps.datasets.items(): + if ref != "Out[1]" and ref != "main_dataset" and len(df) <= 10000: + analysis_tables[ref] = df + + # Convert markdown to MarkdownString + error_markdown_artifact = MarkdownString(report_markdown) + + logger.info( + f"Fallback analysis preserved {len(sql_log)} SQL queries and {len(analysis_tables)} analysis tables" + ) + + # Log error metadata + log_metadata( + { + "ai_agent_execution": { + "model_name": agent_config.model_name + if agent_config + else "unknown", + "processing_time_ms": processing_time_ms, + "success": False, + "error_message": str(e), + "tool_calls_made": len(sql_log), + "datasets_created": len(analysis_tables), + "fallback_mode": True, + }, + "data_quality_assessment": { + "quality_score": fallback_report.data_quality_score, + "issues_found": len(fallback_report.fixes), + "risk_count": len(fallback_report.risks), + "key_findings_count": len(fallback_report.key_findings), + }, + } + ) + + return ( + error_markdown_artifact, + report_json, + sql_log, + analysis_tables, + ) + + +def _run_fallback_analysis( + dataset_df: pd.DataFrame, deps: AnalystAgentDeps +) -> EDAReport: + """Fallback analysis when Pydantic AI is not available. + + Performs basic statistical analysis and generates a simple report + using pandas operations instead of AI-driven analysis. Also simulates + some SQL queries to populate logs and analysis tables. + """ + logger.info("Running fallback EDA analysis") + + # Store the main dataset + deps.store_dataset(dataset_df, "main_dataset") + + # Run some basic SQL-like analysis to populate logs and tables + _run_fallback_sql_analysis(dataset_df, deps) + + # Basic statistics + numeric_cols = dataset_df.select_dtypes( + include=["number"] + ).columns.tolist() + categorical_cols = dataset_df.select_dtypes( + include=["object", "category"] + ).columns.tolist() + + # Calculate missing data + missing_counts = dataset_df.isnull().sum() + missing_pct = (missing_counts / len(dataset_df) * 100).round(2) + + # Calculate quality score + quality_factors = [] + + # Missing data factor (0-40 points) + overall_missing_pct = ( + dataset_df.isnull().sum().sum() + / (len(dataset_df) * len(dataset_df.columns)) + * 100 + ) + if overall_missing_pct <= 5: + quality_factors.append(40) + elif overall_missing_pct <= 15: + quality_factors.append(30) + elif overall_missing_pct <= 30: + quality_factors.append(20) + else: + quality_factors.append(10) + + # Duplicate factor (0-30 points) + duplicate_count = dataset_df.duplicated().sum() + duplicate_pct = duplicate_count / len(dataset_df) * 100 + if duplicate_pct <= 1: + quality_factors.append(30) + elif duplicate_pct <= 5: + quality_factors.append(25) + elif duplicate_pct <= 15: + quality_factors.append(15) + else: + quality_factors.append(5) + + # Schema completeness factor (0-30 points) + if len(dataset_df.columns) > 0: + non_empty_cols = (dataset_df.notna().any()).sum() + schema_completeness = non_empty_cols / len(dataset_df.columns) * 30 + quality_factors.append(int(schema_completeness)) + else: + quality_factors.append(0) + + data_quality_score = sum(quality_factors) + + # Generate key findings + key_findings = [] + key_findings.append( + f"Dataset contains {len(dataset_df):,} rows and {len(dataset_df.columns)} columns" + ) + + if numeric_cols: + key_findings.append( + f"Found {len(numeric_cols)} numeric columns for quantitative analysis" + ) + + if categorical_cols: + key_findings.append( + f"Found {len(categorical_cols)} categorical columns for segmentation analysis" + ) + + if overall_missing_pct > 10: + key_findings.append( + f"Missing data is {overall_missing_pct:.1f}% overall, requiring attention" + ) + + if duplicate_count > 0: + key_findings.append( + f"Found {duplicate_count:,} duplicate rows ({duplicate_pct:.1f}%)" + ) + + # Generate risks + risks = [] + if overall_missing_pct > 20: + risks.append( + "High percentage of missing data may impact analysis quality" + ) + + if duplicate_pct > 10: + risks.append( + "Significant duplicate data may skew statistical analysis" + ) + + if len(numeric_cols) == 0: + risks.append("No numeric columns found for quantitative analysis") + + # Generate fixes + fixes = [] + + high_missing_cols = missing_pct[missing_pct > 30].index.tolist() + if high_missing_cols: + fixes.append( + DataQualityFix( + title="Address high missing data in key columns", + rationale=f"Columns {high_missing_cols} have >30% missing values", + severity="high", + code_snippet="df.dropna(subset=['high_missing_col']) or df.fillna(method='forward')", + affected_columns=high_missing_cols, + estimated_impact=0.3, + ) + ) + + if duplicate_count > 0: + fixes.append( + DataQualityFix( + title="Remove duplicate records", + rationale=f"Found {duplicate_count:,} duplicate rows affecting data integrity", + severity="medium" if duplicate_pct < 10 else "high", + code_snippet="df.drop_duplicates(inplace=True)", + affected_columns=list(dataset_df.columns), + estimated_impact=duplicate_pct / 100, + ) + ) + + # Column profiles + column_profiles = {} + for col in dataset_df.columns: + profile = { + "dtype": str(dataset_df[col].dtype), + "null_count": int(missing_counts[col]), + "null_percentage": float(missing_pct[col]), + "unique_count": int(dataset_df[col].nunique()), + } + + if col in numeric_cols and dataset_df[col].notna().sum() > 0: + profile.update( + { + "mean": float(dataset_df[col].mean()), + "std": float(dataset_df[col].std()), + "min": float(dataset_df[col].min()), + "max": float(dataset_df[col].max()), + "median": float(dataset_df[col].median()), + } + ) + elif col in categorical_cols and dataset_df[col].notna().sum() > 0: + value_counts = dataset_df[col].value_counts().head(5) + profile["top_values"] = value_counts.to_dict() + + column_profiles[col] = profile + + # Correlation insights + correlation_insights = [] + if len(numeric_cols) > 1: + corr_matrix = dataset_df[numeric_cols].corr() + high_corrs = [] + for i, col1 in enumerate(numeric_cols): + for j, col2 in enumerate(numeric_cols[i + 1 :], i + 1): + corr_val = corr_matrix.iloc[i, j] + if abs(corr_val) > 0.7: + high_corrs.append((col1, col2, corr_val)) + + if high_corrs: + correlation_insights.append( + f"Found {len(high_corrs)} high correlations (>0.7)" + ) + for col1, col2, corr in high_corrs[:3]: # Show top 3 + correlation_insights.append( + f"{col1} and {col2} are strongly correlated ({corr:.3f})" + ) + else: + correlation_insights.append( + "No strong correlations (>0.7) detected between numeric variables" + ) + + # Missing data analysis + missing_data_analysis = { + "total_missing_cells": int(missing_counts.sum()), + "missing_percentage": float(overall_missing_pct), + "columns_with_missing": missing_counts[missing_counts > 0].to_dict(), + "completely_missing_columns": missing_counts[ + missing_counts == len(dataset_df) + ].index.tolist(), + } + + # Generate markdown report + markdown_report = f"""# EDA Report (Fallback Analysis) + +## Executive Summary +{key_findings[0] if key_findings else "Basic dataset analysis completed"} + +**Data Quality Score: {data_quality_score}/100** + +## Key Findings +{chr(10).join(f"- {finding}" for finding in key_findings)} + +## Data Quality Issues +{chr(10).join(f"- {risk}" for risk in risks) if risks else "No major quality issues detected"} + +## Column Overview +| Column | Type | Missing | Unique | +|--------|------|---------|--------| +{chr(10).join(f"| {col} | {profile['dtype']} | {profile['null_percentage']:.1f}% | {profile['unique_count']} |" for col, profile in column_profiles.items())} + +## Recommendations +{chr(10).join(f"- {fix.title}: {fix.rationale}" for fix in fixes) if fixes else "- Dataset appears to be in good condition"} + +*Report generated by ZenML EDA Pipeline (fallback mode)* +""" + + return EDAReport( + headline=key_findings[0] + if key_findings + else "Basic EDA analysis completed", + key_findings=key_findings, + risks=risks, + fixes=fixes, + data_quality_score=data_quality_score, + markdown=markdown_report, + column_profiles=column_profiles, + correlation_insights=correlation_insights, + missing_data_analysis=missing_data_analysis, + ) + + +def _run_fallback_sql_analysis( + dataset_df: pd.DataFrame, deps: AnalystAgentDeps +): + """Run basic SQL-like analysis to populate logs and analysis tables for fallback mode.""" + import time + + import duckdb + + try: + # Create DuckDB connection + conn = duckdb.connect(":memory:") + conn.register("dataset", dataset_df) + + # Run some basic SQL queries to simulate what the AI agent would do + fallback_queries = [ + ("SELECT COUNT(*) as row_count FROM dataset", "basic_stats"), + ( + "SELECT COUNT(*) as col_count FROM (SELECT * FROM dataset LIMIT 1)", + "column_count", + ), + ] + + # Add column-specific queries + for col in dataset_df.columns[:5]: # Limit to first 5 columns + # Escape column names with spaces + escaped_col = f'"{col}"' if " " in col else col + + # Count nulls + fallback_queries.append( + ( + f"SELECT COUNT(*) - COUNT({escaped_col}) as null_count FROM dataset", + f"nulls_{col.replace(' ', '_')}", + ) + ) + + # Get unique counts for non-numeric columns + if dataset_df[col].dtype == "object" or dataset_df[ + col + ].dtype.name.startswith("str"): + fallback_queries.append( + ( + f"SELECT COUNT(DISTINCT {escaped_col}) as unique_count FROM dataset WHERE {escaped_col} IS NOT NULL", + f"unique_{col.replace(' ', '_')}", + ) + ) + else: + # Basic stats for numeric columns + fallback_queries.append( + ( + f"SELECT AVG({escaped_col}) as avg_val, MIN({escaped_col}) as min_val, MAX({escaped_col}) as max_val FROM dataset WHERE {escaped_col} IS NOT NULL", + f"stats_{col.replace(' ', '_')}", + ) + ) + + # Execute queries and log them + for sql, description in fallback_queries: + try: + start_time = time.time() + result_df = conn.execute(sql).fetchdf() + execution_time = (time.time() - start_time) * 1000 + + # Store result table + result_ref = deps.store_dataset( + result_df, f"fallback_{description}" + ) + + # Log the query + deps.log_query( + sql=sql, + result_ref=result_ref, + rows_returned=len(result_df), + execution_time_ms=execution_time, + ) + + except Exception as e: + logger.warning(f"Fallback query failed: {sql} - {e}") + + conn.close() + logger.info( + f"Fallback analysis executed {len(fallback_queries)} SQL queries" + ) + + except Exception as e: + logger.warning(f"Could not run fallback SQL analysis: {e}") + + +def _log_agent_metadata( + agent_config: AgentConfig, + eda_report: EDAReport, + processing_time_ms: int, + sql_log: List[Dict[str, str]], + analysis_tables: Dict[str, pd.DataFrame], + result: Any = None, +) -> None: + """Log enhanced metadata about AI agent execution for ZenML dashboard.""" + + # Calculate token usage if available + token_usage = {} + cost_estimate = None + if result and hasattr(result, "usage"): + token_usage = { + "input_tokens": getattr(result.usage, "prompt_tokens", 0), + "output_tokens": getattr(result.usage, "completion_tokens", 0), + "total_tokens": getattr(result.usage, "total_tokens", 0), + } + cost_estimate = _estimate_cost(result.usage, agent_config.model_name) + + # Calculate SQL execution metrics + sql_metrics = {} + if sql_log: + execution_times = [q.get("execution_time_ms", 0) for q in sql_log] + sql_metrics = { + "total_queries": len(sql_log), + "avg_execution_time_ms": sum(execution_times) + / len(execution_times) + if execution_times + else 0, + "max_execution_time_ms": max(execution_times) + if execution_times + else 0, + "total_rows_processed": sum( + q.get("rows_returned", 0) for q in sql_log + ), + "query_types": _analyze_query_types(sql_log), + } + + # Analyze data quality issues by severity + severity_breakdown = {"low": 0, "medium": 0, "high": 0, "critical": 0} + for fix in eda_report.fixes: + severity = fix.severity.lower() + if severity in severity_breakdown: + severity_breakdown[severity] += 1 + + # Log metadata to ZenML + log_metadata( + { + "ai_agent_execution": { + "model_name": agent_config.model_name, + "processing_time_ms": processing_time_ms, + "success": True, + "tool_calls_made": len(sql_log), + "datasets_created": len(analysis_tables), + }, + "token_usage": token_usage, + "cost_estimate_usd": cost_estimate, + "data_quality_assessment": { + "quality_score": eda_report.data_quality_score, + "issues_found": len(eda_report.fixes), + "severity_breakdown": severity_breakdown, + "risk_count": len(eda_report.risks), + "key_findings_count": len(eda_report.key_findings), + }, + "sql_execution_metrics": sql_metrics, + "analysis_summary": { + "headline": eda_report.headline, + "columns_analyzed": len(eda_report.column_profiles), + "correlation_insights": len(eda_report.correlation_insights), + "missing_data_pct": eda_report.missing_data_analysis.get( + "missing_percentage", 0 + ), + }, + } + ) + + +def _estimate_cost(usage, model_name: str) -> float: + """Estimate cost based on token usage and model pricing.""" + if not hasattr(usage, "total_tokens") or usage.total_tokens == 0: + return 0.0 + + # Rough pricing estimates (per 1M tokens) + pricing = { + "gpt-4": {"input": 30, "output": 60}, + "gpt-4o": {"input": 5, "output": 15}, + "gpt-4o-mini": {"input": 0.15, "output": 0.6}, + "claude-3-5-sonnet": {"input": 3, "output": 15}, + "claude-3-haiku": {"input": 0.25, "output": 1.25}, + } + + # Find matching pricing + model_pricing = None + for model_key, prices in pricing.items(): + if model_key in model_name.lower(): + model_pricing = prices + break + + if not model_pricing: + # Default rough estimate + return usage.total_tokens * 0.00001 + + input_cost = ( + getattr(usage, "prompt_tokens", 0) * model_pricing["input"] / 1_000_000 + ) + output_cost = ( + getattr(usage, "completion_tokens", 0) + * model_pricing["output"] + / 1_000_000 + ) + + return input_cost + output_cost + + +def _analyze_query_types(sql_log: List[Dict[str, str]]) -> Dict[str, int]: + """Analyze the types of SQL queries executed.""" + query_types = {} + + for query in sql_log: + sql = query.get("sql", "").upper().strip() + + # Determine query type + if sql.startswith("SELECT COUNT"): + query_type = "count" + elif sql.startswith("SELECT DISTINCT") or "DISTINCT" in sql: + query_type = "distinct" + elif "GROUP BY" in sql: + query_type = "group_by" + elif "ORDER BY" in sql: + query_type = "ordered" + elif "WHERE" in sql: + query_type = "filtered" + elif sql.startswith("SELECT"): + query_type = "basic_select" + elif sql.startswith("WITH"): + query_type = "cte" + else: + query_type = "other" + + query_types[query_type] = query_types.get(query_type, 0) + 1 + + return query_types + + +def _create_error_report( + error_msg: str, dataset_df: pd.DataFrame +) -> Dict[str, Any]: + """Create an error report when analysis fails.""" + return { + "headline": "EDA analysis failed due to technical error", + "key_findings": [ + f"Analysis failed: {error_msg}", + f"Dataset has {len(dataset_df)} rows and {len(dataset_df.columns)} columns", + ], + "risks": [ + "Analysis could not be completed", + "Manual inspection required", + ], + "fixes": [], + "data_quality_score": 0.0, + "markdown": f"# EDA Analysis Failed\n\n**Error:** {error_msg}\n\nPlease check the dataset and configuration.", + "column_profiles": {}, + "correlation_insights": [], + "missing_data_analysis": {}, + } diff --git a/examples/pydantic_ai_eda/steps/ingest.py b/examples/pydantic_ai_eda/steps/ingest.py new file mode 100644 index 00000000000..77352b5d6db --- /dev/null +++ b/examples/pydantic_ai_eda/steps/ingest.py @@ -0,0 +1,374 @@ +"""Data ingestion step for EDA pipeline. + +Supports loading data from multiple sources including HuggingFace datasets, +local files, and data warehouse connections. +""" + +import hashlib +import logging +from typing import Annotated, Any, Dict, Tuple + +import pandas as pd +from models import DataSourceConfig + +from zenml import step + +logger = logging.getLogger(__name__) + + +@step +def ingest_data( + source_config: DataSourceConfig, +) -> Tuple[ + Annotated[pd.DataFrame, "dataset"], + Annotated[Dict[str, Any], "ingestion_metadata"], +]: + """Ingest data from configured source. + + Loads data from HuggingFace, local files, or warehouse based on + source configuration. Returns both the DataFrame and metadata. + + Args: + source_config: Configuration specifying data source and parameters + + Returns: + Tuple of (raw_df, metadata) where metadata contains schema info, + row count, and content hash for traceability + """ + logger.info( + f"Ingesting data from {source_config.source_type}: {source_config.source_path}" + ) + + # Load data based on source type + if source_config.source_type == "hf": + df = _load_from_huggingface(source_config) + elif source_config.source_type == "local": + df = _load_from_local(source_config) + elif source_config.source_type == "warehouse": + df = _load_from_warehouse(source_config) + else: + raise ValueError( + f"Unsupported source type: {source_config.source_type}" + ) + + # Apply sampling if configured + if source_config.sample_size and len(df) > source_config.sample_size: + if source_config.sampling_strategy == "random": + df = df.sample(n=source_config.sample_size, random_state=42) + elif source_config.sampling_strategy == "first_n": + df = df.head(source_config.sample_size) + elif ( + source_config.sampling_strategy == "stratified" + and source_config.target_column + ): + df = _stratified_sample( + df, source_config.target_column, source_config.sample_size + ) + else: + logger.warning( + f"Unknown sampling strategy: {source_config.sampling_strategy}" + ) + df = df.sample(n=source_config.sample_size, random_state=42) + + # Note: For basic datasets like iris, pandas DataFrames should work fine with ZenML + # The dtype warnings come from ZenML's internal serialization of DataFrame metadata, + # not the actual data. This is normal for pandas DataFrames in ZenML. + + # Generate metadata + metadata = { + "source_type": source_config.source_type, + "source_path": source_config.source_path, + "original_rows": len(df), + "columns": len(df.columns), + "column_names": df.columns.tolist(), + "dtypes": df.dtypes.to_dict(), + "target_column": source_config.target_column, + "content_hash": _compute_content_hash(df), + "memory_usage_mb": df.memory_usage(deep=True).sum() / 1024 / 1024, + } + + # Add target column statistics if specified + if ( + source_config.target_column + and source_config.target_column in df.columns + ): + metadata["target_value_counts"] = ( + df[source_config.target_column].value_counts().to_dict() + ) + metadata["target_null_count"] = ( + df[source_config.target_column].isnull().sum() + ) + + logger.info( + f"Loaded dataset with {len(df)} rows and {len(df.columns)} columns" + ) + + return df, metadata + + +def _load_from_huggingface(config: DataSourceConfig) -> pd.DataFrame: + """Load dataset from HuggingFace Hub.""" + try: + from datasets import load_dataset + + # Parse dataset path (may include config/split) + parts = config.source_path.split("/") + if len(parts) >= 2: + dataset_name = "/".join(parts[:2]) + subset = parts[2] if len(parts) > 2 else None + else: + dataset_name = config.source_path + subset = None + + # Load dataset + dataset = load_dataset(dataset_name, subset, split="train") + df = dataset.to_pandas() + + logger.info(f"Loaded HuggingFace dataset: {dataset_name}") + return df + + except ImportError: + raise ImportError( + "datasets library required for HuggingFace loading. Install with: pip install datasets" + ) + except Exception as e: + raise RuntimeError( + f"Failed to load HuggingFace dataset {config.source_path}: {e}" + ) + + +def _load_from_local(config: DataSourceConfig) -> pd.DataFrame: + """Load dataset from local file.""" + try: + file_path = config.source_path + + if file_path.endswith(".csv"): + df = pd.read_csv(file_path) + elif file_path.endswith(".parquet"): + df = pd.read_parquet(file_path) + elif file_path.endswith(".json"): + df = pd.read_json(file_path) + elif file_path.endswith((".xlsx", ".xls")): + df = pd.read_excel(file_path) + else: + # Try CSV as fallback + df = pd.read_csv(file_path) + + logger.info(f"Loaded local file: {file_path}") + return df + + except Exception as e: + raise RuntimeError( + f"Failed to load local file {config.source_path}: {e}" + ) + + +def _load_from_warehouse(config: DataSourceConfig) -> pd.DataFrame: + """Load dataset from data warehouse connection.""" + try: + warehouse_config = config.warehouse_config or {} + + if "connection_string" in warehouse_config: + # Generic SQL connection + import sqlalchemy + + engine = sqlalchemy.create_engine( + warehouse_config["connection_string"] + ) + df = pd.read_sql(config.source_path, engine) + elif "type" in warehouse_config: + # Specific warehouse type + warehouse_type = warehouse_config["type"].lower() + + if warehouse_type == "bigquery": + df = _load_from_bigquery(config.source_path, warehouse_config) + elif warehouse_type == "snowflake": + df = _load_from_snowflake(config.source_path, warehouse_config) + elif warehouse_type == "redshift": + df = _load_from_redshift(config.source_path, warehouse_config) + else: + raise ValueError( + f"Unsupported warehouse type: {warehouse_type}" + ) + else: + raise ValueError( + "Warehouse config must specify connection_string or type" + ) + + logger.info(f"Loaded from warehouse: {config.source_path}") + return df + + except Exception as e: + raise RuntimeError( + f"Failed to load from warehouse {config.source_path}: {e}" + ) + + +def _load_from_bigquery( + table_path: str, config: Dict[str, Any] +) -> pd.DataFrame: + """Load data from Google BigQuery.""" + try: + import pandas_gbq + + project_id = config.get("project_id") + credentials = config.get("credentials_path") + + if table_path.startswith("SELECT"): + # It's a query + df = pandas_gbq.read_gbq( + table_path, project_id=project_id, credentials=credentials + ) + else: + # It's a table reference + query = f"SELECT * FROM `{table_path}`" + df = pandas_gbq.read_gbq( + query, project_id=project_id, credentials=credentials + ) + + return df + + except ImportError: + raise ImportError( + "pandas-gbq required for BigQuery. Install with: pip install pandas-gbq" + ) + + +def _load_from_snowflake( + table_path: str, config: Dict[str, Any] +) -> pd.DataFrame: + """Load data from Snowflake.""" + try: + import snowflake.connector + from snowflake.connector.pandas_tools import pd_writer + + conn = snowflake.connector.connect( + user=config["user"], + password=config["password"], + account=config["account"], + warehouse=config.get("warehouse"), + database=config.get("database"), + schema=config.get("schema"), + ) + + if table_path.upper().startswith("SELECT"): + query = table_path + else: + query = f"SELECT * FROM {table_path}" + + df = pd.read_sql(query, conn) + conn.close() + + return df + + except ImportError: + raise ImportError( + "snowflake-connector-python required. Install with: pip install snowflake-connector-python" + ) + + +def _load_from_redshift( + table_path: str, config: Dict[str, Any] +) -> pd.DataFrame: + """Load data from Amazon Redshift.""" + try: + import psycopg2 + import sqlalchemy + + connection_string = f"postgresql://{config['user']}:{config['password']}@{config['host']}:{config.get('port', 5439)}/{config['database']}" + engine = sqlalchemy.create_engine(connection_string) + + if table_path.upper().startswith("SELECT"): + query = table_path + else: + query = f"SELECT * FROM {table_path}" + + df = pd.read_sql(query, engine) + + return df + + except ImportError: + raise ImportError( + "psycopg2 required for Redshift. Install with: pip install psycopg2-binary" + ) + + +def _stratified_sample( + df: pd.DataFrame, target_column: str, sample_size: int +) -> pd.DataFrame: + """Perform stratified sampling based on target column.""" + try: + # Calculate proportional sample sizes for each class + value_counts = df[target_column].value_counts() + proportions = value_counts / len(df) + + sampled_dfs = [] + remaining_samples = sample_size + + for value, proportion in proportions.items(): + if remaining_samples <= 0: + break + + # Calculate sample size for this class + class_sample_size = max(1, int(proportion * sample_size)) + class_sample_size = min( + class_sample_size, remaining_samples, value_counts[value] + ) + + # Sample from this class + class_df = df[df[target_column] == value].sample( + n=class_sample_size, random_state=42 + ) + sampled_dfs.append(class_df) + + remaining_samples -= class_sample_size + + # Combine all samples + result_df = pd.concat(sampled_dfs, ignore_index=True) + + # Shuffle the final result + result_df = result_df.sample(frac=1, random_state=42).reset_index( + drop=True + ) + + logger.info( + f"Stratified sampling: {len(result_df)} samples across {len(sampled_dfs)} classes" + ) + return result_df + + except Exception as e: + logger.warning(f"Stratified sampling failed, using random: {e}") + return df.sample(n=sample_size, random_state=42) + + +def _compute_content_hash(df: pd.DataFrame) -> str: + """Compute hash of DataFrame content for change detection.""" + try: + # Create a string representation of the dataframe structure and sample + content_parts = [ + f"shape:{df.shape}", + f"columns:{sorted(df.columns.tolist())}", + f"dtypes:{sorted(df.dtypes.astype(str).tolist())}", + ] + + # Add sample of data if not too large + if len(df) <= 1000: + content_parts.append(f"data:{df.to_string()}") + else: + # Use a sample and summary stats + sample_df = ( + df.sample(n=100, random_state=42) if len(df) > 100 else df + ) + content_parts.extend( + [ + f"sample:{sample_df.to_string()}", + f"describe:{df.describe().to_string()}", + ] + ) + + content_str = "|".join(content_parts) + return hashlib.md5(content_str.encode()).hexdigest() + + except Exception as e: + logger.warning(f"Failed to compute content hash: {e}") + return f"error_{hashlib.md5(str(df.shape).encode()).hexdigest()}" diff --git a/examples/pydantic_ai_eda/steps/quality_gate.py b/examples/pydantic_ai_eda/steps/quality_gate.py new file mode 100644 index 00000000000..5f6a6729a47 --- /dev/null +++ b/examples/pydantic_ai_eda/steps/quality_gate.py @@ -0,0 +1,325 @@ +"""Quality gate step for data quality assessment and pipeline routing. + +Evaluates EDA results against quality thresholds to make pass/fail decisions +for downstream processing or model training workflows. +""" + +import logging +from typing import Annotated, Any, Dict + +from models import QualityGateDecision + +from zenml import step + +logger = logging.getLogger(__name__) + + +@step +def evaluate_quality_gate( + report_json: Dict[str, Any], + min_quality_score: float = 70.0, + block_on_high_severity: bool = True, + max_missing_data_pct: float = 30.0, + require_target_column: bool = False, + target_column: str = None, +) -> Annotated[QualityGateDecision, "quality_gate_decision"]: + """Evaluate data quality and make gate pass/fail decision. + + Analyzes the EDA report results against configurable quality thresholds + to determine if the data meets requirements for downstream processing. + + Args: + report_json: EDA report JSON containing quality metrics and findings + min_quality_score: Minimum data quality score required (0-100) + block_on_high_severity: Whether to fail on high-severity issues + max_missing_data_pct: Maximum allowable missing data percentage + require_target_column: Whether a target column is required + target_column: Expected target column name (if required) + + Returns: + QualityGateDecision with pass/fail result and recommendations + """ + logger.info("Evaluating data quality gate") + + try: + # Extract key metrics from report + quality_score = report_json.get("data_quality_score", 0.0) + fixes = report_json.get("fixes", []) + missing_data_analysis = report_json.get("missing_data_analysis", {}) + column_profiles = report_json.get("column_profiles", {}) + + # Initialize decision components + blocking_issues = [] + recommendations = [] + decision_factors = [] + + # Check 1: Overall quality score + if quality_score < min_quality_score: + blocking_issues.append( + f"Data quality score ({quality_score:.1f}) below minimum threshold ({min_quality_score})" + ) + decision_factors.append( + f"Quality score: {quality_score:.1f}/{min_quality_score} โŒ" + ) + else: + decision_factors.append( + f"Quality score: {quality_score:.1f}/{min_quality_score} โœ…" + ) + + # Check 2: High severity issues + high_severity_fixes = [ + fix for fix in fixes if fix.get("severity") == "high" + ] + if block_on_high_severity and high_severity_fixes: + high_severity_titles = [ + fix.get("title", "Unknown issue") + for fix in high_severity_fixes + ] + blocking_issues.append( + f"High severity issues found: {', '.join(high_severity_titles)}" + ) + decision_factors.append( + f"High severity issues: {len(high_severity_fixes)} โŒ" + ) + else: + decision_factors.append( + f"High severity issues: {len(high_severity_fixes)} โœ…" + ) + + # Check 3: Missing data threshold + overall_missing_pct = missing_data_analysis.get( + "missing_percentage", 0.0 + ) + if overall_missing_pct > max_missing_data_pct: + blocking_issues.append( + f"Missing data percentage ({overall_missing_pct:.1f}%) exceeds threshold ({max_missing_data_pct}%)" + ) + decision_factors.append( + f"Missing data: {overall_missing_pct:.1f}%/{max_missing_data_pct}% โŒ" + ) + else: + decision_factors.append( + f"Missing data: {overall_missing_pct:.1f}%/{max_missing_data_pct}% โœ…" + ) + + # Check 4: Target column requirement + if require_target_column and target_column: + if target_column not in column_profiles: + blocking_issues.append( + f"Required target column '{target_column}' not found in dataset" + ) + decision_factors.append( + f"Target column '{target_column}': Missing โŒ" + ) + else: + # Check target column quality + target_profile = column_profiles[target_column] + target_missing_pct = target_profile.get("null_percentage", 0.0) + + if ( + target_missing_pct > 50 + ): # Target column shouldn't be mostly empty + blocking_issues.append( + f"Target column '{target_column}' has {target_missing_pct:.1f}% missing values" + ) + decision_factors.append( + f"Target column '{target_column}': {target_missing_pct:.1f}% missing โŒ" + ) + else: + decision_factors.append( + f"Target column '{target_column}': Present โœ…" + ) + + # Generate recommendations based on findings + recommendations = _generate_recommendations( + quality_score, + fixes, + missing_data_analysis, + column_profiles, + min_quality_score, + ) + + # Make final decision + passed = len(blocking_issues) == 0 + + if passed: + decision_reason = ( + f"All quality checks passed. {', '.join(decision_factors)}" + ) + logger.info("โœ… Quality gate PASSED") + else: + decision_reason = f"Quality gate failed. Blocking issues: {'; '.join(blocking_issues)}" + logger.warning("โŒ Quality gate FAILED") + + # Log decision details + logger.info(f"Quality score: {quality_score:.1f}") + logger.info(f"Missing data: {overall_missing_pct:.1f}%") + logger.info(f"High severity issues: {len(high_severity_fixes)}") + + return QualityGateDecision( + passed=passed, + quality_score=quality_score, + decision_reason=decision_reason, + blocking_issues=blocking_issues, + recommendations=recommendations, + metadata={ + "decision_factors": decision_factors, + "thresholds": { + "min_quality_score": min_quality_score, + "max_missing_data_pct": max_missing_data_pct, + "block_on_high_severity": block_on_high_severity, + "require_target_column": require_target_column, + }, + "metrics": { + "overall_missing_pct": overall_missing_pct, + "high_severity_count": len(high_severity_fixes), + "total_fixes": len(fixes), + "column_count": len(column_profiles), + }, + }, + ) + + except Exception as e: + logger.error(f"Quality gate evaluation failed: {e}") + + # Return failure decision + return QualityGateDecision( + passed=False, + quality_score=0.0, + decision_reason=f"Quality gate evaluation failed: {str(e)}", + blocking_issues=[f"Technical error during evaluation: {str(e)}"], + recommendations=[ + "Review EDA report format and quality gate configuration" + ], + metadata={"error": str(e)}, + ) + + +def _generate_recommendations( + quality_score: float, + fixes: list, + missing_data_analysis: dict, + column_profiles: dict, + min_threshold: float, +) -> list: + """Generate actionable recommendations based on quality assessment.""" + recommendations = [] + + # Score-based recommendations + if quality_score < min_threshold: + score_gap = min_threshold - quality_score + if score_gap > 30: + recommendations.append( + "Consider data cleaning or alternative data sources due to significant quality issues" + ) + elif score_gap > 15: + recommendations.append( + "Address high-priority data quality issues before proceeding" + ) + else: + recommendations.append( + "Minor quality improvements recommended but data is usable" + ) + + # Fix-based recommendations + critical_fixes = [ + fix for fix in fixes if fix.get("severity") in ["high", "critical"] + ] + if critical_fixes: + recommendations.append( + f"Implement {len(critical_fixes)} critical data quality fixes" + ) + + # Add specific recommendations for common issues + for fix in critical_fixes[:3]: # Top 3 critical fixes + if "missing" in fix.get("title", "").lower(): + recommendations.append( + "Consider imputation strategies for missing data" + ) + elif "duplicate" in fix.get("title", "").lower(): + recommendations.append( + "Remove or consolidate duplicate records" + ) + elif "outlier" in fix.get("title", "").lower(): + recommendations.append("Investigate and handle outlier values") + + # Missing data recommendations + overall_missing_pct = missing_data_analysis.get("missing_percentage", 0.0) + if overall_missing_pct > 20: + recommendations.append( + "High missing data detected - consider data imputation or collection improvements" + ) + elif overall_missing_pct > 10: + recommendations.append( + "Moderate missing data - review imputation strategies" + ) + + # Column-specific recommendations + if column_profiles: + high_missing_cols = [ + col + for col, profile in column_profiles.items() + if profile.get("null_percentage", 0) > 50 + ] + + if high_missing_cols: + if len(high_missing_cols) == 1: + recommendations.append( + f"Column '{high_missing_cols[0]}' has excessive missing data - consider removal or targeted collection" + ) + else: + recommendations.append( + f"{len(high_missing_cols)} columns have >50% missing data - review data collection process" + ) + + # Pipeline-specific recommendations + if quality_score >= min_threshold and len(critical_fixes) == 0: + recommendations.append( + "Data quality is acceptable for downstream processing" + ) + recommendations.append( + "Consider implementing monitoring for quality regression" + ) + elif quality_score >= min_threshold * 0.8: # Close to passing + recommendations.append( + "Data quality is borderline - implement fixes and re-evaluate" + ) + recommendations.append( + "Consider A/B testing with and without quality improvements" + ) + else: + recommendations.append( + "Significant data quality issues require attention before production use" + ) + recommendations.append( + "Consider data pipeline improvements or alternative data sources" + ) + + return recommendations + + +@step +def route_based_on_quality( + decision: QualityGateDecision, + on_pass_message: str = "Proceed to downstream processing", + on_fail_message: str = "Data quality insufficient - halt pipeline", +) -> Annotated[str, "routing_decision"]: + """Route pipeline execution based on quality gate decision. + + Simple routing step that can be used to conditionally execute + downstream steps based on quality gate results. + + Args: + decision: Quality gate decision result + on_pass_message: Message when quality gate passes + on_fail_message: Message when quality gate fails + + Returns: + Routing message indicating next steps + """ + if decision.passed: + logger.info(f"๐Ÿš€ {on_pass_message}") + return on_pass_message + else: + logger.warning(f"๐Ÿ›‘ {on_fail_message}") + return on_fail_message From 44e81902e8622decbbe81b1a8b22adf054bdd7d3 Mon Sep 17 00:00:00 2001 From: Hamza Tahir Date: Thu, 21 Aug 2025 22:42:58 +0200 Subject: [PATCH 02/14] easier example --- examples/pydantic_ai_eda/models.py | 68 +- .../pydantic_ai_eda/pipelines/eda_pipeline.py | 14 +- examples/pydantic_ai_eda/requirements.txt | 6 +- examples/pydantic_ai_eda/run.py | 290 +----- examples/pydantic_ai_eda/steps/__init__.py | 6 +- examples/pydantic_ai_eda/steps/agent_tools.py | 481 ++-------- examples/pydantic_ai_eda/steps/eda_agent.py | 901 ++---------------- .../pydantic_ai_eda/steps/quality_gate.py | 77 +- 8 files changed, 285 insertions(+), 1558 deletions(-) diff --git a/examples/pydantic_ai_eda/models.py b/examples/pydantic_ai_eda/models.py index 30253e03480..e17f0087d4f 100644 --- a/examples/pydantic_ai_eda/models.py +++ b/examples/pydantic_ai_eda/models.py @@ -1,8 +1,4 @@ -"""Data models for EDA pipeline with Pydantic AI. - -This module defines Pydantic models used throughout the EDA pipeline -for request/response handling, analysis results, and evaluation. -""" +"""Simple data models for Pydantic AI EDA pipeline.""" from typing import Any, Dict, List, Optional @@ -10,18 +6,7 @@ class DataSourceConfig(BaseModel): - """Configuration for data source ingestion. - - Supports HuggingFace datasets, local files, and warehouse connections. - - Attributes: - source_type: Type of data source (hf, local, warehouse) - source_path: Path/identifier for the data source - target_column: Optional target column for analysis focus - sampling_strategy: How to sample the data (random, stratified, first_n) - sample_size: Number of rows to sample (None for all data) - warehouse_config: Additional config for warehouse connections - """ + """Simple data source configuration.""" source_type: str = Field( description="Data source type: hf, local, or warehouse" @@ -32,13 +17,9 @@ class DataSourceConfig(BaseModel): target_column: Optional[str] = Field( None, description="Optional target column name" ) - sampling_strategy: str = Field("random", description="Sampling strategy") sample_size: Optional[int] = Field( None, description="Number of rows to sample" ) - warehouse_config: Optional[Dict[str, Any]] = Field( - None, description="Warehouse connection config" - ) class DataQualityFix(BaseModel): @@ -94,26 +75,28 @@ class EDAReport(BaseModel): headline: str = Field(description="Executive summary of key findings") key_findings: List[str] = Field( - description="Important discoveries about the data" + default_factory=list, + description="Important discoveries about the data", ) risks: List[str] = Field( - description="Potential risks identified in the data" + default_factory=list, + description="Potential risks identified in the data", ) fixes: List[DataQualityFix] = Field( - description="Recommended data quality fixes" + default_factory=list, description="Recommended data quality fixes" ) data_quality_score: float = Field( description="Overall quality score (0-100)" ) markdown: str = Field(description="Full markdown report") - column_profiles: Dict[str, Dict[str, Any]] = Field( - description="Statistical profiles per column" + column_profiles: Optional[Dict[str, Dict[str, Any]]] = Field( + default_factory=dict, description="Statistical profiles per column" ) correlation_insights: List[str] = Field( - description="Key correlation findings" + default_factory=list, description="Key correlation findings" ) - missing_data_analysis: Dict[str, Any] = Field( - description="Missing data patterns" + missing_data_analysis: Optional[Dict[str, Any]] = Field( + default_factory=dict, description="Missing data patterns" ) @@ -145,27 +128,10 @@ class QualityGateDecision(BaseModel): class AgentConfig(BaseModel): - """Configuration for Pydantic AI agent behavior. + """Simple configuration for Pydantic AI agent.""" - Controls how the EDA agent operates including model selection, - tool usage limits, and safety constraints. - - Attributes: - model_name: Name of the language model to use - max_tool_calls: Maximum number of tool calls allowed - sql_guard_enabled: Whether to enable SQL safety guards - preview_limit: Maximum rows to show in data previews - enable_plotting: Whether to enable chart/plot generation - timeout_seconds: Maximum execution time in seconds - """ - - model_name: str = Field("gpt-5", description="Language model to use") - max_tool_calls: int = Field(50, description="Maximum tool calls allowed") - sql_guard_enabled: bool = Field( - True, description="Enable SQL safety guards" - ) - preview_limit: int = Field(10, description="Max rows in data previews") - enable_plotting: bool = Field( - False, description="Enable plotting capabilities" + model_name: str = Field("gpt-4o-mini", description="Language model to use") + max_tool_calls: int = Field(6, description="Maximum tool calls allowed") + timeout_seconds: int = Field( + 60, description="Max execution time in seconds" ) - timeout_seconds: int = Field(300, description="Max execution time") diff --git a/examples/pydantic_ai_eda/pipelines/eda_pipeline.py b/examples/pydantic_ai_eda/pipelines/eda_pipeline.py index 78487c5d27f..bf18c99954a 100644 --- a/examples/pydantic_ai_eda/pipelines/eda_pipeline.py +++ b/examples/pydantic_ai_eda/pipelines/eda_pipeline.py @@ -10,9 +10,8 @@ from models import AgentConfig, DataSourceConfig from steps import ( - evaluate_quality_gate, + evaluate_quality_gate_with_routing, ingest_data, - route_based_on_quality, run_eda_agent, ) @@ -61,8 +60,8 @@ def eda_pipeline( agent_config=agent_config, ) - # Step 3: Evaluate data quality gate - quality_decision = evaluate_quality_gate( + # Step 3: Evaluate data quality gate and get routing decision + quality_decision, routing_message = evaluate_quality_gate_with_routing( report_json=report_json, min_quality_score=min_quality_score, block_on_high_severity=block_on_high_severity, @@ -71,13 +70,6 @@ def eda_pipeline( target_column=source_config.target_column, ) - # Step 4: Route based on quality assessment - routing_message = route_based_on_quality( - decision=quality_decision, - on_pass_message="Data quality acceptable - ready for downstream processing", - on_fail_message="Data quality insufficient - requires remediation before use", - ) - # Log pipeline summary (note: artifacts are returned, actual values logged in steps) logger.info("Pipeline steps completed successfully") logger.info("Check step outputs for detailed analysis results") diff --git a/examples/pydantic_ai_eda/requirements.txt b/examples/pydantic_ai_eda/requirements.txt index 94380d9fcab..93b1bee7e3b 100644 --- a/examples/pydantic_ai_eda/requirements.txt +++ b/examples/pydantic_ai_eda/requirements.txt @@ -1,5 +1,5 @@ # ZenML (use existing installation) -# zenml +zenml # Core dependencies with version constraints for compatibility pandas>=2.0.0,<3.0.0 @@ -7,10 +7,10 @@ numpy>=1.24.0,<2.0.0 duckdb>=1.0.0,<2.0.0 # Pydantic compatibility - CRITICAL: Must be <2.10 for ZenML compatibility -pydantic>=2.8.0,<2.10.0 +pydantic>=2.0.0,<2.10.0 # AI/ML frameworks -pydantic-ai[logfire]>=0.7.0 +pydantic-ai[logfire]>=0.4.0 openai>=1.0.0,<2.0.0 anthropic>=0.30.0 # Alternative LLM provider diff --git a/examples/pydantic_ai_eda/run.py b/examples/pydantic_ai_eda/run.py index 67d7fbe840c..efec12111c3 100644 --- a/examples/pydantic_ai_eda/run.py +++ b/examples/pydantic_ai_eda/run.py @@ -1,298 +1,58 @@ #!/usr/bin/env python3 -"""Run the Pydantic AI EDA pipeline. +"""Simple Pydantic AI EDA pipeline runner.""" -This script provides multiple ways to run the EDA pipeline: -- With HuggingFace datasets (default) -- With local CSV files -- With different quality thresholds -- For testing and production scenarios - -Works with or without API keys (falls back to statistical analysis). -""" - -import argparse import os -import sys -from pathlib import Path -from typing import Optional from models import AgentConfig, DataSourceConfig from pipelines.eda_pipeline import eda_pipeline -def create_sample_dataset(): - """Create a sample iris dataset CSV for local testing.""" - try: - import pandas as pd - from sklearn.datasets import load_iris - - print("๐Ÿ“ Creating sample dataset...") - iris = load_iris() - df = pd.DataFrame(iris.data, columns=iris.feature_names) - df["target"] = iris.target - - df.to_csv("iris_sample.csv", index=False) - print(f"โœ… Created iris_sample.csv with {len(df)} rows") - return "iris_sample.csv" - except ImportError: - print("โŒ sklearn not available for sample dataset creation") - return None - +def main(): + """Run the EDA pipeline with simple configuration.""" + print("๐Ÿ” Pydantic AI EDA Pipeline") + print("=" * 30) -def check_api_keys(): - """Check for available API keys and return provider info.""" + # Check for API keys has_openai = bool(os.getenv("OPENAI_API_KEY")) has_anthropic = bool(os.getenv("ANTHROPIC_API_KEY")) - if has_openai and has_anthropic: - print("๐Ÿค– Both OpenAI and Anthropic API keys detected") - return "both" - elif has_openai: - print("๐Ÿค– OpenAI API key detected - will use GPT models") - return "openai" - elif has_anthropic: - print("๐Ÿค– Anthropic API key detected - will use Claude models") - return "anthropic" - else: - print("โš ๏ธ No API keys found - will use statistical fallback analysis") - print( - " Set OPENAI_API_KEY or ANTHROPIC_API_KEY for full AI features" - ) - return None - + if not (has_openai or has_anthropic): + print("โŒ No API keys found!") + print("Set OPENAI_API_KEY or ANTHROPIC_API_KEY environment variable") + return -def run_pipeline( - source_type: str = "hf", - source_path: str = "scikit-learn/iris", - target_column: Optional[str] = "target", - min_quality_score: float = 70.0, - ai_provider: Optional[str] = None, - timeout: int = 300, - sample_size: Optional[int] = None, - verbose: bool = False, -): - """Run the EDA pipeline with specified configuration.""" + model_name = "gpt-4o-mini" if has_openai else "claude-3-haiku-20240307" + print(f"๐Ÿค– Using model: {model_name}") - # Configure data source + # Simple configuration source_config = DataSourceConfig( - source_type=source_type, - source_path=source_path, - target_column=target_column, - sample_size=sample_size, + source_type="hf", + source_path="scikit-learn/iris", + target_column="target", ) - # Configure AI agent based on available providers - if ai_provider == "anthropic": - model_name = "claude-4" - elif ai_provider == "openai": - model_name = "gpt-5" - else: - model_name = "gpt-5" # Default fallback - agent_config = AgentConfig( model_name=model_name, - max_tool_calls=15, # Reduced to prevent infinite loops - sql_guard_enabled=True, - preview_limit=10, - timeout_seconds=timeout, - temperature=0.1, + max_tool_calls=6, # Keep it snappy - just the essentials + timeout_seconds=60, # Quick analysis ) - print(f"๐Ÿ“Š Analyzing dataset: {source_config.source_path}") - if target_column: - print(f"๐ŸŽฏ Target column: {target_column}") - print(f"๐Ÿ“ Quality threshold: {min_quality_score}") + print(f"๐Ÿ“Š Analyzing: {source_config.source_path}") try: - print("๐Ÿš€ Running EDA pipeline") results = eda_pipeline.with_options(enable_cache=False)( source_config=source_config, agent_config=agent_config, - min_quality_score=min_quality_score, - block_on_high_severity=False, # Don't block for demo - max_missing_data_pct=30.0, - require_target_column=bool(target_column), - ) - - print("โœ… Pipeline completed successfully!") - - # Display results summary - print(f"\n{'=' * 60}") - print("๐Ÿ“‹ PIPELINE RESULTS") - print("=" * 60) - - # Show pipeline run info - if hasattr(results, "id"): - print(f"๐Ÿ“ Pipeline run ID: {results.id}") - if hasattr(results, "status"): - print(f"๐Ÿ“Š Status: {results.status}") - if hasattr(results, "name"): - print(f"๐Ÿท๏ธ Name: {results.name}") - - # Show artifact locations - print(f"\n๐Ÿ“ฆ Generated Artifacts:") - print(f" โ€ข EDA report (markdown): Available in ZenML dashboard") - print(f" โ€ข Analysis results (JSON): Available in ZenML dashboard") - print(f" โ€ข Quality assessment: Available in ZenML dashboard") - print(f" โ€ข SQL execution log: Available in ZenML dashboard") - print(f" โ€ข Analysis tables: Available in ZenML dashboard") - - # Show next steps - print(f"\n๐Ÿ“– Next Steps:") - print(f" โ€ข View full results in ZenML dashboard") - print( - f" โ€ข Access artifacts: results.steps['step_name'].outputs['artifact_name'].load()" + min_quality_score=70.0, ) - print(f" โ€ข Run with different parameters using command line options") - - if not ai_provider: - print(f"\n๐Ÿ”‘ For AI-powered analysis:") - print(f" โ€ข Set: export OPENAI_API_KEY='your-key'") - print(f" โ€ข Or: export ANTHROPIC_API_KEY='your-key'") - print(f" โ€ข Then re-run for intelligent insights") - + print("โœ… Pipeline completed! Check ZenML dashboard for results.") return results - except Exception as e: print(f"โŒ Pipeline failed: {e}") - - if verbose: - import traceback - - print(f"\n๐Ÿ” Full error traceback:") - traceback.print_exc() - - print(f"\n๐Ÿ”ง Troubleshooting:") - if source_type == "hf": - print(f" โ€ข Check internet connection for HuggingFace datasets") - print( - f" โ€ข Try local mode: python run.py --source-type local --create-sample" - ) - elif source_type == "local": - print(f" โ€ข Check file exists: {source_path}") - print(f" โ€ข Ensure file is valid CSV format") - - print(f" โ€ข Ensure ZenML is initialized: zenml init") - print(f" โ€ข Check ZenML stack: zenml stack list") - print(f" โ€ข Install dependencies: pip install -r requirements.txt") - - return None - - -def main(): - """Main CLI interface.""" - parser = argparse.ArgumentParser( - description="Run Pydantic AI EDA Pipeline", - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=""" -Examples: - # Basic usage with HuggingFace dataset - python run.py - - # Use local CSV file - python run.py --source-type local --source-path data.csv --target target_col - - # Create and use sample dataset - python run.py --source-type local --create-sample - - # Custom quality threshold - python run.py --min-quality-score 80 - - # Custom dataset with specific settings - python run.py --source-path username/dataset --sample-size 1000 --timeout 600 - """, - ) - - # Data source options - parser.add_argument( - "--source-type", - choices=["hf", "local", "warehouse"], - default="hf", - help="Data source type (default: hf)", - ) - parser.add_argument( - "--source-path", - default="scikit-learn/iris", - help="Dataset path (HF dataset name or file path) (default: scikit-learn/iris)", - ) - parser.add_argument( - "--target-column", - default="target", - help="Target column name (default: target)", - ) - parser.add_argument( - "--sample-size", - type=int, - help="Limit dataset to N rows for faster processing", - ) - - # Pipeline options - parser.add_argument( - "--min-quality-score", - type=float, - default=70.0, - help="Minimum quality score threshold (default: 70.0)", - ) - parser.add_argument( - "--timeout", - type=int, - default=300, - help="AI agent timeout in seconds (default: 300)", - ) - - # Utility options - parser.add_argument( - "--create-sample", - action="store_true", - help="Create iris_sample.csv for local testing", - ) - parser.add_argument( - "--verbose", action="store_true", help="Show detailed error traces" - ) - - args = parser.parse_args() - - print("๐Ÿš€ Pydantic AI EDA Pipeline") - print("=" * 40) - - # Create sample dataset if requested - if args.create_sample: - sample_file = create_sample_dataset() - if sample_file and args.source_type == "local": - args.source_path = sample_file - print(f"๐Ÿ”„ Switched to created sample: {sample_file}") - - # Check API key availability - ai_provider = check_api_keys() - - # Validate local file exists - if args.source_type == "local": - if not Path(args.source_path).exists(): - print(f"โŒ Local file not found: {args.source_path}") - if not args.create_sample: - print( - f"๐Ÿ’ก Try: python run.py --source-type local --create-sample" - ) - sys.exit(1) - - # Run the pipeline - results = run_pipeline( - source_type=args.source_type, - source_path=args.source_path, - target_column=args.target_column, - min_quality_score=args.min_quality_score, - ai_provider=ai_provider, - timeout=args.timeout, - sample_size=args.sample_size, - verbose=args.verbose, - ) - - if results: - print(f"\n๐ŸŽ‰ Pipeline completed successfully!") - sys.exit(0) - else: - print(f"\n๐Ÿ’ฅ Pipeline failed!") - sys.exit(1) + print("\nTroubleshooting:") + print("- Check your API key is valid") + print("- Ensure ZenML is initialized: zenml init") + print("- Install requirements: pip install -r requirements.txt") if __name__ == "__main__": diff --git a/examples/pydantic_ai_eda/steps/__init__.py b/examples/pydantic_ai_eda/steps/__init__.py index 92f58660db5..f90b91ef614 100644 --- a/examples/pydantic_ai_eda/steps/__init__.py +++ b/examples/pydantic_ai_eda/steps/__init__.py @@ -11,11 +11,11 @@ from .eda_agent import run_eda_agent from .ingest import ingest_data -from .quality_gate import evaluate_quality_gate, route_based_on_quality +from .quality_gate import evaluate_quality_gate, evaluate_quality_gate_with_routing __all__ = [ "ingest_data", "run_eda_agent", - "evaluate_quality_gate", - "route_based_on_quality", + "evaluate_quality_gate", + "evaluate_quality_gate_with_routing", ] \ No newline at end of file diff --git a/examples/pydantic_ai_eda/steps/agent_tools.py b/examples/pydantic_ai_eda/steps/agent_tools.py index 34a8f8edd43..c04309031b2 100644 --- a/examples/pydantic_ai_eda/steps/agent_tools.py +++ b/examples/pydantic_ai_eda/steps/agent_tools.py @@ -1,434 +1,161 @@ -"""Pydantic AI agent tools for SQL-based EDA analysis. +"""Simple Pydantic AI agent tools for EDA analysis.""" -This module provides the tools and dependencies that the Pydantic AI agent -uses to perform exploratory data analysis through SQL queries. -""" - -import logging -import re -import tempfile -from pathlib import Path -from typing import Any, Dict, List, Optional +from dataclasses import dataclass, field +from typing import Any, Dict, List import duckdb import pandas as pd -from pydantic import BaseModel - -# Import RunContext for tool signatures -try: - from pydantic_ai import RunContext -except ImportError: - # Define a fallback if not available - class RunContext: - def __init__(self, deps): - self.deps = deps +from pydantic_ai import ModelRetry, RunContext -logger = logging.getLogger(__name__) +@dataclass +class AnalystAgentDeps: + """Simple storage for analysis results with Out[n] references.""" + output: dict[str, pd.DataFrame] = field(default_factory=dict) + query_history: List[Dict[str, Any]] = field(default_factory=list) -class AnalystAgentDeps(BaseModel): - """Dependencies for the EDA analyst agent. + def store(self, value: pd.DataFrame) -> str: + """Store the output and return reference like Out[1] for the LLM.""" + ref = f"Out[{len(self.output) + 1}]" + self.output[ref] = value + return ref - Manages the agent's state including datasets, query history, - and analysis outputs. Acts as the context/memory for the agent. - - Attributes: - datasets: Mapping of reference names to DataFrames - query_history: Log of executed SQL queries - output_counter: Counter for generating unique output references - """ + def get(self, ref: str) -> pd.DataFrame: + if ref not in self.output: + raise ModelRetry( + f"Error: {ref} is not a valid variable reference. Check the previous messages and try again." + ) + return self.output[ref] - datasets: Dict[str, pd.DataFrame] = {} - query_history: List[Dict[str, Any]] = [] - output_counter: int = 0 - - class Config: - arbitrary_types_allowed = True - - def store_dataset(self, df: pd.DataFrame, ref_name: str = None) -> str: - """Store a dataset and return its reference name. - - Args: - df: DataFrame to store - ref_name: Optional custom reference name - - Returns: - Reference name for the stored dataset - """ - if ref_name is None: - self.output_counter += 1 - ref_name = f"Out[{self.output_counter}]" - - self.datasets[ref_name] = df.copy() - logger.info(f"Stored dataset as {ref_name} with shape {df.shape}") - return ref_name - - def get_dataset(self, ref_name: str) -> Optional[pd.DataFrame]: - """Retrieve a dataset by reference name.""" - return self.datasets.get(ref_name) - - def list_datasets(self) -> List[str]: - """List all available dataset references.""" - return list(self.datasets.keys()) - - def log_query( - self, - sql: str, - result_ref: str, - rows_returned: int, - execution_time_ms: float, - ): - """Log an executed SQL query.""" - self.query_history.append( - { - "sql": sql, - "result_ref": result_ref, - "rows_returned": rows_returned, - "execution_time_ms": execution_time_ms, - "timestamp": pd.Timestamp.now().isoformat(), - } - ) +def run_sql(ctx: RunContext[AnalystAgentDeps], dataset: str, sql: str) -> str: + """Run SQL query on a DataFrame using DuckDB. -def run_duckdb_query( - ctx: RunContext[AnalystAgentDeps], dataset_ref: str, sql: str -) -> str: - """Execute SQL query against a dataset using DuckDB. - - This is the primary tool the agent uses for data analysis. - Provides read-only access with safety guards against harmful SQL. + Note: Use 'dataset' as the table name in your SQL queries. Args: - deps: Agent dependencies containing datasets - dataset_ref: Reference to the dataset to query + ctx: Pydantic AI agent RunContext + dataset: reference to the DataFrame (e.g., 'Out[1]') sql: SQL query to execute - - Returns: - Message describing the query execution and result location """ - import time - - start_time = time.time() - - deps = ctx.deps try: - # Get the dataset - df = deps.get_dataset(dataset_ref) - if df is None: - return f"Error: Dataset '{dataset_ref}' not found. Available: {deps.list_datasets()}" - - # Validate SQL query for safety - if not _is_safe_sql(sql): - return "Error: SQL query contains prohibited operations. Only SELECT queries are allowed." - - # Auto-inject LIMIT if missing and query might return large results - modified_sql = _maybe_add_limit(sql, df) - - # Execute query with DuckDB - conn = duckdb.connect(":memory:") - - # Register the dataset - conn.register("dataset", df) - - # Execute the query - try: - result = conn.execute(modified_sql).fetchdf() - execution_time = (time.time() - start_time) * 1000 - - # Store result - result_ref = deps.store_dataset(result) - - # Log the query - deps.log_query( - modified_sql, result_ref, len(result), execution_time - ) - - # Return success message - rows_msg = f"{len(result)} row(s)" if len(result) != 1 else "1 row" - return f"Executed SQL successfully. Result stored as {result_ref} with {rows_msg}. Use display('{result_ref}') to view." + data = ctx.deps.get(dataset) + result = duckdb.query_df( + df=data, virtual_table_name="dataset", sql_query=sql + ) + ref = ctx.deps.store(result.df()) - except Exception as e: - return f"SQL execution error: {str(e)}" - finally: - conn.close() + # Log the query for tracking + ctx.deps.query_history.append( + {"sql": sql, "result_ref": ref, "rows_returned": len(result.df())} + ) + return f"Query executed successfully. Result stored as `{ref}` ({len(result.df())} rows)." except Exception as e: - logger.error(f"Error in run_duckdb_query: {e}") - return f"Tool error: {str(e)}" + raise ModelRetry(f"SQL query failed: {str(e)}") -def display_data( - ctx: RunContext[AnalystAgentDeps], dataset_ref: str, max_rows: int = 10 +def display( + ctx: RunContext[AnalystAgentDeps], dataset: str, rows: int = 5 ) -> str: - """Display a preview of a dataset. - - Shows the first few rows of a dataset in a readable format. - Used by the agent to peek at data content. + """Display the first few rows of a dataset. Args: - deps: Agent dependencies - dataset_ref: Reference to dataset to display - max_rows: Maximum number of rows to show - - Returns: - String representation of the dataset preview + ctx: Pydantic AI agent RunContext + dataset: reference to the DataFrame + rows: number of rows to display (default: 5) """ - deps = ctx.deps try: - df = deps.get_dataset(dataset_ref) - if df is None: - return f"Error: Dataset '{dataset_ref}' not found. Available: {deps.list_datasets()}" - - if len(df) == 0: - return f"Dataset {dataset_ref} is empty (0 rows, {len(df.columns)} columns)." - - # Show basic info - preview_rows = min(max_rows, len(df)) - info = f"Dataset {dataset_ref}: {len(df)} rows ร— {len(df.columns)} columns\n\n" - - # Show column info - info += "Columns and types:\n" - for col, dtype in df.dtypes.items(): - null_count = df[col].isnull().sum() - null_pct = null_count / len(df) * 100 if len(df) > 0 else 0 - info += f" {col}: {dtype} ({null_count} nulls, {null_pct:.1f}%)\n" - - info += f"\nFirst {preview_rows} rows:\n" - info += df.head(preview_rows).to_string(max_cols=10, max_colwidth=50) - - if len(df) > preview_rows: - info += f"\n... ({len(df) - preview_rows} more rows)" - - return info - + data = ctx.deps.get(dataset) + return f"Dataset {dataset} preview:\n{data.head(rows).to_string()}" except Exception as e: return f"Display error: {str(e)}" -def profile_data(ctx: RunContext[AnalystAgentDeps], dataset_ref: str) -> str: - """Generate statistical profile of a dataset. - - Provides comprehensive statistics about the dataset including - distributions, correlations, and data quality metrics. +def describe(ctx: RunContext[AnalystAgentDeps], dataset: str) -> str: + """Get statistical summary of a dataset. Args: - ctx: Agent context with dependencies - dataset_ref: Reference to dataset to profile - - Returns: - Detailed statistical profile as string + ctx: Pydantic AI agent RunContext + dataset: reference to the DataFrame """ - deps = ctx.deps try: - df = deps.get_dataset(dataset_ref) - if df is None: - return f"Error: Dataset '{dataset_ref}' not found. Available: {deps.list_datasets()}" + data = ctx.deps.get(dataset) - if len(df) == 0: - return f"Cannot profile empty dataset {dataset_ref}." + # Basic info + info = [ + f"Dataset {dataset} comprehensive summary:", + f"Shape: {data.shape[0]:,} rows ร— {data.shape[1]} columns", + f"Memory usage: {data.memory_usage(deep=True).sum() / 1024**2:.2f} MB", + "", + ] - profile = f"Statistical Profile for {dataset_ref}\n" - profile += "=" * 50 + "\n\n" + # Data types and missing info + info.append("Column Information:") + for col in data.columns: + dtype = str(data[col].dtype) + null_count = data[col].isnull().sum() + null_pct = (null_count / len(data)) * 100 + unique_count = data[col].nunique() - # Basic info - profile += f"Shape: {df.shape[0]:,} rows ร— {df.shape[1]} columns\n" - profile += f"Memory usage: {df.memory_usage(deep=True).sum() / 1024**2:.2f} MB\n\n" - - # Numeric columns summary - numeric_cols = df.select_dtypes(include=["number"]).columns.tolist() - if numeric_cols: - profile += "Numeric Columns:\n" - desc = df[numeric_cols].describe() - profile += desc.to_string() + "\n\n" - - # Categorical columns summary - cat_cols = df.select_dtypes( - include=["object", "category"] - ).columns.tolist() - if cat_cols: - profile += "Categorical Columns:\n" - for col in cat_cols[ - :5 - ]: # Limit to first 5 to avoid too much output - unique_count = df[col].nunique() - null_count = df[col].isnull().sum() - most_common = df[col].value_counts().head(3) - - profile += f" {col}:\n" - profile += f" Unique values: {unique_count}\n" - profile += f" Null values: {null_count}\n" - profile += f" Most common: {dict(most_common)}\n\n" - - # Missing data analysis - missing_data = df.isnull().sum() - if missing_data.sum() > 0: - profile += "Missing Data:\n" - for col, missing in missing_data[missing_data > 0].items(): - pct = missing / len(df) * 100 - profile += f" {col}: {missing} ({pct:.1f}%)\n" - profile += "\n" - - # Correlation for numeric columns (if more than 1 numeric column) - if len(numeric_cols) > 1: - corr_matrix = df[numeric_cols].corr() - # Find high correlations - high_corrs = [] - for i, col1 in enumerate(numeric_cols): - for j, col2 in enumerate(numeric_cols[i + 1 :], i + 1): - corr_val = corr_matrix.iloc[i, j] - if abs(corr_val) > 0.7: # High correlation threshold - high_corrs.append((col1, col2, corr_val)) - - if high_corrs: - profile += "High Correlations (>0.7):\n" - for col1, col2, corr in high_corrs: - profile += f" {col1} โ†” {col2}: {corr:.3f}\n" - profile += "\n" - - return profile + info.append( + f" {col}: {dtype} | {null_count} nulls ({null_pct:.1f}%) | {unique_count} unique" + ) + + info.append("\nStatistical Summary:") + info.append(data.describe(include="all").to_string()) + return "\n".join(info) except Exception as e: - return f"Profile error: {str(e)}" + return f"Describe error: {str(e)}" -def save_table_as_csv( - ctx: RunContext[AnalystAgentDeps], dataset_ref: str, filename: str = None +def analyze_correlations( + ctx: RunContext[AnalystAgentDeps], dataset: str ) -> str: - """Save a dataset as CSV file. - - Optional tool for exporting analysis results. Saves to temporary - directory and returns the file path. + """Analyze correlations between numeric variables. Args: - ctx: Agent context with dependencies - dataset_ref: Reference to dataset to save - filename: Optional filename (auto-generated if not provided) - - Returns: - Path to saved CSV file or error message + ctx: Pydantic AI agent RunContext + dataset: reference to the DataFrame """ - deps = ctx.deps try: - df = deps.get_dataset(dataset_ref) - if df is None: - return f"Error: Dataset '{dataset_ref}' not found. Available: {deps.list_datasets()}" - - # Generate filename if not provided - if filename is None: - filename = f"eda_export_{dataset_ref.replace('[', '').replace(']', '')}.csv" - - # Ensure .csv extension - if not filename.endswith(".csv"): - filename += ".csv" - - # Save to temporary directory - temp_dir = Path(tempfile.gettempdir()) / "zenml_eda_exports" - temp_dir.mkdir(exist_ok=True) - - file_path = temp_dir / filename - df.to_csv(file_path, index=False) - - return f"Saved {len(df)} rows to: {file_path}" + data = ctx.deps.get(dataset) + numeric_data = data.select_dtypes(include=["number"]) - except Exception as e: - return f"Save error: {str(e)}" + if len(numeric_data.columns) < 2: + return "Need at least 2 numeric columns for correlation analysis." + corr_matrix = numeric_data.corr() -def _is_safe_sql(sql: str) -> bool: - """Check if SQL query is safe (read-only operations). + # Store correlation matrix + corr_ref = ctx.deps.store(corr_matrix) - Blocks potentially harmful SQL operations to ensure the agent - can only perform analysis, not modify data or system state. - """ - sql_upper = sql.upper().strip() - - # Allow only SELECT statements - if not sql_upper.startswith("SELECT") and not sql_upper.startswith("WITH"): - return False - - # Block dangerous keywords - prohibited_keywords = [ - "DROP", - "DELETE", - "INSERT", - "UPDATE", - "ALTER", - "CREATE", - "TRUNCATE", - "REPLACE", - "MERGE", - "EXEC", - "EXECUTE", - "ATTACH", - "DETACH", - "PRAGMA", - "COPY", - "IMPORT", - "EXPORT", - "LOAD", - "INSTALL", - "SET GLOBAL", - "SET PERSIST", - ] - - for keyword in prohibited_keywords: - # Use word boundaries to avoid false positives - pattern = r"\b" + re.escape(keyword) + r"\b" - if re.search(pattern, sql_upper): - return False - - return True - - -def _maybe_add_limit( - sql: str, df: pd.DataFrame, default_limit: int = 1000 -) -> str: - """Add LIMIT clause to queries that might return large results. + # Find strong correlations + strong_corrs = [] + for i, col1 in enumerate(numeric_data.columns): + for j, col2 in enumerate(numeric_data.columns[i + 1 :], i + 1): + corr_val = corr_matrix.iloc[i, j] + if abs(corr_val) > 0.7: + strong_corrs.append(f"{col1} โ†” {col2}: {corr_val:.3f}") - Prevents the agent from accidentally creating huge result sets - that could cause memory issues or slow performance. - """ - sql_upper = sql.upper().strip() - - # If LIMIT already present, don't modify - if "LIMIT" in sql_upper: - return sql - - # If dataset is small, no need to limit - if len(df) <= default_limit: - return sql - - # Check if query might return large results - # (GROUP BY, aggregation functions usually return smaller results) - has_aggregation = any( - keyword in sql_upper - for keyword in [ - "GROUP BY", - "COUNT(", - "SUM(", - "AVG(", - "MAX(", - "MIN(", - "DISTINCT", + result = [ + f"Correlation analysis for {len(numeric_data.columns)} numeric columns:", + f"Correlation matrix stored as {corr_ref}", + "", ] - ) - - if has_aggregation: - return sql - # Add LIMIT clause - modified_sql = sql.rstrip() - if modified_sql.endswith(";"): - modified_sql = modified_sql[:-1] + if strong_corrs: + result.append("Strong correlations (|r| > 0.7):") + result.extend(f" {corr}" for corr in strong_corrs) + else: + result.append("No strong correlations (|r| > 0.7) found.") - return f"{modified_sql} LIMIT {default_limit}" + return "\n".join(result) + except Exception as e: + return f"Correlation analysis error: {str(e)}" -# Registry of available tools for the agent -AGENT_TOOLS = { - "run_duckdb": run_duckdb_query, - "display": display_data, - "profile": profile_data, - "save_csv": save_table_as_csv, -} +# Enhanced tool registry +AGENT_TOOLS = [run_sql, display, describe, analyze_correlations] diff --git a/examples/pydantic_ai_eda/steps/eda_agent.py b/examples/pydantic_ai_eda/steps/eda_agent.py index e540b6d8f40..fa21e24754c 100644 --- a/examples/pydantic_ai_eda/steps/eda_agent.py +++ b/examples/pydantic_ai_eda/steps/eda_agent.py @@ -1,16 +1,12 @@ -"""EDA agent step using Pydantic AI for automated data analysis. +"""Simple EDA agent step using Pydantic AI.""" -This step implements the core EDA agent that uses Pydantic AI to perform -intelligent exploratory data analysis through SQL queries and structured reporting. -""" - -import logging -import time from typing import Annotated, Any, Dict, List, Tuple import pandas as pd +from pydantic_ai import Agent +from pydantic_ai.settings import ModelSettings -from zenml import log_metadata, step +from zenml import step from zenml.types import MarkdownString # Logfire for observability @@ -21,72 +17,9 @@ except ImportError: LOGFIRE_AVAILABLE = False -from models import AgentConfig, DataQualityFix, EDAReport +from models import AgentConfig, EDAReport from steps.agent_tools import AGENT_TOOLS, AnalystAgentDeps -logger = logging.getLogger(__name__) - -# (HTML visualization removed as per simplification decision) - - -# EDA analysis prompt template -EDA_SYSTEM_PROMPT = """You are an expert data analyst. Your task is to perform a focused exploratory data analysis (EDA) and produce a final EDAReport within 15 tool calls maximum. - -## CRITICAL: Follow this exact workflow: -1. **Initial Assessment** (2-3 tool calls): - - display('dataset') to see the data structure - - profile('dataset') for basic statistics - - run_duckdb('dataset', 'SELECT COUNT(*), COUNT(DISTINCT *) as unique_rows FROM dataset') for size/duplicates - -2. **Data Quality Analysis** (3-4 tool calls): - - Check missing data: SELECT column_name, COUNT(*) - COUNT(column_name) as nulls FROM (SELECT * FROM dataset LIMIT 1) CROSS JOIN (SELECT column_name FROM information_schema.columns WHERE table_name = 'dataset') - - Identify duplicates if any found in step 1 - - Check for obvious outliers in numeric columns (use percentiles) - -3. **Key Insights** (2-3 tool calls): - - Calculate correlations between numeric columns if >1 numeric column exists - - Analyze categorical distributions for top categories - - Identify most important patterns or issues - -4. **STOP and Generate Report** (1 tool call): - - Produce the final EDAReport with all required fields - - Do NOT continue exploring after generating the report - -## Available Tools: -- run_duckdb(dataset_ref, sql): Execute SQL against 'dataset' table (read-only) -- display(dataset_ref): Show first 10 rows -- profile(dataset_ref): Get statistical summary -- save_csv(dataset_ref, filename): Save query results - -## SQL Guidelines: -- Table name is 'dataset' -- Use efficient aggregations, avoid SELECT * -- Limit large result sets with LIMIT clause -- Focus on summary statistics, not raw data exploration - -## REQUIRED Output Format: -You MUST produce an EDAReport with: -- headline: Executive summary (1 sentence) -- key_findings: 3-5 critical discoveries -- risks: Data quality issues found -- fixes: Specific DataQualityFix objects for issues -- data_quality_score: 0-100 score based on: - * Missing data: 0-15%=good(30pts), 16-30%=fair(20pts), >30%=poor(10pts) - * Duplicates: 0-5%=good(25pts), 6-15%=fair(15pts), >15%=poor(5pts) - * Schema quality: All columns have data=good(25pts), some empty=fair(15pts) - * Consistency: Clean data=good(20pts), issues found=poor(10pts) -- markdown: Summary report for humans -- column_profiles: Per-column statistics from profiling -- correlation_insights: Key relationships found -- missing_data_analysis: Missing data summary - -## EFFICIENCY RULES: -- Maximum 15 tool calls total -- Stop analysis once you have enough information for the report -- Focus on critical issues, not exhaustive exploration -- Generate the final report as soon as you have sufficient insights -- Do NOT keep exploring after finding basic patterns""" - @step def run_eda_agent( @@ -99,779 +32,95 @@ def run_eda_agent( Annotated[List[Dict[str, str]], "sql_execution_log"], Annotated[Dict[str, pd.DataFrame], "analysis_tables"], ]: - """Run Pydantic AI agent for EDA analysis. - - Executes an AI agent that performs comprehensive exploratory data analysis - on the provided dataset using SQL queries and statistical analysis. - - Args: - dataset_df: The dataset to analyze - dataset_metadata: Metadata about the dataset - agent_config: Configuration for agent behavior - - Returns: - Tuple of (report_markdown, report_json, sql_log, analysis_tables) - containing all artifacts generated during the EDA analysis - """ - start_time = time.time() + """Run simple Pydantic AI agent for EDA analysis.""" + if agent_config is None: + agent_config = AgentConfig() - # Configure Logfire with explicit token + # Configure Logfire for observability if LOGFIRE_AVAILABLE: try: logfire.configure() logfire.instrument_pydantic_ai() logfire.info("EDA agent starting", dataset_shape=dataset_df.shape) except Exception as e: - logger.warning(f"Failed to configure Logfire: {e}") - - if agent_config is None: - agent_config = AgentConfig() - - logger.info(f"Starting EDA analysis with {agent_config.model_name}") - logger.info(f"Dataset shape: {dataset_df.shape}") - - try: - # Initialize agent dependencies - deps = AnalystAgentDeps() - - # Store the main dataset as Out[1] - main_ref = deps.store_dataset(dataset_df, "Out[1]") - logger.info(f"Stored main dataset as {main_ref}") - - # Initialize and run the Pydantic AI agent - try: - import asyncio - - import nest_asyncio - from pydantic_ai import Agent - from pydantic_ai.models.anthropic import AnthropicModel - from pydantic_ai.models.openai import OpenAIModel - - # Select model based on configuration - if agent_config.model_name.startswith("gpt"): - model = OpenAIModel(agent_config.model_name) - elif agent_config.model_name.startswith("claude"): - model = AnthropicModel(agent_config.model_name) - else: - # Default to OpenAI - model = OpenAIModel("gpt-5") - logger.warning( - f"Unknown model {agent_config.model_name}, defaulting to gpt-5" - ) - - # Create the agent with tools and stricter limits - agent = Agent( - model=model, - system_prompt=EDA_SYSTEM_PROMPT, - deps_type=AnalystAgentDeps, - result_type=EDAReport, - ) - - # Set strict tool call limits - if hasattr(agent, "max_tool_calls"): - agent.max_tool_calls = 15 - elif hasattr(agent, "_max_tool_calls"): - agent._max_tool_calls = 15 - - # Register tools - for tool_func in AGENT_TOOLS.values(): - agent.tool(tool_func) - - # Prepare initial context with clear instructions - initial_prompt = f"""ANALYZE THIS DATASET EFFICIENTLY - Maximum 15 tool calls. - -Dataset Information: -- Shape: {dataset_df.shape} -- Columns: {list(dataset_df.columns)} -- Source: {dataset_metadata.get("source_type", "unknown")} - -WORKFLOW (stick to this exactly): -1. display('dataset') - see the data -2. profile('dataset') - get basic stats -3. run_duckdb('dataset', 'SELECT COUNT(*), COUNT(DISTINCT *) FROM dataset') - check duplicates -4. Check missing data patterns with SQL -5. Analyze key relationships/distributions (max 3 queries) -6. STOP and generate final EDAReport - -Do NOT over-analyze. Focus on critical issues only. Generate your report once you have the essential insights.""" - - # Run the agent with timeout - try: - - async def run_agent(): - result = await agent.run(initial_prompt, deps=deps) - return result - - # Run with timeout - result = asyncio.wait_for( - run_agent(), timeout=agent_config.timeout_seconds - ) - - # If we're in a sync context, run it - try: - result = asyncio.run(result) - except RuntimeError: - # Already in an event loop - nest_asyncio.apply() - result = asyncio.run(result) + print(f"Warning: Failed to configure Logfire: {e}") + + # Initialize agent dependencies and store the dataset + deps = AnalystAgentDeps() + main_ref = deps.store(dataset_df) + + # Create the EDA analyst agent with focused system prompt + system_prompt = """You are a data analyst. Perform quick but insightful EDA. + +FOCUS ON: +- Data quality score (0-100) based on missing data and duplicates +- Key patterns and distributions +- Notable correlations or anomalies +- 2-3 actionable recommendations + +Be concise but specific with numbers. Aim for quality insights, not exhaustive analysis.""" + + analyst_agent = Agent( + f"openai:{agent_config.model_name}", + deps_type=AnalystAgentDeps, + output_type=EDAReport, + output_retries=3, # Allow more retries for result validation + system_prompt=system_prompt, + model_settings=ModelSettings( + parallel_tool_calls=True, + ), + ) - except ImportError: - # Fallback to mock analysis if pydantic-ai not available - logger.warning( - "Pydantic AI not available, running fallback analysis" - ) - result = _run_fallback_analysis(dataset_df, deps) + # Register tools + for tool in AGENT_TOOLS: + analyst_agent.tool(tool) - except ImportError: - logger.warning( - "Pydantic AI not available, running fallback analysis" - ) - result = _run_fallback_analysis(dataset_df, deps) + # Run focused analysis + user_prompt = f"""Quick EDA analysis for dataset '{main_ref}' ({dataset_df.shape[0]} rows, {dataset_df.shape[1]} cols). - # Extract results - if hasattr(result, "data"): - eda_report = result.data - else: - eda_report = result +STEPS (keep it fast): +1. display('{main_ref}') - check data structure +2. describe('{main_ref}') - get key stats +3. run_sql('{main_ref}', 'SELECT COUNT(*) as total, COUNT(DISTINCT *) as unique FROM dataset') - check duplicates +4. If multiple numeric columns: analyze_correlations('{main_ref}') - processing_time_ms = int((time.time() - start_time) * 1000) +Generate EDAReport with data quality score and 2-3 key insights.""" - # Prepare return artifacts - report_markdown = eda_report.markdown + try: + result = analyst_agent.run_sync(user_prompt, deps=deps) + eda_report = result.output + except Exception as e: + # Simple fallback - create basic report + eda_report = EDAReport( + headline=f"Basic analysis of {dataset_df.shape[0]} rows, {dataset_df.shape[1]} columns", + key_findings=[ + f"Dataset contains {len(dataset_df)} rows and {len(dataset_df.columns)} columns" + ], + risks=["Analysis failed - using basic fallback"], + fixes=[], + data_quality_score=50.0, + markdown=f"# EDA Report\n\nBasic analysis failed: {str(e)}\n\nDataset shape: {dataset_df.shape}", + column_profiles={}, + correlation_insights=[], + missing_data_analysis={}, + ) - report_json = { + # Return results + return ( + MarkdownString(eda_report.markdown), + { "headline": eda_report.headline, "key_findings": eda_report.key_findings, - "risks": eda_report.risks, - "fixes": [fix.model_dump() for fix in eda_report.fixes], "data_quality_score": eda_report.data_quality_score, - "column_profiles": eda_report.column_profiles, - "correlation_insights": eda_report.correlation_insights, - "missing_data_analysis": eda_report.missing_data_analysis, - "processing_time_ms": processing_time_ms, "agent_metadata": { "model": agent_config.model_name, "tool_calls": len(deps.query_history), - "datasets_created": len(deps.datasets), }, - } - - # Get SQL execution log - sql_log = deps.query_history.copy() - - # Filter analysis tables (keep only reasonably sized ones) - analysis_tables = {} - for ref, df in deps.datasets.items(): - if ( - ref != "Out[1]" and len(df) <= 10000 - ): # Don't return huge tables - analysis_tables[ref] = df - - logger.info(f"EDA analysis completed in {processing_time_ms}ms") - logger.info(f"Generated {len(analysis_tables)} analysis tables") - logger.info(f"Data quality score: {eda_report.data_quality_score}") - - # Log enhanced metadata for ZenML dashboard - _log_agent_metadata( - agent_config=agent_config, - eda_report=eda_report, - processing_time_ms=processing_time_ms, - sql_log=sql_log, - analysis_tables=analysis_tables, - result=result if "result" in locals() else None, - ) - - # Determine tool names dynamically when possible - tool_names: List[str] = list(AGENT_TOOLS.keys()) - try: - if "agent" in locals(): - tools_attr = getattr(agent, "tools", None) - if tools_attr is None: - tools_attr = getattr(agent, "_tools", None) - if tools_attr is not None: - if isinstance(tools_attr, dict): - tool_names = list(tools_attr.keys()) - else: - possible_keys = getattr(tools_attr, "keys", None) - if callable(possible_keys): - tool_names = list(possible_keys()) - except Exception: - # Best-effort fallback - pass - - # Convert markdown to MarkdownString for proper rendering - markdown_artifact = MarkdownString(report_markdown) - - return markdown_artifact, report_json, sql_log, analysis_tables - - except Exception as e: - logger.error(f"EDA agent failed: {e}") - - # Run fallback analysis to preserve any existing context and generate basic results - logger.info("Running fallback analysis after agent failure") - fallback_report = _run_fallback_analysis(dataset_df, deps) - processing_time_ms = int((time.time() - start_time) * 1000) - - # Prepare return artifacts with fallback data - report_markdown = ( - fallback_report.markdown - + f"\n\n**Note:** Analysis completed in fallback mode after agent error: {str(e)}" - ) - - report_json = { - "headline": f"EDA analysis completed (fallback mode after error: {str(e)})", - "key_findings": fallback_report.key_findings, - "risks": fallback_report.risks - + [f"Original agent failed: {str(e)}"], - "fixes": [fix.model_dump() for fix in fallback_report.fixes], - "data_quality_score": fallback_report.data_quality_score, - "column_profiles": fallback_report.column_profiles, - "correlation_insights": fallback_report.correlation_insights, - "missing_data_analysis": fallback_report.missing_data_analysis, - "processing_time_ms": processing_time_ms, - "agent_metadata": { - "model": "fallback_after_error", - "tool_calls": len(deps.query_history), - "datasets_created": len(deps.datasets), - "error": str(e), - }, - } - - # Get SQL execution log (preserving any queries that were executed before failure) - sql_log = deps.query_history.copy() - - # Filter analysis tables (preserving any tables created before failure) - analysis_tables = {} - for ref, df in deps.datasets.items(): - if ref != "Out[1]" and ref != "main_dataset" and len(df) <= 10000: - analysis_tables[ref] = df - - # Convert markdown to MarkdownString - error_markdown_artifact = MarkdownString(report_markdown) - - logger.info( - f"Fallback analysis preserved {len(sql_log)} SQL queries and {len(analysis_tables)} analysis tables" - ) - - # Log error metadata - log_metadata( - { - "ai_agent_execution": { - "model_name": agent_config.model_name - if agent_config - else "unknown", - "processing_time_ms": processing_time_ms, - "success": False, - "error_message": str(e), - "tool_calls_made": len(sql_log), - "datasets_created": len(analysis_tables), - "fallback_mode": True, - }, - "data_quality_assessment": { - "quality_score": fallback_report.data_quality_score, - "issues_found": len(fallback_report.fixes), - "risk_count": len(fallback_report.risks), - "key_findings_count": len(fallback_report.key_findings), - }, - } - ) - - return ( - error_markdown_artifact, - report_json, - sql_log, - analysis_tables, - ) - - -def _run_fallback_analysis( - dataset_df: pd.DataFrame, deps: AnalystAgentDeps -) -> EDAReport: - """Fallback analysis when Pydantic AI is not available. - - Performs basic statistical analysis and generates a simple report - using pandas operations instead of AI-driven analysis. Also simulates - some SQL queries to populate logs and analysis tables. - """ - logger.info("Running fallback EDA analysis") - - # Store the main dataset - deps.store_dataset(dataset_df, "main_dataset") - - # Run some basic SQL-like analysis to populate logs and tables - _run_fallback_sql_analysis(dataset_df, deps) - - # Basic statistics - numeric_cols = dataset_df.select_dtypes( - include=["number"] - ).columns.tolist() - categorical_cols = dataset_df.select_dtypes( - include=["object", "category"] - ).columns.tolist() - - # Calculate missing data - missing_counts = dataset_df.isnull().sum() - missing_pct = (missing_counts / len(dataset_df) * 100).round(2) - - # Calculate quality score - quality_factors = [] - - # Missing data factor (0-40 points) - overall_missing_pct = ( - dataset_df.isnull().sum().sum() - / (len(dataset_df) * len(dataset_df.columns)) - * 100 - ) - if overall_missing_pct <= 5: - quality_factors.append(40) - elif overall_missing_pct <= 15: - quality_factors.append(30) - elif overall_missing_pct <= 30: - quality_factors.append(20) - else: - quality_factors.append(10) - - # Duplicate factor (0-30 points) - duplicate_count = dataset_df.duplicated().sum() - duplicate_pct = duplicate_count / len(dataset_df) * 100 - if duplicate_pct <= 1: - quality_factors.append(30) - elif duplicate_pct <= 5: - quality_factors.append(25) - elif duplicate_pct <= 15: - quality_factors.append(15) - else: - quality_factors.append(5) - - # Schema completeness factor (0-30 points) - if len(dataset_df.columns) > 0: - non_empty_cols = (dataset_df.notna().any()).sum() - schema_completeness = non_empty_cols / len(dataset_df.columns) * 30 - quality_factors.append(int(schema_completeness)) - else: - quality_factors.append(0) - - data_quality_score = sum(quality_factors) - - # Generate key findings - key_findings = [] - key_findings.append( - f"Dataset contains {len(dataset_df):,} rows and {len(dataset_df.columns)} columns" - ) - - if numeric_cols: - key_findings.append( - f"Found {len(numeric_cols)} numeric columns for quantitative analysis" - ) - - if categorical_cols: - key_findings.append( - f"Found {len(categorical_cols)} categorical columns for segmentation analysis" - ) - - if overall_missing_pct > 10: - key_findings.append( - f"Missing data is {overall_missing_pct:.1f}% overall, requiring attention" - ) - - if duplicate_count > 0: - key_findings.append( - f"Found {duplicate_count:,} duplicate rows ({duplicate_pct:.1f}%)" - ) - - # Generate risks - risks = [] - if overall_missing_pct > 20: - risks.append( - "High percentage of missing data may impact analysis quality" - ) - - if duplicate_pct > 10: - risks.append( - "Significant duplicate data may skew statistical analysis" - ) - - if len(numeric_cols) == 0: - risks.append("No numeric columns found for quantitative analysis") - - # Generate fixes - fixes = [] - - high_missing_cols = missing_pct[missing_pct > 30].index.tolist() - if high_missing_cols: - fixes.append( - DataQualityFix( - title="Address high missing data in key columns", - rationale=f"Columns {high_missing_cols} have >30% missing values", - severity="high", - code_snippet="df.dropna(subset=['high_missing_col']) or df.fillna(method='forward')", - affected_columns=high_missing_cols, - estimated_impact=0.3, - ) - ) - - if duplicate_count > 0: - fixes.append( - DataQualityFix( - title="Remove duplicate records", - rationale=f"Found {duplicate_count:,} duplicate rows affecting data integrity", - severity="medium" if duplicate_pct < 10 else "high", - code_snippet="df.drop_duplicates(inplace=True)", - affected_columns=list(dataset_df.columns), - estimated_impact=duplicate_pct / 100, - ) - ) - - # Column profiles - column_profiles = {} - for col in dataset_df.columns: - profile = { - "dtype": str(dataset_df[col].dtype), - "null_count": int(missing_counts[col]), - "null_percentage": float(missing_pct[col]), - "unique_count": int(dataset_df[col].nunique()), - } - - if col in numeric_cols and dataset_df[col].notna().sum() > 0: - profile.update( - { - "mean": float(dataset_df[col].mean()), - "std": float(dataset_df[col].std()), - "min": float(dataset_df[col].min()), - "max": float(dataset_df[col].max()), - "median": float(dataset_df[col].median()), - } - ) - elif col in categorical_cols and dataset_df[col].notna().sum() > 0: - value_counts = dataset_df[col].value_counts().head(5) - profile["top_values"] = value_counts.to_dict() - - column_profiles[col] = profile - - # Correlation insights - correlation_insights = [] - if len(numeric_cols) > 1: - corr_matrix = dataset_df[numeric_cols].corr() - high_corrs = [] - for i, col1 in enumerate(numeric_cols): - for j, col2 in enumerate(numeric_cols[i + 1 :], i + 1): - corr_val = corr_matrix.iloc[i, j] - if abs(corr_val) > 0.7: - high_corrs.append((col1, col2, corr_val)) - - if high_corrs: - correlation_insights.append( - f"Found {len(high_corrs)} high correlations (>0.7)" - ) - for col1, col2, corr in high_corrs[:3]: # Show top 3 - correlation_insights.append( - f"{col1} and {col2} are strongly correlated ({corr:.3f})" - ) - else: - correlation_insights.append( - "No strong correlations (>0.7) detected between numeric variables" - ) - - # Missing data analysis - missing_data_analysis = { - "total_missing_cells": int(missing_counts.sum()), - "missing_percentage": float(overall_missing_pct), - "columns_with_missing": missing_counts[missing_counts > 0].to_dict(), - "completely_missing_columns": missing_counts[ - missing_counts == len(dataset_df) - ].index.tolist(), - } - - # Generate markdown report - markdown_report = f"""# EDA Report (Fallback Analysis) - -## Executive Summary -{key_findings[0] if key_findings else "Basic dataset analysis completed"} - -**Data Quality Score: {data_quality_score}/100** - -## Key Findings -{chr(10).join(f"- {finding}" for finding in key_findings)} - -## Data Quality Issues -{chr(10).join(f"- {risk}" for risk in risks) if risks else "No major quality issues detected"} - -## Column Overview -| Column | Type | Missing | Unique | -|--------|------|---------|--------| -{chr(10).join(f"| {col} | {profile['dtype']} | {profile['null_percentage']:.1f}% | {profile['unique_count']} |" for col, profile in column_profiles.items())} - -## Recommendations -{chr(10).join(f"- {fix.title}: {fix.rationale}" for fix in fixes) if fixes else "- Dataset appears to be in good condition"} - -*Report generated by ZenML EDA Pipeline (fallback mode)* -""" - - return EDAReport( - headline=key_findings[0] - if key_findings - else "Basic EDA analysis completed", - key_findings=key_findings, - risks=risks, - fixes=fixes, - data_quality_score=data_quality_score, - markdown=markdown_report, - column_profiles=column_profiles, - correlation_insights=correlation_insights, - missing_data_analysis=missing_data_analysis, - ) - - -def _run_fallback_sql_analysis( - dataset_df: pd.DataFrame, deps: AnalystAgentDeps -): - """Run basic SQL-like analysis to populate logs and analysis tables for fallback mode.""" - import time - - import duckdb - - try: - # Create DuckDB connection - conn = duckdb.connect(":memory:") - conn.register("dataset", dataset_df) - - # Run some basic SQL queries to simulate what the AI agent would do - fallback_queries = [ - ("SELECT COUNT(*) as row_count FROM dataset", "basic_stats"), - ( - "SELECT COUNT(*) as col_count FROM (SELECT * FROM dataset LIMIT 1)", - "column_count", - ), - ] - - # Add column-specific queries - for col in dataset_df.columns[:5]: # Limit to first 5 columns - # Escape column names with spaces - escaped_col = f'"{col}"' if " " in col else col - - # Count nulls - fallback_queries.append( - ( - f"SELECT COUNT(*) - COUNT({escaped_col}) as null_count FROM dataset", - f"nulls_{col.replace(' ', '_')}", - ) - ) - - # Get unique counts for non-numeric columns - if dataset_df[col].dtype == "object" or dataset_df[ - col - ].dtype.name.startswith("str"): - fallback_queries.append( - ( - f"SELECT COUNT(DISTINCT {escaped_col}) as unique_count FROM dataset WHERE {escaped_col} IS NOT NULL", - f"unique_{col.replace(' ', '_')}", - ) - ) - else: - # Basic stats for numeric columns - fallback_queries.append( - ( - f"SELECT AVG({escaped_col}) as avg_val, MIN({escaped_col}) as min_val, MAX({escaped_col}) as max_val FROM dataset WHERE {escaped_col} IS NOT NULL", - f"stats_{col.replace(' ', '_')}", - ) - ) - - # Execute queries and log them - for sql, description in fallback_queries: - try: - start_time = time.time() - result_df = conn.execute(sql).fetchdf() - execution_time = (time.time() - start_time) * 1000 - - # Store result table - result_ref = deps.store_dataset( - result_df, f"fallback_{description}" - ) - - # Log the query - deps.log_query( - sql=sql, - result_ref=result_ref, - rows_returned=len(result_df), - execution_time_ms=execution_time, - ) - - except Exception as e: - logger.warning(f"Fallback query failed: {sql} - {e}") - - conn.close() - logger.info( - f"Fallback analysis executed {len(fallback_queries)} SQL queries" - ) - - except Exception as e: - logger.warning(f"Could not run fallback SQL analysis: {e}") - - -def _log_agent_metadata( - agent_config: AgentConfig, - eda_report: EDAReport, - processing_time_ms: int, - sql_log: List[Dict[str, str]], - analysis_tables: Dict[str, pd.DataFrame], - result: Any = None, -) -> None: - """Log enhanced metadata about AI agent execution for ZenML dashboard.""" - - # Calculate token usage if available - token_usage = {} - cost_estimate = None - if result and hasattr(result, "usage"): - token_usage = { - "input_tokens": getattr(result.usage, "prompt_tokens", 0), - "output_tokens": getattr(result.usage, "completion_tokens", 0), - "total_tokens": getattr(result.usage, "total_tokens", 0), - } - cost_estimate = _estimate_cost(result.usage, agent_config.model_name) - - # Calculate SQL execution metrics - sql_metrics = {} - if sql_log: - execution_times = [q.get("execution_time_ms", 0) for q in sql_log] - sql_metrics = { - "total_queries": len(sql_log), - "avg_execution_time_ms": sum(execution_times) - / len(execution_times) - if execution_times - else 0, - "max_execution_time_ms": max(execution_times) - if execution_times - else 0, - "total_rows_processed": sum( - q.get("rows_returned", 0) for q in sql_log - ), - "query_types": _analyze_query_types(sql_log), - } - - # Analyze data quality issues by severity - severity_breakdown = {"low": 0, "medium": 0, "high": 0, "critical": 0} - for fix in eda_report.fixes: - severity = fix.severity.lower() - if severity in severity_breakdown: - severity_breakdown[severity] += 1 - - # Log metadata to ZenML - log_metadata( + }, + deps.query_history, { - "ai_agent_execution": { - "model_name": agent_config.model_name, - "processing_time_ms": processing_time_ms, - "success": True, - "tool_calls_made": len(sql_log), - "datasets_created": len(analysis_tables), - }, - "token_usage": token_usage, - "cost_estimate_usd": cost_estimate, - "data_quality_assessment": { - "quality_score": eda_report.data_quality_score, - "issues_found": len(eda_report.fixes), - "severity_breakdown": severity_breakdown, - "risk_count": len(eda_report.risks), - "key_findings_count": len(eda_report.key_findings), - }, - "sql_execution_metrics": sql_metrics, - "analysis_summary": { - "headline": eda_report.headline, - "columns_analyzed": len(eda_report.column_profiles), - "correlation_insights": len(eda_report.correlation_insights), - "missing_data_pct": eda_report.missing_data_analysis.get( - "missing_percentage", 0 - ), - }, - } - ) - - -def _estimate_cost(usage, model_name: str) -> float: - """Estimate cost based on token usage and model pricing.""" - if not hasattr(usage, "total_tokens") or usage.total_tokens == 0: - return 0.0 - - # Rough pricing estimates (per 1M tokens) - pricing = { - "gpt-4": {"input": 30, "output": 60}, - "gpt-4o": {"input": 5, "output": 15}, - "gpt-4o-mini": {"input": 0.15, "output": 0.6}, - "claude-3-5-sonnet": {"input": 3, "output": 15}, - "claude-3-haiku": {"input": 0.25, "output": 1.25}, - } - - # Find matching pricing - model_pricing = None - for model_key, prices in pricing.items(): - if model_key in model_name.lower(): - model_pricing = prices - break - - if not model_pricing: - # Default rough estimate - return usage.total_tokens * 0.00001 - - input_cost = ( - getattr(usage, "prompt_tokens", 0) * model_pricing["input"] / 1_000_000 + ref: df + for ref, df in deps.output.items() + if ref != main_ref and len(df) <= 1000 + }, ) - output_cost = ( - getattr(usage, "completion_tokens", 0) - * model_pricing["output"] - / 1_000_000 - ) - - return input_cost + output_cost - - -def _analyze_query_types(sql_log: List[Dict[str, str]]) -> Dict[str, int]: - """Analyze the types of SQL queries executed.""" - query_types = {} - - for query in sql_log: - sql = query.get("sql", "").upper().strip() - - # Determine query type - if sql.startswith("SELECT COUNT"): - query_type = "count" - elif sql.startswith("SELECT DISTINCT") or "DISTINCT" in sql: - query_type = "distinct" - elif "GROUP BY" in sql: - query_type = "group_by" - elif "ORDER BY" in sql: - query_type = "ordered" - elif "WHERE" in sql: - query_type = "filtered" - elif sql.startswith("SELECT"): - query_type = "basic_select" - elif sql.startswith("WITH"): - query_type = "cte" - else: - query_type = "other" - - query_types[query_type] = query_types.get(query_type, 0) + 1 - - return query_types - - -def _create_error_report( - error_msg: str, dataset_df: pd.DataFrame -) -> Dict[str, Any]: - """Create an error report when analysis fails.""" - return { - "headline": "EDA analysis failed due to technical error", - "key_findings": [ - f"Analysis failed: {error_msg}", - f"Dataset has {len(dataset_df)} rows and {len(dataset_df.columns)} columns", - ], - "risks": [ - "Analysis could not be completed", - "Manual inspection required", - ], - "fixes": [], - "data_quality_score": 0.0, - "markdown": f"# EDA Analysis Failed\n\n**Error:** {error_msg}\n\nPlease check the dataset and configuration.", - "column_profiles": {}, - "correlation_insights": [], - "missing_data_analysis": {}, - } diff --git a/examples/pydantic_ai_eda/steps/quality_gate.py b/examples/pydantic_ai_eda/steps/quality_gate.py index 5f6a6729a47..b8e3eb7c3b7 100644 --- a/examples/pydantic_ai_eda/steps/quality_gate.py +++ b/examples/pydantic_ai_eda/steps/quality_gate.py @@ -5,7 +5,7 @@ """ import logging -from typing import Annotated, Any, Dict +from typing import Annotated, Any, Dict, Tuple from models import QualityGateDecision @@ -14,7 +14,6 @@ logger = logging.getLogger(__name__) -@step def evaluate_quality_gate( report_json: Dict[str, Any], min_quality_score: float = 70.0, @@ -129,6 +128,10 @@ def evaluate_quality_gate( decision_factors.append( f"Target column '{target_column}': Present โœ…" ) + elif require_target_column and not target_column: + # If target column is required but not specified + blocking_issues.append("Target column required but not specified") + decision_factors.append("Target column: Not specified โŒ") # Generate recommendations based on findings recommendations = _generate_recommendations( @@ -142,6 +145,23 @@ def evaluate_quality_gate( # Make final decision passed = len(blocking_issues) == 0 + # Debug logging + logger.info(f"Quality check details:") + logger.info( + f" - Quality score: {quality_score:.1f} >= {min_quality_score} = {quality_score >= min_quality_score}" + ) + logger.info( + f" - High severity issues: {len(high_severity_fixes)} (block_on_high_severity={block_on_high_severity})" + ) + logger.info( + f" - Missing data: {overall_missing_pct:.1f}% <= {max_missing_data_pct}% = {overall_missing_pct <= max_missing_data_pct}" + ) + logger.info( + f" - Require target: {require_target_column}, Target column: {target_column}" + ) + logger.info(f" - Blocking issues count: {len(blocking_issues)}") + logger.info(f" - Decision factors: {decision_factors}") + if passed: decision_reason = ( f"All quality checks passed. {', '.join(decision_factors)}" @@ -150,6 +170,7 @@ def evaluate_quality_gate( else: decision_reason = f"Quality gate failed. Blocking issues: {'; '.join(blocking_issues)}" logger.warning("โŒ Quality gate FAILED") + logger.warning(f"Blocking issues: {blocking_issues}") # Log decision details logger.info(f"Quality score: {quality_score:.1f}") @@ -299,27 +320,39 @@ def _generate_recommendations( @step -def route_based_on_quality( - decision: QualityGateDecision, - on_pass_message: str = "Proceed to downstream processing", - on_fail_message: str = "Data quality insufficient - halt pipeline", -) -> Annotated[str, "routing_decision"]: - """Route pipeline execution based on quality gate decision. - - Simple routing step that can be used to conditionally execute - downstream steps based on quality gate results. - - Args: - decision: Quality gate decision result - on_pass_message: Message when quality gate passes - on_fail_message: Message when quality gate fails +def evaluate_quality_gate_with_routing( + report_json: Dict[str, Any], + min_quality_score: float = 70.0, + block_on_high_severity: bool = True, + max_missing_data_pct: float = 30.0, + require_target_column: bool = False, + target_column: str = None, +) -> Tuple[ + Annotated[QualityGateDecision, "quality_gate_decision"], + Annotated[str, "routing_message"], +]: + """Combined quality gate evaluation and routing decision. - Returns: - Routing message indicating next steps + Returns both the detailed quality decision and routing message. """ + # Use existing logic to evaluate quality + decision = evaluate_quality_gate( + report_json, + min_quality_score, + block_on_high_severity, + max_missing_data_pct, + require_target_column, + target_column, + ) + + # Generate routing message if decision.passed: - logger.info(f"๐Ÿš€ {on_pass_message}") - return on_pass_message + routing_message = ( + "๐Ÿš€ Data quality passed - proceed to downstream processing" + ) + logger.info(routing_message) else: - logger.warning(f"๐Ÿ›‘ {on_fail_message}") - return on_fail_message + routing_message = "๐Ÿ›‘ Data quality insufficient - review and improve data before proceeding" + logger.warning(routing_message) + + return decision, routing_message From a665a7a186661112f8fde9e0352cb2746bb75601 Mon Sep 17 00:00:00 2001 From: Hamza Tahir Date: Thu, 21 Aug 2025 22:46:20 +0200 Subject: [PATCH 03/14] Refactor quality gate evaluation function --- .../pydantic_ai_eda/steps/quality_gate.py | 312 +++++++----------- 1 file changed, 116 insertions(+), 196 deletions(-) diff --git a/examples/pydantic_ai_eda/steps/quality_gate.py b/examples/pydantic_ai_eda/steps/quality_gate.py index b8e3eb7c3b7..7540e5062b1 100644 --- a/examples/pydantic_ai_eda/steps/quality_gate.py +++ b/examples/pydantic_ai_eda/steps/quality_gate.py @@ -1,11 +1,7 @@ -"""Quality gate step for data quality assessment and pipeline routing. - -Evaluates EDA results against quality thresholds to make pass/fail decisions -for downstream processing or model training workflows. -""" +"""Quality gate step for data quality assessment and pipeline routing.""" import logging -from typing import Annotated, Any, Dict, Tuple +from typing import Annotated, Any, Dict, List, Tuple from models import QualityGateDecision @@ -14,6 +10,21 @@ logger = logging.getLogger(__name__) +def _check_quality_metric( + condition: bool, + pass_msg: str, + fail_msg: str, + blocking_issues: List[str], + decision_factors: List[str], +) -> None: + """Helper to standardize quality check pattern.""" + if condition: + blocking_issues.append(fail_msg) + decision_factors.append(f"{pass_msg.split(':')[0]}: โŒ") + else: + decision_factors.append(f"{pass_msg}: โœ…") + + def evaluate_quality_gate( report_json: Dict[str, Any], min_quality_score: float = 70.0, @@ -22,102 +33,75 @@ def evaluate_quality_gate( require_target_column: bool = False, target_column: str = None, ) -> Annotated[QualityGateDecision, "quality_gate_decision"]: - """Evaluate data quality and make gate pass/fail decision. - - Analyzes the EDA report results against configurable quality thresholds - to determine if the data meets requirements for downstream processing. - - Args: - report_json: EDA report JSON containing quality metrics and findings - min_quality_score: Minimum data quality score required (0-100) - block_on_high_severity: Whether to fail on high-severity issues - max_missing_data_pct: Maximum allowable missing data percentage - require_target_column: Whether a target column is required - target_column: Expected target column name (if required) - - Returns: - QualityGateDecision with pass/fail result and recommendations - """ + """Evaluate data quality and make gate pass/fail decision.""" logger.info("Evaluating data quality gate") try: - # Extract key metrics from report + # Extract metrics quality_score = report_json.get("data_quality_score", 0.0) fixes = report_json.get("fixes", []) missing_data_analysis = report_json.get("missing_data_analysis", {}) column_profiles = report_json.get("column_profiles", {}) - # Initialize decision components blocking_issues = [] - recommendations = [] decision_factors = [] - # Check 1: Overall quality score - if quality_score < min_quality_score: - blocking_issues.append( - f"Data quality score ({quality_score:.1f}) below minimum threshold ({min_quality_score})" - ) - decision_factors.append( - f"Quality score: {quality_score:.1f}/{min_quality_score} โŒ" - ) - else: - decision_factors.append( - f"Quality score: {quality_score:.1f}/{min_quality_score} โœ…" - ) + # Quality score check + _check_quality_metric( + quality_score < min_quality_score, + f"Quality score: {quality_score:.1f}/{min_quality_score}", + f"Data quality score ({quality_score:.1f}) below minimum threshold ({min_quality_score})", + blocking_issues, + decision_factors, + ) - # Check 2: High severity issues - high_severity_fixes = [ - fix for fix in fixes if fix.get("severity") == "high" - ] + # High severity issues check + high_severity_fixes = [f for f in fixes if f.get("severity") == "high"] if block_on_high_severity and high_severity_fixes: - high_severity_titles = [ - fix.get("title", "Unknown issue") - for fix in high_severity_fixes - ] - blocking_issues.append( - f"High severity issues found: {', '.join(high_severity_titles)}" - ) - decision_factors.append( - f"High severity issues: {len(high_severity_fixes)} โŒ" + titles = [f.get("title", "Unknown") for f in high_severity_fixes] + _check_quality_metric( + True, + f"High severity issues: {len(high_severity_fixes)}", + f"High severity issues found: {', '.join(titles)}", + blocking_issues, + decision_factors, ) else: decision_factors.append( f"High severity issues: {len(high_severity_fixes)} โœ…" ) - # Check 3: Missing data threshold + # Missing data check overall_missing_pct = missing_data_analysis.get( "missing_percentage", 0.0 ) - if overall_missing_pct > max_missing_data_pct: - blocking_issues.append( - f"Missing data percentage ({overall_missing_pct:.1f}%) exceeds threshold ({max_missing_data_pct}%)" - ) - decision_factors.append( - f"Missing data: {overall_missing_pct:.1f}%/{max_missing_data_pct}% โŒ" - ) - else: - decision_factors.append( - f"Missing data: {overall_missing_pct:.1f}%/{max_missing_data_pct}% โœ…" - ) + _check_quality_metric( + overall_missing_pct > max_missing_data_pct, + f"Missing data: {overall_missing_pct:.1f}%/{max_missing_data_pct}%", + f"Missing data percentage ({overall_missing_pct:.1f}%) exceeds threshold ({max_missing_data_pct}%)", + blocking_issues, + decision_factors, + ) - # Check 4: Target column requirement - if require_target_column and target_column: - if target_column not in column_profiles: + # Target column check + if require_target_column: + if not target_column: blocking_issues.append( - f"Required target column '{target_column}' not found in dataset" + "Target column required but not specified" + ) + decision_factors.append("Target column: Not specified โŒ") + elif target_column not in column_profiles: + blocking_issues.append( + f"Required target column '{target_column}' not found" ) decision_factors.append( f"Target column '{target_column}': Missing โŒ" ) else: - # Check target column quality - target_profile = column_profiles[target_column] - target_missing_pct = target_profile.get("null_percentage", 0.0) - - if ( - target_missing_pct > 50 - ): # Target column shouldn't be mostly empty + target_missing_pct = column_profiles[target_column].get( + "null_percentage", 0.0 + ) + if target_missing_pct > 50: blocking_issues.append( f"Target column '{target_column}' has {target_missing_pct:.1f}% missing values" ) @@ -128,55 +112,26 @@ def evaluate_quality_gate( decision_factors.append( f"Target column '{target_column}': Present โœ…" ) - elif require_target_column and not target_column: - # If target column is required but not specified - blocking_issues.append("Target column required but not specified") - decision_factors.append("Target column: Not specified โŒ") - # Generate recommendations based on findings + # Generate recommendations recommendations = _generate_recommendations( - quality_score, - fixes, - missing_data_analysis, - column_profiles, - min_quality_score, + quality_score, fixes, overall_missing_pct, min_quality_score ) - # Make final decision + # Make decision passed = len(blocking_issues) == 0 - - # Debug logging - logger.info(f"Quality check details:") - logger.info( - f" - Quality score: {quality_score:.1f} >= {min_quality_score} = {quality_score >= min_quality_score}" - ) - logger.info( - f" - High severity issues: {len(high_severity_fixes)} (block_on_high_severity={block_on_high_severity})" - ) - logger.info( - f" - Missing data: {overall_missing_pct:.1f}% <= {max_missing_data_pct}% = {overall_missing_pct <= max_missing_data_pct}" + decision_reason = ( + f"All quality checks passed. {', '.join(decision_factors)}" + if passed + else f"Quality gate failed. Issues: {'; '.join(blocking_issues)}" ) + logger.info( - f" - Require target: {require_target_column}, Target column: {target_column}" + "โœ… Quality gate PASSED" if passed else "โŒ Quality gate FAILED" ) - logger.info(f" - Blocking issues count: {len(blocking_issues)}") - logger.info(f" - Decision factors: {decision_factors}") - - if passed: - decision_reason = ( - f"All quality checks passed. {', '.join(decision_factors)}" - ) - logger.info("โœ… Quality gate PASSED") - else: - decision_reason = f"Quality gate failed. Blocking issues: {'; '.join(blocking_issues)}" - logger.warning("โŒ Quality gate FAILED") + if not passed: logger.warning(f"Blocking issues: {blocking_issues}") - # Log decision details - logger.info(f"Quality score: {quality_score:.1f}") - logger.info(f"Missing data: {overall_missing_pct:.1f}%") - logger.info(f"High severity issues: {len(high_severity_fixes)}") - return QualityGateDecision( passed=passed, quality_score=quality_score, @@ -202,13 +157,11 @@ def evaluate_quality_gate( except Exception as e: logger.error(f"Quality gate evaluation failed: {e}") - - # Return failure decision return QualityGateDecision( passed=False, quality_score=0.0, decision_reason=f"Quality gate evaluation failed: {str(e)}", - blocking_issues=[f"Technical error during evaluation: {str(e)}"], + blocking_issues=[f"Technical error: {str(e)}"], recommendations=[ "Review EDA report format and quality gate configuration" ], @@ -219,101 +172,73 @@ def evaluate_quality_gate( def _generate_recommendations( quality_score: float, fixes: list, - missing_data_analysis: dict, - column_profiles: dict, + overall_missing_pct: float, min_threshold: float, ) -> list: """Generate actionable recommendations based on quality assessment.""" recommendations = [] - # Score-based recommendations - if quality_score < min_threshold: - score_gap = min_threshold - quality_score - if score_gap > 30: - recommendations.append( - "Consider data cleaning or alternative data sources due to significant quality issues" - ) - elif score_gap > 15: - recommendations.append( - "Address high-priority data quality issues before proceeding" - ) - else: - recommendations.append( - "Minor quality improvements recommended but data is usable" - ) + # Quality score recommendations + score_gap = min_threshold - quality_score + if score_gap > 30: + recommendations.append( + "Consider alternative data sources due to significant quality issues" + ) + elif score_gap > 15: + recommendations.append( + "Address high-priority data quality issues before proceeding" + ) + elif score_gap > 0: + recommendations.append("Minor quality improvements recommended") - # Fix-based recommendations + # Critical fixes critical_fixes = [ - fix for fix in fixes if fix.get("severity") in ["high", "critical"] + f for f in fixes if f.get("severity") in ["high", "critical"] ] if critical_fixes: recommendations.append( f"Implement {len(critical_fixes)} critical data quality fixes" ) - # Add specific recommendations for common issues - for fix in critical_fixes[:3]: # Top 3 critical fixes - if "missing" in fix.get("title", "").lower(): - recommendations.append( - "Consider imputation strategies for missing data" - ) - elif "duplicate" in fix.get("title", "").lower(): - recommendations.append( - "Remove or consolidate duplicate records" - ) - elif "outlier" in fix.get("title", "").lower(): - recommendations.append("Investigate and handle outlier values") + # Add specific fix types + fix_types = { + "missing": "Consider imputation strategies for missing data", + "duplicate": "Remove or consolidate duplicate records", + "outlier": "Investigate and handle outlier values", + } + + for fix in critical_fixes[:3]: + title = fix.get("title", "").lower() + for keyword, recommendation in fix_types.items(): + if keyword in title and recommendation not in recommendations: + recommendations.append(recommendation) + break # Missing data recommendations - overall_missing_pct = missing_data_analysis.get("missing_percentage", 0.0) if overall_missing_pct > 20: recommendations.append( - "High missing data detected - consider data imputation or collection improvements" + "High missing data detected - consider imputation or collection improvements" ) elif overall_missing_pct > 10: recommendations.append( "Moderate missing data - review imputation strategies" ) - # Column-specific recommendations - if column_profiles: - high_missing_cols = [ - col - for col, profile in column_profiles.items() - if profile.get("null_percentage", 0) > 50 - ] - - if high_missing_cols: - if len(high_missing_cols) == 1: - recommendations.append( - f"Column '{high_missing_cols[0]}' has excessive missing data - consider removal or targeted collection" - ) - else: - recommendations.append( - f"{len(high_missing_cols)} columns have >50% missing data - review data collection process" - ) - - # Pipeline-specific recommendations - if quality_score >= min_threshold and len(critical_fixes) == 0: - recommendations.append( - "Data quality is acceptable for downstream processing" - ) - recommendations.append( - "Consider implementing monitoring for quality regression" - ) - elif quality_score >= min_threshold * 0.8: # Close to passing - recommendations.append( - "Data quality is borderline - implement fixes and re-evaluate" + # Pipeline recommendations + if quality_score >= min_threshold and not critical_fixes: + recommendations.extend( + [ + "Data quality acceptable for downstream processing", + "Consider implementing quality monitoring", + ] ) + elif score_gap <= min_threshold * 0.2: # Close to passing recommendations.append( - "Consider A/B testing with and without quality improvements" + "Data quality borderline - implement fixes and re-evaluate" ) else: recommendations.append( - "Significant data quality issues require attention before production use" - ) - recommendations.append( - "Consider data pipeline improvements or alternative data sources" + "Significant quality issues require attention before production use" ) return recommendations @@ -331,11 +256,7 @@ def evaluate_quality_gate_with_routing( Annotated[QualityGateDecision, "quality_gate_decision"], Annotated[str, "routing_message"], ]: - """Combined quality gate evaluation and routing decision. - - Returns both the detailed quality decision and routing message. - """ - # Use existing logic to evaluate quality + """Combined quality gate evaluation and routing decision.""" decision = evaluate_quality_gate( report_json, min_quality_score, @@ -345,14 +266,13 @@ def evaluate_quality_gate_with_routing( target_column, ) - # Generate routing message - if decision.passed: - routing_message = ( - "๐Ÿš€ Data quality passed - proceed to downstream processing" - ) - logger.info(routing_message) - else: - routing_message = "๐Ÿ›‘ Data quality insufficient - review and improve data before proceeding" - logger.warning(routing_message) + routing_message = ( + "๐Ÿš€ Data quality passed - proceed to downstream processing" + if decision.passed + else "๐Ÿ›‘ Data quality insufficient - review and improve data before proceeding" + ) + logger.info(routing_message) if decision.passed else logger.warning( + routing_message + ) return decision, routing_message From f7477388b1327405aebf45a741fa3182fba7dbc9 Mon Sep 17 00:00:00 2001 From: Hamza Tahir Date: Fri, 22 Aug 2025 09:29:14 +0200 Subject: [PATCH 04/14] Update pipeline with prompt experimentation for Pydantic AI agents --- .../pydantic_ai_eda/pipelines/__init__.py | 9 +- .../pipelines/prompt_experiment_pipeline.py | 55 ++ .../pydantic_ai_eda/run_prompt_experiment.py | 111 +++ examples/pydantic_ai_eda/steps/__init__.py | 3 + .../steps/prompt_experiment.py | 635 ++++++++++++++++++ examples/pydantic_ai_eda/test_cases.json | 47 ++ 6 files changed, 857 insertions(+), 3 deletions(-) create mode 100644 examples/pydantic_ai_eda/pipelines/prompt_experiment_pipeline.py create mode 100644 examples/pydantic_ai_eda/run_prompt_experiment.py create mode 100644 examples/pydantic_ai_eda/steps/prompt_experiment.py create mode 100644 examples/pydantic_ai_eda/test_cases.json diff --git a/examples/pydantic_ai_eda/pipelines/__init__.py b/examples/pydantic_ai_eda/pipelines/__init__.py index 8806b28c264..11952e656bc 100644 --- a/examples/pydantic_ai_eda/pipelines/__init__.py +++ b/examples/pydantic_ai_eda/pipelines/__init__.py @@ -1,12 +1,15 @@ -"""ZenML pipeline for Pydantic AI EDA workflow. +"""ZenML pipelines for Pydantic AI EDA workflows. -This module contains the pipeline definition for the EDA workflow: +This module contains pipeline definitions: -- eda_pipeline.py: Complete EDA pipeline with AI analysis and quality gates +- eda_pipeline.py: Complete EDA pipeline with AI analysis and quality gates +- prompt_experiment_pipeline.py: A/B testing pipeline for agent prompt optimization """ from .eda_pipeline import eda_pipeline +from .prompt_experiment_pipeline import prompt_experiment_pipeline __all__ = [ "eda_pipeline", + "prompt_experiment_pipeline", ] \ No newline at end of file diff --git a/examples/pydantic_ai_eda/pipelines/prompt_experiment_pipeline.py b/examples/pydantic_ai_eda/pipelines/prompt_experiment_pipeline.py new file mode 100644 index 00000000000..dc1ba0c0b76 --- /dev/null +++ b/examples/pydantic_ai_eda/pipelines/prompt_experiment_pipeline.py @@ -0,0 +1,55 @@ +"""Pipeline for experimenting with Pydantic AI agent prompts.""" + +from typing import Any, Dict, List, Optional + +from models import AgentConfig, DataSourceConfig +from steps import compare_agent_prompts, ingest_data + +from zenml import pipeline +from zenml.logger import get_logger + +logger = get_logger(__name__) + + +@pipeline +def prompt_experiment_pipeline( + source_config: DataSourceConfig, + prompt_variants: List[str], + agent_config: Optional[AgentConfig] = None, +) -> Dict[str, Any]: + """Pipeline for A/B testing Pydantic AI agent prompts. + + This pipeline helps developers optimize their agent prompts by testing + multiple variants on the same dataset and comparing performance metrics. + + Args: + source_config: Data source to test prompts against + prompt_variants: List of system prompts to compare + agent_config: Configuration for agent behavior during testing + + Returns: + Comprehensive comparison results with recommendations + """ + logger.info(f"๐Ÿงช Starting prompt experiment with {len(prompt_variants)} variants") + + # Step 1: Load the test dataset + dataset_df, ingestion_metadata = ingest_data(source_config=source_config) + + # Step 2: Run prompt comparison experiment + experiment_results = compare_agent_prompts( + dataset_df=dataset_df, + prompt_variants=prompt_variants, + agent_config=agent_config, + ) + + logger.info("โœ… Prompt experiment completed - check results for best performing variant") + + return { + "experiment_results": experiment_results, + "dataset_metadata": ingestion_metadata, + "test_config": { + "source": f"{source_config.source_type}:{source_config.source_path}", + "prompt_count": len(prompt_variants), + "agent_config": agent_config.model_dump() if agent_config else None, + }, + } \ No newline at end of file diff --git a/examples/pydantic_ai_eda/run_prompt_experiment.py b/examples/pydantic_ai_eda/run_prompt_experiment.py new file mode 100644 index 00000000000..2e40eaf2d15 --- /dev/null +++ b/examples/pydantic_ai_eda/run_prompt_experiment.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python3 +"""Run prompt experimentation pipeline to optimize Pydantic AI agents.""" + +import os +from models import AgentConfig, DataSourceConfig +from pipelines.prompt_experiment_pipeline import prompt_experiment_pipeline + + +def main(): + """Run prompt A/B testing to find the best system prompt.""" + print("๐Ÿงช Pydantic AI Prompt Experimentation") + print("=" * 40) + + # Check for API keys + has_openai = bool(os.getenv("OPENAI_API_KEY")) + has_anthropic = bool(os.getenv("ANTHROPIC_API_KEY")) + + if not (has_openai or has_anthropic): + print("โŒ No API keys found!") + print("Set OPENAI_API_KEY or ANTHROPIC_API_KEY environment variable") + return + + model_name = "gpt-4o-mini" if has_openai else "claude-3-haiku-20240307" + print(f"๐Ÿค– Using model: {model_name}") + + # Dataset configuration + source_config = DataSourceConfig( + source_type="hf", + source_path="scikit-learn/iris", + target_column="target", + ) + + agent_config = AgentConfig( + model_name=model_name, + max_tool_calls=4, # Reduce for faster testing + timeout_seconds=30, # Shorter timeout to avoid stalling + ) + + # Define prompt variants to test (simplified for speed) + prompt_variants = [ + # Variant 1: Concise + """You are a data analyst. Analyze the dataset quickly - focus on data quality score and key findings. Be concise.""", + + # Variant 2: Quality-focused + """You are a data quality specialist. Calculate data quality score, identify missing data and duplicates. Provide specific recommendations.""", + + # Variant 3: Business-oriented + """You are a business analyst. Is this data ready for ML? Provide go/no-go recommendation with quality score and business impact.""" + ] + + print(f"๐Ÿ“Š Testing {len(prompt_variants)} prompt variants on: {source_config.source_path}") + print("This will help identify the best performing prompt for your use case.\n") + + try: + pipeline_run = prompt_experiment_pipeline( + source_config=source_config, + prompt_variants=prompt_variants, + agent_config=agent_config, + ) + + # Extract results from ZenML pipeline artifacts + print("๐Ÿ“ˆ EXPERIMENT RESULTS") + print("=" * 25) + print("โœ… Pipeline completed successfully!") + + # Get the artifact from the pipeline run + run_metadata = pipeline_run.dict() + print(f"๐Ÿ” Pipeline run ID: {pipeline_run.id}") + print(f"๐Ÿ“Š Check ZenML dashboard for detailed experiment results") + print(f"๐Ÿ† Results are stored as pipeline artifacts") + + # Try to access the step outputs + try: + step_names = list(pipeline_run.steps.keys()) + print(f"๐Ÿ“‹ Pipeline steps: {step_names}") + + if "compare_agent_prompts" in step_names: + step_output = pipeline_run.steps["compare_agent_prompts"] + print(f"๐ŸŽฏ Experiment data available in step outputs") + + # Try to load the actual results + outputs = step_output.outputs + if "prompt_comparison_results" in outputs: + experiment_data = outputs["prompt_comparison_results"].load() + summary = experiment_data["experiment_summary"] + + print(f"โœ… Successful runs: {summary['successful_runs']}/{summary['total_prompts_tested']}") + print(f"๐Ÿ† Best prompt: {summary['best_prompt_variant']}") + print(f"โฑ๏ธ Average time: {summary['avg_execution_time']}s") + + print("\n๐Ÿ’ก RECOMMENDATIONS:") + for rec in experiment_data["recommendations"]: + print(f" โ€ข {rec}") + + except Exception as e: + print(f"โš ๏ธ Could not extract detailed results: {e}") + print("Check ZenML dashboard for full experiment analysis") + + print(f"\nโœ… Prompt experiment completed! Check ZenML dashboard for detailed results.") + return pipeline_run + + except Exception as e: + print(f"โŒ Experiment failed: {e}") + print("\nTroubleshooting:") + print("- Check your API key is valid") + print("- Ensure ZenML is initialized: zenml init") + print("- Install requirements: pip install -r requirements.txt") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/pydantic_ai_eda/steps/__init__.py b/examples/pydantic_ai_eda/steps/__init__.py index f90b91ef614..7be92fe9ac5 100644 --- a/examples/pydantic_ai_eda/steps/__init__.py +++ b/examples/pydantic_ai_eda/steps/__init__.py @@ -11,11 +11,14 @@ from .eda_agent import run_eda_agent from .ingest import ingest_data +from .prompt_experiment import compare_agent_prompts, evaluate_prompts_with_test_cases from .quality_gate import evaluate_quality_gate, evaluate_quality_gate_with_routing __all__ = [ "ingest_data", "run_eda_agent", + "compare_agent_prompts", + "evaluate_prompts_with_test_cases", "evaluate_quality_gate", "evaluate_quality_gate_with_routing", ] \ No newline at end of file diff --git a/examples/pydantic_ai_eda/steps/prompt_experiment.py b/examples/pydantic_ai_eda/steps/prompt_experiment.py new file mode 100644 index 00000000000..a4aeee5c267 --- /dev/null +++ b/examples/pydantic_ai_eda/steps/prompt_experiment.py @@ -0,0 +1,635 @@ +"""Advanced prompt experimentation step for Pydantic AI agent development.""" + +import json +import time +from pathlib import Path +from typing import Annotated, Any, Dict, List, Optional + +import pandas as pd +from pydantic_ai import Agent +from pydantic_ai.settings import ModelSettings + +from zenml import step +from zenml.logger import get_logger + +from models import AgentConfig, EDAReport +from steps.agent_tools import AGENT_TOOLS, AnalystAgentDeps + +logger = get_logger(__name__) + + +@step +def evaluate_prompts_with_test_cases( + dataset_df: pd.DataFrame, + prompt_variants: List[str], + test_cases_path: str = "test_cases.json", + agent_config: AgentConfig = None, + use_llm_judge: bool = True, +) -> Annotated[Dict[str, Any], "comprehensive_prompt_evaluation"]: + """Advanced prompt evaluation using structured test cases and LLM judge. + + This step implements best practices from AI evaluation methodology: + - Structured test cases with categories + - LLM judge evaluation for quality assessment + - Tool usage tracking and analysis + - Statistical significance testing + + Args: + dataset_df: Dataset to test prompts against + prompt_variants: List of system prompts to compare + test_cases_path: Path to JSON file with structured test cases + agent_config: Base configuration for all agents + use_llm_judge: Whether to use LLM for response quality evaluation + + Returns: + Comprehensive evaluation results with quality scores and recommendations + """ + if agent_config is None: + agent_config = AgentConfig() + + # Load structured test cases + test_cases = _load_test_cases(test_cases_path) + if not test_cases: + # Fallback to simple evaluation if no test cases + return compare_agent_prompts(dataset_df, prompt_variants, agent_config) + + logger.info(f"๐Ÿงช Running comprehensive evaluation: {len(prompt_variants)} prompts ร— {len(test_cases)} test cases") + + # Initialize LLM judge if requested + llm_judge = None + if use_llm_judge: + llm_judge = _create_llm_judge(agent_config.model_name) + + all_results = [] + + # Test each prompt variant + for prompt_idx, system_prompt in enumerate(prompt_variants): + prompt_id = f"variant_{prompt_idx + 1}" + logger.info(f"Testing {prompt_id}/{len(prompt_variants)}") + + prompt_results = [] + + # Run each test case for this prompt + for test_case in test_cases: + case_result = _run_single_test_case( + dataset_df, system_prompt, test_case, agent_config, llm_judge, prompt_id + ) + prompt_results.append(case_result) + + # Aggregate results for this prompt + prompt_summary = _analyze_prompt_performance(prompt_results, prompt_id, system_prompt) + all_results.append(prompt_summary) + + # Generate comprehensive comparison + final_analysis = _generate_comprehensive_analysis(all_results, test_cases) + + return final_analysis + + +@step +def compare_agent_prompts( + dataset_df: pd.DataFrame, + prompt_variants: List[str], + agent_config: AgentConfig = None, +) -> Annotated[Dict[str, Any], "prompt_comparison_results"]: + """Test multiple system prompts on the same dataset for agent optimization. + + This step helps developers iterate on agent prompts by running A/B tests + and comparing quality, performance, and output characteristics. + + Args: + dataset_df: Dataset to test prompts against + prompt_variants: List of system prompts to compare + agent_config: Base configuration for all agents + + Returns: + Comparison report with metrics for each prompt variant + """ + if agent_config is None: + agent_config = AgentConfig() + + logger.info(f"๐Ÿงช Testing {len(prompt_variants)} prompt variants on dataset with {len(dataset_df)} rows") + + results = [] + + for i, system_prompt in enumerate(prompt_variants): + prompt_id = f"variant_{i+1}" + logger.info(f"๐Ÿ”„ Testing prompt variant {i+1}/{len(prompt_variants)}") + print(f"๐Ÿ”„ Testing prompt variant {i+1}/{len(prompt_variants)}") + + start_time = time.time() + + try: + # Initialize fresh dependencies for each test + deps = AnalystAgentDeps() + main_ref = deps.store(dataset_df) + + print(f" ๐Ÿค– Creating agent for variant {i+1}...") + # Create agent with this prompt variant + test_agent = Agent( + f"openai:{agent_config.model_name}", + deps_type=AnalystAgentDeps, + output_type=EDAReport, + output_retries=3, + system_prompt=system_prompt, + model_settings=ModelSettings(parallel_tool_calls=False), # Disable parallel to avoid hanging + ) + + # Register tools + print(f" ๐Ÿ”ง Registering {len(AGENT_TOOLS)} tools...") + for tool in AGENT_TOOLS: + test_agent.tool(tool) + + # Run analysis with consistent user prompt + user_prompt = f"""Analyze dataset '{main_ref}' ({dataset_df.shape[0]} rows, {dataset_df.shape[1]} cols). + +Focus on data quality, key patterns, and actionable insights.""" + + print(f" โšก Running analysis for variant {i+1}...") + result = test_agent.run_sync(user_prompt, deps=deps) + eda_report = result.output + + execution_time = time.time() - start_time + + # Collect metrics for comparison + result_metrics = { + "prompt_id": prompt_id, + "system_prompt": system_prompt, + "success": True, + "execution_time_seconds": round(execution_time, 2), + "tool_calls_made": len(deps.query_history), + "data_quality_score": eda_report.data_quality_score, + "key_findings_count": len(eda_report.key_findings), + "risks_identified": len(eda_report.risks), + "fixes_suggested": len(eda_report.fixes), + "correlation_insights": len(eda_report.correlation_insights), + "headline_length": len(eda_report.headline), + "markdown_length": len(eda_report.markdown), + "tables_generated": len(deps.output) - 1, # Exclude main dataset + "error": None, + } + + logger.info(f"โœ… Variant {i+1}: Score={eda_report.data_quality_score:.1f}, Tools={len(deps.query_history)}, Time={execution_time:.1f}s") + + except Exception as e: + execution_time = time.time() - start_time + logger.warning(f"โŒ Variant {i+1} failed: {str(e)}") + + result_metrics = { + "prompt_id": prompt_id, + "system_prompt": system_prompt, + "success": False, + "execution_time_seconds": round(execution_time, 2), + "tool_calls_made": 0, + "data_quality_score": 0.0, + "key_findings_count": 0, + "risks_identified": 0, + "fixes_suggested": 0, + "correlation_insights": 0, + "headline_length": 0, + "markdown_length": 0, + "tables_generated": 0, + "error": str(e), + } + + results.append(result_metrics) + + # Analyze results and determine best variant + successful_results = [r for r in results if r["success"]] + + if successful_results: + # Rank by composite score (quality + speed + thoroughness) + for result in successful_results: + # Composite score: 60% quality + 20% speed + 20% thoroughness + speed_score = max(0, 100 - result["execution_time_seconds"] * 10) # Penalty for slowness + thoroughness_score = (result["key_findings_count"] * 10 + + result["risks_identified"] * 5 + + result["tables_generated"] * 5) + + result["composite_score"] = ( + result["data_quality_score"] * 0.6 + + min(speed_score, 100) * 0.2 + + min(thoroughness_score, 100) * 0.2 + ) + + # Find best performer + best_result = max(successful_results, key=lambda x: x["composite_score"]) + best_prompt_id = best_result["prompt_id"] + + logger.info(f"๐Ÿ† Best performing prompt: {best_prompt_id} (score: {best_result['composite_score']:.1f})") + else: + best_prompt_id = "none" + logger.warning("โŒ No prompts succeeded") + + # Generate summary comparison + summary = { + "experiment_summary": { + "total_prompts_tested": len(prompt_variants), + "successful_runs": len(successful_results), + "failed_runs": len(results) - len(successful_results), + "best_prompt_variant": best_prompt_id, + "avg_execution_time": round(sum(r["execution_time_seconds"] for r in results) / len(results), 2) if results else 0, + "dataset_info": { + "rows": len(dataset_df), + "columns": len(dataset_df.columns), + "column_names": list(dataset_df.columns), + } + }, + "detailed_results": results, + "recommendations": _generate_prompt_recommendations(results), + } + + return summary + + +def _generate_prompt_recommendations(results: List[Dict[str, Any]]) -> List[str]: + """Generate recommendations based on prompt experiment results.""" + recommendations = [] + + successful_results = [r for r in results if r["success"]] + + if not successful_results: + return ["All prompts failed - check agent configuration and error messages"] + + # Performance analysis + avg_time = sum(r["execution_time_seconds"] for r in successful_results) / len(successful_results) + avg_quality = sum(r["data_quality_score"] for r in successful_results) / len(successful_results) + + if avg_time > 30: + recommendations.append("Consider shorter, more focused prompts to reduce execution time") + + if avg_quality < 70: + recommendations.append("Consider more specific instructions for data quality assessment") + + # Consistency analysis + quality_scores = [r["data_quality_score"] for r in successful_results] + quality_variance = max(quality_scores) - min(quality_scores) + + if quality_variance > 20: + recommendations.append("High variance in quality scores - consider more consistent prompt structure") + else: + recommendations.append("Quality scores are consistent across prompts - good stability") + + # Tool usage analysis + tool_calls = [r["tool_calls_made"] for r in successful_results] + if max(tool_calls) - min(tool_calls) > 3: + recommendations.append("Variable tool usage detected - consider standardizing analysis steps") + + # Success rate analysis + success_rate = len(successful_results) / len(results) * 100 + if success_rate < 80: + recommendations.append("Low success rate - review prompt complexity and error handling") + else: + recommendations.append(f"Good success rate ({success_rate:.0f}%) - prompts are robust") + + return recommendations + + +def _load_test_cases(test_cases_path: str) -> List[Dict[str, Any]]: + """Load structured test cases from JSON file.""" + try: + test_cases_file = Path(test_cases_path) + if not test_cases_file.exists(): + logger.warning(f"Test cases file not found: {test_cases_path}") + return [] + + with open(test_cases_file, 'r') as f: + test_cases = json.load(f) + + logger.info(f"Loaded {len(test_cases)} test cases from {test_cases_path}") + return test_cases + + except Exception as e: + logger.error(f"Failed to load test cases: {e}") + return [] + + +def _create_llm_judge(base_model: str) -> Optional[Agent]: + """Create an LLM judge agent for evaluating responses.""" + try: + # Use a stronger model for evaluation if available + judge_model = "gpt-4o" if "gpt" in base_model else "claude-3-5-sonnet-20241022" + + judge_system_prompt = """You are an expert AI evaluator specializing in data analysis responses. + +Your job is to assess EDA (Exploratory Data Analysis) responses based on: + +1. **Accuracy** (1-5): Factual correctness and valid statistical insights +2. **Relevance** (1-5): How well the response addresses the specific query +3. **Completeness** (1-5): Whether all aspects of the query are covered +4. **Tool Usage** (1-5): Appropriate use of available analysis tools +5. **Actionability** (1-5): Quality of recommendations and insights + +Score each criterion from 1 (poor) to 5 (excellent). +Provide scores in JSON format: {"accuracy": X, "relevance": X, "completeness": X, "tool_usage": X, "actionability": X, "overall": X, "reasoning": "brief explanation"} + +Be objective and consistent in your evaluations.""" + + return Agent( + f"openai:{judge_model}" if "gpt" in base_model else f"anthropic:{judge_model}", + system_prompt=judge_system_prompt, + ) + + except Exception as e: + logger.warning(f"Failed to create LLM judge: {e}") + return None + + +def _run_single_test_case( + dataset_df: pd.DataFrame, + system_prompt: str, + test_case: Dict[str, Any], + agent_config: AgentConfig, + llm_judge: Optional[Agent], + prompt_id: str, +) -> Dict[str, Any]: + """Run a single test case and collect comprehensive metrics.""" + start_time = time.time() + + try: + # Initialize agent for this test + deps = AnalystAgentDeps() + main_ref = deps.store(dataset_df) + + test_agent = Agent( + f"openai:{agent_config.model_name}", + deps_type=AnalystAgentDeps, + output_type=EDAReport, + output_retries=3, + system_prompt=system_prompt, + model_settings=ModelSettings(parallel_tool_calls=True), + ) + + # Register tools + for tool in AGENT_TOOLS: + test_agent.tool(tool) + + # Run the test case query + full_query = f"Dataset reference: {main_ref}\n\n{test_case['query']}" + result = test_agent.run_sync(full_query, deps=deps) + eda_report = result.output + + execution_time = time.time() - start_time + + # Collect basic metrics + case_result = { + "test_id": test_case["id"], + "category": test_case.get("category", "general"), + "prompt_id": prompt_id, + "query": test_case["query"], + "success": True, + "execution_time": execution_time, + "tool_calls_made": len(deps.query_history), + "response": str(eda_report.markdown), + "data_quality_score": eda_report.data_quality_score, + "findings_count": len(eda_report.key_findings), + "risks_count": len(eda_report.risks), + "recommendations_count": len(eda_report.fixes), + "error": None, + } + + # Evaluate with LLM judge if available + if llm_judge: + judge_evaluation = _get_llm_judge_scores(llm_judge, test_case, case_result) + case_result.update(judge_evaluation) + + # Check against expected metrics if available + expected_metrics = test_case.get("expected_metrics", {}) + case_result["meets_expectations"] = _check_expectations(case_result, expected_metrics) + + return case_result + + except Exception as e: + execution_time = time.time() - start_time + return { + "test_id": test_case["id"], + "category": test_case.get("category", "general"), + "prompt_id": prompt_id, + "query": test_case["query"], + "success": False, + "execution_time": execution_time, + "error": str(e), + "meets_expectations": False, + } + + +def _get_llm_judge_scores(llm_judge: Agent, test_case: Dict, case_result: Dict) -> Dict: + """Get quality scores from LLM judge.""" + try: + eval_prompt = f""" +Query: {test_case['query']} +Category: {test_case.get('category', 'general')} + +Response to evaluate: +{case_result['response'][:2000]}... + +Data Quality Score Provided: {case_result['data_quality_score']} +Tool Calls Made: {case_result['tool_calls_made']} +Findings Count: {case_result['findings_count']} + +Please evaluate this EDA response and provide scores in the requested JSON format.""" + + judge_response = llm_judge.run_sync(eval_prompt) + + # Parse JSON response + try: + scores = json.loads(str(judge_response.output)) + return { + "judge_accuracy": scores.get("accuracy", 0), + "judge_relevance": scores.get("relevance", 0), + "judge_completeness": scores.get("completeness", 0), + "judge_tool_usage": scores.get("tool_usage", 0), + "judge_actionability": scores.get("actionability", 0), + "judge_overall": scores.get("overall", 0), + "judge_reasoning": scores.get("reasoning", ""), + } + except json.JSONDecodeError: + logger.warning("LLM judge response was not valid JSON") + return {"judge_overall": 3, "judge_reasoning": "Failed to parse judge response"} + + except Exception as e: + logger.warning(f"LLM judge evaluation failed: {e}") + return {"judge_overall": 3, "judge_reasoning": f"Evaluation error: {e}"} + + +def _check_expectations(case_result: Dict, expected_metrics: Dict) -> bool: + """Check if results meet expected criteria.""" + if not expected_metrics: + return True + + checks = [] + + # Check minimum quality score + min_quality = expected_metrics.get("quality_score_min") + if min_quality: + checks.append(case_result.get("data_quality_score", 0) >= min_quality) + + # Check minimum recommendations + min_recs = expected_metrics.get("recommendations_min") + if min_recs: + checks.append(case_result.get("recommendations_count", 0) >= min_recs) + + # Check expected tool calls + expected_tools = expected_metrics.get("tool_calls_expected", []) + if expected_tools: + # This would need actual tool tracking - simplified for now + checks.append(case_result.get("tool_calls_made", 0) > 0) + + return all(checks) if checks else True + + +def _analyze_prompt_performance(results: List[Dict], prompt_id: str, system_prompt: str) -> Dict: + """Analyze performance across all test cases for a single prompt.""" + successful_results = [r for r in results if r["success"]] + + if not successful_results: + return { + "prompt_id": prompt_id, + "system_prompt": system_prompt, + "success_rate": 0, + "avg_scores": {}, + "category_performance": {}, + "overall_rating": "failed" + } + + # Calculate averages + avg_scores = { + "execution_time": sum(r["execution_time"] for r in successful_results) / len(successful_results), + "tool_calls": sum(r["tool_calls_made"] for r in successful_results) / len(successful_results), + "data_quality_score": sum(r["data_quality_score"] for r in successful_results) / len(successful_results), + "findings_count": sum(r["findings_count"] for r in successful_results) / len(successful_results), + } + + # Add LLM judge scores if available + judge_scores = [r for r in successful_results if "judge_overall" in r] + if judge_scores: + avg_scores.update({ + "judge_accuracy": sum(r["judge_accuracy"] for r in judge_scores) / len(judge_scores), + "judge_relevance": sum(r["judge_relevance"] for r in judge_scores) / len(judge_scores), + "judge_completeness": sum(r["judge_completeness"] for r in judge_scores) / len(judge_scores), + "judge_overall": sum(r["judge_overall"] for r in judge_scores) / len(judge_scores), + }) + + # Category-wise performance + categories = {} + for result in successful_results: + cat = result["category"] + if cat not in categories: + categories[cat] = {"count": 0, "avg_score": 0, "meets_expectations": 0} + categories[cat]["count"] += 1 + categories[cat]["avg_score"] += result["data_quality_score"] + if result.get("meets_expectations", False): + categories[cat]["meets_expectations"] += 1 + + for cat in categories: + categories[cat]["avg_score"] /= categories[cat]["count"] + categories[cat]["success_rate"] = categories[cat]["meets_expectations"] / categories[cat]["count"] + + # Overall rating + success_rate = len(successful_results) / len(results) + avg_judge_score = avg_scores.get("judge_overall", avg_scores.get("data_quality_score", 50) / 20) + + if success_rate >= 0.8 and avg_judge_score >= 4: + rating = "excellent" + elif success_rate >= 0.6 and avg_judge_score >= 3: + rating = "good" + elif success_rate >= 0.4: + rating = "acceptable" + else: + rating = "poor" + + return { + "prompt_id": prompt_id, + "system_prompt": system_prompt, + "success_rate": success_rate, + "avg_scores": avg_scores, + "category_performance": categories, + "overall_rating": rating, + "detailed_results": results, + } + + +def _generate_comprehensive_analysis(all_results: List[Dict], test_cases: List[Dict]) -> Dict: + """Generate final comprehensive analysis of all prompt variants.""" + # Rank prompts by overall performance + ranked_prompts = sorted(all_results, key=lambda x: ( + x["success_rate"] * 0.4 + + x["avg_scores"].get("judge_overall", 3) * 0.3 + + (100 - x["avg_scores"]["execution_time"]) / 100 * 0.2 + + x["avg_scores"]["data_quality_score"] / 100 * 0.1 + ), reverse=True) + + best_prompt = ranked_prompts[0] if ranked_prompts else None + + # Category analysis + category_insights = {} + for test_case in test_cases: + cat = test_case.get("category", "general") + if cat not in category_insights: + category_insights[cat] = {"prompt_performance": []} + + for prompt_result in all_results: + cat_perf = prompt_result["category_performance"].get(cat, {}) + if cat_perf: + category_insights[cat]["prompt_performance"].append({ + "prompt_id": prompt_result["prompt_id"], + "avg_score": cat_perf["avg_score"], + "success_rate": cat_perf["success_rate"] + }) + + return { + "evaluation_summary": { + "total_prompts_tested": len(all_results), + "total_test_cases": len(test_cases), + "best_prompt": best_prompt["prompt_id"] if best_prompt else "none", + "best_prompt_rating": best_prompt["overall_rating"] if best_prompt else "none", + "categories_tested": list(category_insights.keys()), + }, + "prompt_rankings": ranked_prompts, + "category_analysis": category_insights, + "detailed_results": all_results, + "recommendations": _generate_advanced_recommendations(all_results, category_insights), + } + + +def _generate_advanced_recommendations(all_results: List[Dict], category_insights: Dict) -> List[str]: + """Generate advanced recommendations based on comprehensive analysis.""" + recommendations = [] + + if not all_results: + return ["No successful evaluations - check agent configuration"] + + # Success rate analysis + avg_success_rate = sum(r["success_rate"] for r in all_results) / len(all_results) + if avg_success_rate < 0.6: + recommendations.append("Low overall success rate - consider simplifying prompts or checking tool integration") + + # Performance consistency + execution_times = [r["avg_scores"]["execution_time"] for r in all_results] + time_variance = max(execution_times) - min(execution_times) + if time_variance > 30: + recommendations.append("High variance in execution time - optimize slower prompts for efficiency") + + # Quality assessment + if any("judge_overall" in r["avg_scores"] for r in all_results): + judge_scores = [r["avg_scores"]["judge_overall"] for r in all_results if "judge_overall" in r["avg_scores"]] + avg_judge_score = sum(judge_scores) / len(judge_scores) + + if avg_judge_score < 3: + recommendations.append("LLM judge scores are low - review prompt clarity and specificity") + elif avg_judge_score > 4: + recommendations.append("Excellent LLM judge scores - prompts are producing high-quality responses") + + # Category-specific insights + for category, data in category_insights.items(): + if data["prompt_performance"]: + cat_scores = [p["avg_score"] for p in data["prompt_performance"]] + if min(cat_scores) < 60: + recommendations.append(f"'{category}' category shows low scores - consider specialized prompts for this use case") + + # Best practices + best_prompt = max(all_results, key=lambda x: x["success_rate"]) + if best_prompt["success_rate"] > 0.8: + recommendations.append(f"'{best_prompt['prompt_id']}' shows strong performance - consider it as your baseline") + + return recommendations \ No newline at end of file diff --git a/examples/pydantic_ai_eda/test_cases.json b/examples/pydantic_ai_eda/test_cases.json new file mode 100644 index 00000000000..27fae0b4c3b --- /dev/null +++ b/examples/pydantic_ai_eda/test_cases.json @@ -0,0 +1,47 @@ +[ + { + "id": "iris-quality-1", + "query": "Analyze the iris dataset for data quality issues, missing values, and overall readiness for machine learning.", + "category": "data_quality", + "expected_metrics": { + "should_identify": ["no missing values", "clean dataset", "high quality score"], + "quality_score_min": 85 + } + }, + { + "id": "iris-distribution-1", + "query": "Examine the distribution of features in this dataset and identify any patterns or anomalies.", + "category": "distribution_analysis", + "expected_metrics": { + "should_identify": ["sepal/petal measurements", "species distribution", "correlation patterns"], + "tool_calls_expected": ["describe", "analyze_correlations"] + } + }, + { + "id": "iris-ml-readiness-1", + "query": "Is this dataset ready for training a classification model? What preprocessing steps are needed?", + "category": "ml_readiness", + "expected_metrics": { + "should_identify": ["classification target", "feature scaling", "data preparation"], + "quality_score_min": 80 + } + }, + { + "id": "iris-edge-case-1", + "query": "Find any outliers, duplicates, or data inconsistencies that could impact model performance.", + "category": "edge_cases", + "expected_metrics": { + "should_identify": ["outlier analysis", "duplicate check", "consistency validation"], + "tool_calls_expected": ["run_sql"] + } + }, + { + "id": "iris-business-1", + "query": "From a business perspective, what insights can you extract from this data and what are the key recommendations?", + "category": "business_insights", + "expected_metrics": { + "should_identify": ["actionable insights", "business recommendations", "data value assessment"], + "recommendations_min": 2 + } + } +] \ No newline at end of file From 008bd178c4d46af6ba23e4ffe972e43b090e41b6 Mon Sep 17 00:00:00 2001 From: Hamza Tahir Date: Fri, 22 Aug 2025 10:52:00 +0200 Subject: [PATCH 05/14] Update Pydantic AI EDA pipeline README --- examples/pydantic_ai_eda/README.md | 343 ++---------- .../pipelines/prompt_experiment_pipeline.py | 30 +- .../pydantic_ai_eda/run_prompt_experiment.py | 57 +- .../steps/prompt_experiment.py | 501 +++++++++++------- 4 files changed, 407 insertions(+), 524 deletions(-) diff --git a/examples/pydantic_ai_eda/README.md b/examples/pydantic_ai_eda/README.md index 824be3fce23..59835b1f02a 100644 --- a/examples/pydantic_ai_eda/README.md +++ b/examples/pydantic_ai_eda/README.md @@ -1,343 +1,80 @@ # Pydantic AI EDA Pipeline -This example demonstrates how to build an AI-powered Exploratory Data Analysis (EDA) pipeline using **ZenML** and **Pydantic AI**. The pipeline automatically analyzes datasets, generates comprehensive reports, and makes data quality decisions for downstream processing. +AI-powered Exploratory Data Analysis pipeline using **ZenML** and **Pydantic AI**. Automatically analyzes datasets, generates reports, and makes quality decisions for downstream processing. ## Architecture ``` -ingest โ†’ eda_agent โ†’ quality_gate โ†’ routing +ingest_data โ†’ run_eda_agent โ†’ evaluate_quality_gate_with_routing ``` -## Key Features +## Features -- **๐Ÿค– AI-Powered Analysis**: Uses Pydantic AI with GPT-4 or Claude for intelligent data exploration -- **๐Ÿ“Š SQL-Based EDA**: Agent performs analysis through DuckDB SQL queries with safety guards -- **โœ… Quality Gates**: Automated data quality assessment with configurable thresholds -- **๐ŸŒ Multiple Data Sources**: Support for HuggingFace, local files, and data warehouses -- **๐Ÿ“ˆ Comprehensive Reporting**: Structured JSON reports and human-readable markdown - -## What's Included - -### Pipeline Steps -- **`ingest_data`**: Load data from HuggingFace, local files, or warehouses -- **`run_eda_agent`**: AI agent performs comprehensive EDA using SQL analysis -- **`evaluate_quality_gate`**: Assess data quality against configurable thresholds -- **`route_based_on_quality`**: Make pipeline routing decisions based on quality - -### AI Agent Capabilities -- Statistical analysis and profiling -- Missing data pattern detection -- Correlation analysis -- Outlier identification -- Data quality scoring (0-100) -- Actionable remediation recommendations -- SQL query logging for reproducibility - -### CLI Interface -- **Command-line Runner**: Easy execution with various configuration options -- **Quality Assessment**: Quick quality checks without full analysis -- **Multiple Output Formats**: JSON, CSV, and text reporting +- ๐Ÿค– **AI-Powered Analysis** with GPT-4/Claude +- ๐Ÿ“Š **SQL-Based EDA** through DuckDB with safety guards +- โœ… **Quality Gates** with configurable thresholds +- ๐ŸŒ **Multiple Data Sources** (HuggingFace, local files, warehouses) +- ๐Ÿ“ˆ **Comprehensive Reporting** (JSON/markdown) ## Quick Start -### Prerequisites - -```bash -pip install "zenml[server]" -zenml init -``` - -### Install Dependencies - -```bash -git clone https://github.com/zenml-io/zenml.git -cd zenml/examples/pydantic_ai_eda -pip install -r requirements.txt -``` - -### Set API Keys - -```bash -# For OpenAI (recommended) -export OPENAI_API_KEY="your-openai-key" - -# Or for Anthropic -export ANTHROPIC_API_KEY="your-anthropic-key" -``` - -### Quick Example - ```bash -# Run simple example -python example.py -``` - -### CLI Usage - -```bash -# Analyze HuggingFace dataset -python run_pipeline.py --source-type hf --source-path "scikit-learn/adult-census-income" --target-column "class" +# Install +pip install "zenml[server]" && zenml init +cd zenml/examples/pydantic_ai_eda && pip install -r requirements.txt -# Analyze local file -python run_pipeline.py --source-type local --source-path "/path/to/data.csv" --target-column "target" +# Set API key +export OPENAI_API_KEY="your-openai-key" # or ANTHROPIC_API_KEY -# Quality-only assessment -python run_quality_check.py --source-path "/path/to/data.csv" --min-quality-score 80 +# Run examples +python run.py +python run_prompt_experiment.py ``` -## Example Usage - -### Python API +## Usage ```python from models import DataSourceConfig, AgentConfig from pipelines.eda_pipeline import eda_pipeline -# Configure data source -source_config = DataSourceConfig( - source_type="hf", - source_path="scikit-learn/adult-census-income", - target_column="class", - sample_size=10000 -) - -# Configure AI agent -agent_config = AgentConfig( - model_name="gpt-5", - max_tool_calls=50, - sql_guard_enabled=True -) - -# Run pipeline -results = eda_pipeline( - source_config=source_config, - agent_config=agent_config, - min_quality_score=70.0 -) - -print(f"Quality Score: {results['quality_decision'].quality_score}") -print(f"Quality Gate: {'PASSED' if results['quality_decision'].passed else 'FAILED'}") -``` - -## Pipeline Configuration - -### Data Sources - -**HuggingFace Datasets:** -```python -source_config = DataSourceConfig( - source_type="hf", - source_path="scikit-learn/adult-census-income", - sampling_strategy="random", - sample_size=50000 -) -``` - -**Local Files:** -```python +# EDA Analysis source_config = DataSourceConfig( source_type="local", - source_path="/path/to/data.csv", - target_column="target" + source_path="iris_dataset.csv", + target_column="species" ) -``` -**Data Warehouses:** -```python -source_config = DataSourceConfig( - source_type="warehouse", - source_path="SELECT * FROM customer_data LIMIT 100000", - warehouse_config={ - "type": "bigquery", - "project_id": "my-project" - } -) -``` +results = eda_pipeline(source_config=source_config) -### AI Agent Configuration +# Prompt Experimentation +from pipelines.prompt_experiment_pipeline import prompt_experiment_pipeline -```python -agent_config = AgentConfig( - model_name="gpt-5", # or "claude-4" - max_tool_calls=100, - sql_guard_enabled=True, - preview_limit=20, - timeout_seconds=600 +prompts = ["Analyze this data", "Provide detailed insights"] +experiment = prompt_experiment_pipeline( + source_config=source_config, + prompt_variants=prompts ) ``` -### Quality Gate Thresholds +## Output -```python -quality_decision = evaluate_quality_gate( - report_json=report, - min_quality_score=75.0, - block_on_high_severity=True, - max_missing_data_pct=25.0, - require_target_column=True -) -``` - -## Analysis Outputs - -### EDA Report Structure -```json -{ - "headline": "Dataset contains 32,561 rows with moderate data quality issues", - "key_findings": [ - "Found 6 numeric columns suitable for quantitative analysis", - "Missing data is 7.3% overall, within acceptable range", - "Strong correlation detected between age and hours-per-week (0.89)" - ], - "risks": ["Potential class imbalance in target variable"], - "fixes": [ - { - "title": "Address missing values in workclass column", - "severity": "medium", - "code_snippet": "df['workclass'].fillna(df['workclass'].mode()[0])", - "estimated_impact": 0.15 - } - ], - "data_quality_score": 78.5, - "correlation_insights": [...], - "missing_data_analysis": {...} -} -``` - -### Quality Gate Decision -```json -{ - "passed": true, - "quality_score": 78.5, - "decision_reason": "All quality checks passed", - "blocking_issues": [], - "recommendations": [ - "Data quality is acceptable for downstream processing", - "Consider implementing monitoring for quality regression" - ] -} -``` - -## Data Security - -### Quality Configuration -```python -# Configure quality thresholds -results = eda_pipeline( - source_config=source_config, - min_quality_score=80.0, - max_missing_data_pct=15.0 -) -``` - -### SQL Safety Guards -- Only `SELECT` and `WITH` statements allowed -- Prohibited operations: `DROP`, `DELETE`, `INSERT`, `UPDATE` -- Auto-injection of `LIMIT` clauses for large result sets -- Query logging for full auditability - -## Production Deployment - -### Remote Orchestration -```python -# Configure ZenML stack for cloud deployment -zenml stack register remote_stack \ - --orchestrator=kubernetes \ - --artifact_store=s3 \ - --container_registry=ecr - -# Run with remote stack -zenml stack set remote_stack -python run_pipeline.py --source-path "s3://my-bucket/data.parquet" -``` - -### Monitoring & Alerts -- Pipeline execution tracking via ZenML dashboard -- Quality gate failure notifications -- Data drift detection capabilities -- Token usage and cost monitoring - -## Examples Gallery - -### Customer Segmentation Analysis -```bash -python run_pipeline.py \ - --source-type hf \ - --source-path "scikit-learn/adult-census-income" \ - --target-column "class" \ - --min-quality-score 80 -``` +The pipeline generates: +- **EDA Report**: Statistical analysis, correlations, missing data patterns, quality score (0-100) +- **Quality Gate**: Pass/fail decision with recommendations +- **Remediation**: Actionable code snippets for data issues -### Financial Risk Assessment -```bash -python run_pipeline.py \ - --source-type local \ - --source-path "financial_data.csv" \ - --min-quality-score 90 \ - --require-target-column \ - --target-column "risk_score" -``` - -### Time Series Data Quality Check -```bash -python run_quality_check.py \ - --source-path "time_series.parquet" \ - --max-missing-data-pct 10 \ - --require-target-column \ - --target-column "value" -``` - -## Advanced Features +## Security & Production -### Custom Data Warehouses -Support for BigQuery, Snowflake, Redshift, and generic SQL connections. - -### Multi-Model Analysis -Switch between OpenAI GPT-4, Anthropic Claude, and other providers. - -### Pipeline Caching -Automatic caching of expensive operations for faster iterations. - -### Artifact Lineage -Full traceability of data transformations and analysis steps. +- **SQL Safety**: Only SELECT/WITH queries allowed, auto-LIMIT injection +- **Remote Orchestration**: Kubernetes, S3, ECR support via ZenML stacks +- **Monitoring**: Pipeline tracking, quality alerts, cost monitoring ## Troubleshooting -### Common Issues - -**Missing API Keys:** ```bash +# Common fixes +pip install duckdb>=1.0.0 pydantic-ai>=0.0.13 export OPENAI_API_KEY="your-key" -# or -export ANTHROPIC_API_KEY="your-key" -``` - -**DuckDB Import Errors:** -```bash -pip install duckdb>=1.0.0 -``` - -**Pydantic AI Installation:** -```bash -pip install pydantic-ai>=0.0.13 -``` - -**Large Dataset Memory Issues:** -- Reduce `sample_size` in DataSourceConfig -- Use `enable_masking=True` to reduce memory footprint -- Consider using `quality_only_pipeline` for quick checks - -### Performance Optimization - -- Use `gpt-4o-mini` instead of `gpt-5` for faster analysis -- Limit `max_tool_calls` for time-constrained scenarios -- Enable snapshot caching for repeated analysis -- Use stratified sampling for large datasets - -## Contributing - -This example demonstrates the integration patterns between ZenML and Pydantic AI. Contributions for additional data sources, quality checks, and analysis capabilities are welcome. - -## License -This example is part of the ZenML project and follows the Apache 2.0 license. \ No newline at end of file +# Performance: use gpt-4o-mini, reduce sample_size for large datasets +``` \ No newline at end of file diff --git a/examples/pydantic_ai_eda/pipelines/prompt_experiment_pipeline.py b/examples/pydantic_ai_eda/pipelines/prompt_experiment_pipeline.py index dc1ba0c0b76..73ec28558c2 100644 --- a/examples/pydantic_ai_eda/pipelines/prompt_experiment_pipeline.py +++ b/examples/pydantic_ai_eda/pipelines/prompt_experiment_pipeline.py @@ -18,38 +18,44 @@ def prompt_experiment_pipeline( agent_config: Optional[AgentConfig] = None, ) -> Dict[str, Any]: """Pipeline for A/B testing Pydantic AI agent prompts. - + This pipeline helps developers optimize their agent prompts by testing multiple variants on the same dataset and comparing performance metrics. - + Args: - source_config: Data source to test prompts against + source_config: Data source to test prompts against prompt_variants: List of system prompts to compare agent_config: Configuration for agent behavior during testing - + Returns: Comprehensive comparison results with recommendations """ - logger.info(f"๐Ÿงช Starting prompt experiment with {len(prompt_variants)} variants") - + logger.info( + f"๐Ÿงช Starting prompt experiment with {len(prompt_variants)} variants" + ) + # Step 1: Load the test dataset dataset_df, ingestion_metadata = ingest_data(source_config=source_config) - + # Step 2: Run prompt comparison experiment experiment_results = compare_agent_prompts( dataset_df=dataset_df, prompt_variants=prompt_variants, agent_config=agent_config, ) - - logger.info("โœ… Prompt experiment completed - check results for best performing variant") - + + logger.info( + "โœ… Prompt experiment completed - check results for best performing variant" + ) + return { "experiment_results": experiment_results, "dataset_metadata": ingestion_metadata, "test_config": { "source": f"{source_config.source_type}:{source_config.source_path}", "prompt_count": len(prompt_variants), - "agent_config": agent_config.model_dump() if agent_config else None, + "agent_config": agent_config.model_dump() + if agent_config + else None, }, - } \ No newline at end of file + } diff --git a/examples/pydantic_ai_eda/run_prompt_experiment.py b/examples/pydantic_ai_eda/run_prompt_experiment.py index 2e40eaf2d15..c724f4e2182 100644 --- a/examples/pydantic_ai_eda/run_prompt_experiment.py +++ b/examples/pydantic_ai_eda/run_prompt_experiment.py @@ -2,6 +2,7 @@ """Run prompt experimentation pipeline to optimize Pydantic AI agents.""" import os + from models import AgentConfig, DataSourceConfig from pipelines.prompt_experiment_pipeline import prompt_experiment_pipeline @@ -10,16 +11,16 @@ def main(): """Run prompt A/B testing to find the best system prompt.""" print("๐Ÿงช Pydantic AI Prompt Experimentation") print("=" * 40) - + # Check for API keys has_openai = bool(os.getenv("OPENAI_API_KEY")) has_anthropic = bool(os.getenv("ANTHROPIC_API_KEY")) - + if not (has_openai or has_anthropic): print("โŒ No API keys found!") print("Set OPENAI_API_KEY or ANTHROPIC_API_KEY environment variable") return - + model_name = "gpt-4o-mini" if has_openai else "claude-3-haiku-20240307" print(f"๐Ÿค– Using model: {model_name}") @@ -40,16 +41,18 @@ def main(): prompt_variants = [ # Variant 1: Concise """You are a data analyst. Analyze the dataset quickly - focus on data quality score and key findings. Be concise.""", - - # Variant 2: Quality-focused + # Variant 2: Quality-focused """You are a data quality specialist. Calculate data quality score, identify missing data and duplicates. Provide specific recommendations.""", - # Variant 3: Business-oriented - """You are a business analyst. Is this data ready for ML? Provide go/no-go recommendation with quality score and business impact.""" + """You are a business analyst. Is this data ready for ML? Provide go/no-go recommendation with quality score and business impact.""", ] - print(f"๐Ÿ“Š Testing {len(prompt_variants)} prompt variants on: {source_config.source_path}") - print("This will help identify the best performing prompt for your use case.\n") + print( + f"๐Ÿ“Š Testing {len(prompt_variants)} prompt variants on: {source_config.source_path}" + ) + print( + "This will help identify the best performing prompt for your use case.\n" + ) try: pipeline_run = prompt_experiment_pipeline( @@ -57,55 +60,61 @@ def main(): prompt_variants=prompt_variants, agent_config=agent_config, ) - + # Extract results from ZenML pipeline artifacts print("๐Ÿ“ˆ EXPERIMENT RESULTS") print("=" * 25) print("โœ… Pipeline completed successfully!") - + # Get the artifact from the pipeline run run_metadata = pipeline_run.dict() print(f"๐Ÿ” Pipeline run ID: {pipeline_run.id}") print(f"๐Ÿ“Š Check ZenML dashboard for detailed experiment results") print(f"๐Ÿ† Results are stored as pipeline artifacts") - + # Try to access the step outputs try: step_names = list(pipeline_run.steps.keys()) print(f"๐Ÿ“‹ Pipeline steps: {step_names}") - + if "compare_agent_prompts" in step_names: step_output = pipeline_run.steps["compare_agent_prompts"] print(f"๐ŸŽฏ Experiment data available in step outputs") - + # Try to load the actual results outputs = step_output.outputs if "prompt_comparison_results" in outputs: - experiment_data = outputs["prompt_comparison_results"].load() + experiment_data = outputs[ + "prompt_comparison_results" + ].load() summary = experiment_data["experiment_summary"] - - print(f"โœ… Successful runs: {summary['successful_runs']}/{summary['total_prompts_tested']}") + + print( + f"โœ… Successful runs: {summary['successful_runs']}/{summary['total_prompts_tested']}" + ) print(f"๐Ÿ† Best prompt: {summary['best_prompt_variant']}") print(f"โฑ๏ธ Average time: {summary['avg_execution_time']}s") - + print("\n๐Ÿ’ก RECOMMENDATIONS:") for rec in experiment_data["recommendations"]: print(f" โ€ข {rec}") - + except Exception as e: print(f"โš ๏ธ Could not extract detailed results: {e}") print("Check ZenML dashboard for full experiment analysis") - - print(f"\nโœ… Prompt experiment completed! Check ZenML dashboard for detailed results.") + + print( + f"\nโœ… Prompt experiment completed! Check ZenML dashboard for detailed results." + ) return pipeline_run - + except Exception as e: print(f"โŒ Experiment failed: {e}") print("\nTroubleshooting:") print("- Check your API key is valid") - print("- Ensure ZenML is initialized: zenml init") + print("- Ensure ZenML is initialized: zenml init") print("- Install requirements: pip install -r requirements.txt") if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/examples/pydantic_ai_eda/steps/prompt_experiment.py b/examples/pydantic_ai_eda/steps/prompt_experiment.py index a4aeee5c267..d66ae9b9e86 100644 --- a/examples/pydantic_ai_eda/steps/prompt_experiment.py +++ b/examples/pydantic_ai_eda/steps/prompt_experiment.py @@ -6,19 +6,18 @@ from typing import Annotated, Any, Dict, List, Optional import pandas as pd +from models import AgentConfig, EDAReport from pydantic_ai import Agent from pydantic_ai.settings import ModelSettings +from steps.agent_tools import AGENT_TOOLS, AnalystAgentDeps from zenml import step from zenml.logger import get_logger -from models import AgentConfig, EDAReport -from steps.agent_tools import AGENT_TOOLS, AnalystAgentDeps - logger = get_logger(__name__) -@step +@step def evaluate_prompts_with_test_cases( dataset_df: pd.DataFrame, prompt_variants: List[str], @@ -27,62 +26,71 @@ def evaluate_prompts_with_test_cases( use_llm_judge: bool = True, ) -> Annotated[Dict[str, Any], "comprehensive_prompt_evaluation"]: """Advanced prompt evaluation using structured test cases and LLM judge. - + This step implements best practices from AI evaluation methodology: - Structured test cases with categories - LLM judge evaluation for quality assessment - Tool usage tracking and analysis - Statistical significance testing - + Args: dataset_df: Dataset to test prompts against prompt_variants: List of system prompts to compare test_cases_path: Path to JSON file with structured test cases agent_config: Base configuration for all agents use_llm_judge: Whether to use LLM for response quality evaluation - + Returns: Comprehensive evaluation results with quality scores and recommendations """ if agent_config is None: agent_config = AgentConfig() - + # Load structured test cases test_cases = _load_test_cases(test_cases_path) if not test_cases: # Fallback to simple evaluation if no test cases return compare_agent_prompts(dataset_df, prompt_variants, agent_config) - - logger.info(f"๐Ÿงช Running comprehensive evaluation: {len(prompt_variants)} prompts ร— {len(test_cases)} test cases") - + + logger.info( + f"๐Ÿงช Running comprehensive evaluation: {len(prompt_variants)} prompts ร— {len(test_cases)} test cases" + ) + # Initialize LLM judge if requested llm_judge = None if use_llm_judge: llm_judge = _create_llm_judge(agent_config.model_name) - + all_results = [] - + # Test each prompt variant for prompt_idx, system_prompt in enumerate(prompt_variants): prompt_id = f"variant_{prompt_idx + 1}" logger.info(f"Testing {prompt_id}/{len(prompt_variants)}") - + prompt_results = [] - + # Run each test case for this prompt for test_case in test_cases: case_result = _run_single_test_case( - dataset_df, system_prompt, test_case, agent_config, llm_judge, prompt_id + dataset_df, + system_prompt, + test_case, + agent_config, + llm_judge, + prompt_id, ) prompt_results.append(case_result) - + # Aggregate results for this prompt - prompt_summary = _analyze_prompt_performance(prompt_results, prompt_id, system_prompt) + prompt_summary = _analyze_prompt_performance( + prompt_results, prompt_id, system_prompt + ) all_results.append(prompt_summary) - + # Generate comprehensive comparison final_analysis = _generate_comprehensive_analysis(all_results, test_cases) - + return final_analysis @@ -93,38 +101,42 @@ def compare_agent_prompts( agent_config: AgentConfig = None, ) -> Annotated[Dict[str, Any], "prompt_comparison_results"]: """Test multiple system prompts on the same dataset for agent optimization. - + This step helps developers iterate on agent prompts by running A/B tests and comparing quality, performance, and output characteristics. - + Args: dataset_df: Dataset to test prompts against prompt_variants: List of system prompts to compare agent_config: Base configuration for all agents - + Returns: Comparison report with metrics for each prompt variant """ if agent_config is None: agent_config = AgentConfig() - - logger.info(f"๐Ÿงช Testing {len(prompt_variants)} prompt variants on dataset with {len(dataset_df)} rows") - + + logger.info( + f"๐Ÿงช Testing {len(prompt_variants)} prompt variants on dataset with {len(dataset_df)} rows" + ) + results = [] - + for i, system_prompt in enumerate(prompt_variants): - prompt_id = f"variant_{i+1}" - logger.info(f"๐Ÿ”„ Testing prompt variant {i+1}/{len(prompt_variants)}") - print(f"๐Ÿ”„ Testing prompt variant {i+1}/{len(prompt_variants)}") - + prompt_id = f"variant_{i + 1}" + logger.info( + f"๐Ÿ”„ Testing prompt variant {i + 1}/{len(prompt_variants)}" + ) + print(f"๐Ÿ”„ Testing prompt variant {i + 1}/{len(prompt_variants)}") + start_time = time.time() - + try: # Initialize fresh dependencies for each test - deps = AnalystAgentDeps() + deps = AnalystAgentDeps() main_ref = deps.store(dataset_df) - - print(f" ๐Ÿค– Creating agent for variant {i+1}...") + + print(f" ๐Ÿค– Creating agent for variant {i + 1}...") # Create agent with this prompt variant test_agent = Agent( f"openai:{agent_config.model_name}", @@ -132,25 +144,27 @@ def compare_agent_prompts( output_type=EDAReport, output_retries=3, system_prompt=system_prompt, - model_settings=ModelSettings(parallel_tool_calls=False), # Disable parallel to avoid hanging + model_settings=ModelSettings( + parallel_tool_calls=False + ), # Disable parallel to avoid hanging ) - + # Register tools print(f" ๐Ÿ”ง Registering {len(AGENT_TOOLS)} tools...") for tool in AGENT_TOOLS: test_agent.tool(tool) - + # Run analysis with consistent user prompt user_prompt = f"""Analyze dataset '{main_ref}' ({dataset_df.shape[0]} rows, {dataset_df.shape[1]} cols). Focus on data quality, key patterns, and actionable insights.""" - - print(f" โšก Running analysis for variant {i+1}...") + + print(f" โšก Running analysis for variant {i + 1}...") result = test_agent.run_sync(user_prompt, deps=deps) eda_report = result.output - + execution_time = time.time() - start_time - + # Collect metrics for comparison result_metrics = { "prompt_id": prompt_id, @@ -165,16 +179,19 @@ def compare_agent_prompts( "correlation_insights": len(eda_report.correlation_insights), "headline_length": len(eda_report.headline), "markdown_length": len(eda_report.markdown), - "tables_generated": len(deps.output) - 1, # Exclude main dataset + "tables_generated": len(deps.output) + - 1, # Exclude main dataset "error": None, } - - logger.info(f"โœ… Variant {i+1}: Score={eda_report.data_quality_score:.1f}, Tools={len(deps.query_history)}, Time={execution_time:.1f}s") - + + logger.info( + f"โœ… Variant {i + 1}: Score={eda_report.data_quality_score:.1f}, Tools={len(deps.query_history)}, Time={execution_time:.1f}s" + ) + except Exception as e: execution_time = time.time() - start_time - logger.warning(f"โŒ Variant {i+1} failed: {str(e)}") - + logger.warning(f"โŒ Variant {i + 1} failed: {str(e)}") + result_metrics = { "prompt_id": prompt_id, "system_prompt": system_prompt, @@ -191,36 +208,44 @@ def compare_agent_prompts( "tables_generated": 0, "error": str(e), } - + results.append(result_metrics) - + # Analyze results and determine best variant successful_results = [r for r in results if r["success"]] - + if successful_results: # Rank by composite score (quality + speed + thoroughness) for result in successful_results: # Composite score: 60% quality + 20% speed + 20% thoroughness - speed_score = max(0, 100 - result["execution_time_seconds"] * 10) # Penalty for slowness - thoroughness_score = (result["key_findings_count"] * 10 + - result["risks_identified"] * 5 + - result["tables_generated"] * 5) - + speed_score = max( + 0, 100 - result["execution_time_seconds"] * 10 + ) # Penalty for slowness + thoroughness_score = ( + result["key_findings_count"] * 10 + + result["risks_identified"] * 5 + + result["tables_generated"] * 5 + ) + result["composite_score"] = ( - result["data_quality_score"] * 0.6 + - min(speed_score, 100) * 0.2 + - min(thoroughness_score, 100) * 0.2 + result["data_quality_score"] * 0.6 + + min(speed_score, 100) * 0.2 + + min(thoroughness_score, 100) * 0.2 ) - + # Find best performer - best_result = max(successful_results, key=lambda x: x["composite_score"]) + best_result = max( + successful_results, key=lambda x: x["composite_score"] + ) best_prompt_id = best_result["prompt_id"] - - logger.info(f"๐Ÿ† Best performing prompt: {best_prompt_id} (score: {best_result['composite_score']:.1f})") + + logger.info( + f"๐Ÿ† Best performing prompt: {best_prompt_id} (score: {best_result['composite_score']:.1f})" + ) else: best_prompt_id = "none" logger.warning("โŒ No prompts succeeded") - + # Generate summary comparison summary = { "experiment_summary": { @@ -228,60 +253,88 @@ def compare_agent_prompts( "successful_runs": len(successful_results), "failed_runs": len(results) - len(successful_results), "best_prompt_variant": best_prompt_id, - "avg_execution_time": round(sum(r["execution_time_seconds"] for r in results) / len(results), 2) if results else 0, + "avg_execution_time": round( + sum(r["execution_time_seconds"] for r in results) + / len(results), + 2, + ) + if results + else 0, "dataset_info": { "rows": len(dataset_df), "columns": len(dataset_df.columns), "column_names": list(dataset_df.columns), - } + }, }, "detailed_results": results, "recommendations": _generate_prompt_recommendations(results), } - + return summary -def _generate_prompt_recommendations(results: List[Dict[str, Any]]) -> List[str]: +def _generate_prompt_recommendations( + results: List[Dict[str, Any]], +) -> List[str]: """Generate recommendations based on prompt experiment results.""" recommendations = [] - + successful_results = [r for r in results if r["success"]] - + if not successful_results: - return ["All prompts failed - check agent configuration and error messages"] - + return [ + "All prompts failed - check agent configuration and error messages" + ] + # Performance analysis - avg_time = sum(r["execution_time_seconds"] for r in successful_results) / len(successful_results) - avg_quality = sum(r["data_quality_score"] for r in successful_results) / len(successful_results) - + avg_time = sum( + r["execution_time_seconds"] for r in successful_results + ) / len(successful_results) + avg_quality = sum( + r["data_quality_score"] for r in successful_results + ) / len(successful_results) + if avg_time > 30: - recommendations.append("Consider shorter, more focused prompts to reduce execution time") - + recommendations.append( + "Consider shorter, more focused prompts to reduce execution time" + ) + if avg_quality < 70: - recommendations.append("Consider more specific instructions for data quality assessment") - + recommendations.append( + "Consider more specific instructions for data quality assessment" + ) + # Consistency analysis quality_scores = [r["data_quality_score"] for r in successful_results] quality_variance = max(quality_scores) - min(quality_scores) - + if quality_variance > 20: - recommendations.append("High variance in quality scores - consider more consistent prompt structure") + recommendations.append( + "High variance in quality scores - consider more consistent prompt structure" + ) else: - recommendations.append("Quality scores are consistent across prompts - good stability") - + recommendations.append( + "Quality scores are consistent across prompts - good stability" + ) + # Tool usage analysis tool_calls = [r["tool_calls_made"] for r in successful_results] if max(tool_calls) - min(tool_calls) > 3: - recommendations.append("Variable tool usage detected - consider standardizing analysis steps") - + recommendations.append( + "Variable tool usage detected - consider standardizing analysis steps" + ) + # Success rate analysis success_rate = len(successful_results) / len(results) * 100 if success_rate < 80: - recommendations.append("Low success rate - review prompt complexity and error handling") + recommendations.append( + "Low success rate - review prompt complexity and error handling" + ) else: - recommendations.append(f"Good success rate ({success_rate:.0f}%) - prompts are robust") - + recommendations.append( + f"Good success rate ({success_rate:.0f}%) - prompts are robust" + ) + return recommendations @@ -292,13 +345,15 @@ def _load_test_cases(test_cases_path: str) -> List[Dict[str, Any]]: if not test_cases_file.exists(): logger.warning(f"Test cases file not found: {test_cases_path}") return [] - - with open(test_cases_file, 'r') as f: + + with open(test_cases_file, "r") as f: test_cases = json.load(f) - - logger.info(f"Loaded {len(test_cases)} test cases from {test_cases_path}") + + logger.info( + f"Loaded {len(test_cases)} test cases from {test_cases_path}" + ) return test_cases - + except Exception as e: logger.error(f"Failed to load test cases: {e}") return [] @@ -308,8 +363,10 @@ def _create_llm_judge(base_model: str) -> Optional[Agent]: """Create an LLM judge agent for evaluating responses.""" try: # Use a stronger model for evaluation if available - judge_model = "gpt-4o" if "gpt" in base_model else "claude-3-5-sonnet-20241022" - + judge_model = ( + "gpt-4o" if "gpt" in base_model else "claude-3-5-sonnet-20241022" + ) + judge_system_prompt = """You are an expert AI evaluator specializing in data analysis responses. Your job is to assess EDA (Exploratory Data Analysis) responses based on: @@ -326,10 +383,12 @@ def _create_llm_judge(base_model: str) -> Optional[Agent]: Be objective and consistent in your evaluations.""" return Agent( - f"openai:{judge_model}" if "gpt" in base_model else f"anthropic:{judge_model}", + f"openai:{judge_model}" + if "gpt" in base_model + else f"anthropic:{judge_model}", system_prompt=judge_system_prompt, ) - + except Exception as e: logger.warning(f"Failed to create LLM judge: {e}") return None @@ -345,12 +404,12 @@ def _run_single_test_case( ) -> Dict[str, Any]: """Run a single test case and collect comprehensive metrics.""" start_time = time.time() - + try: # Initialize agent for this test deps = AnalystAgentDeps() main_ref = deps.store(dataset_df) - + test_agent = Agent( f"openai:{agent_config.model_name}", deps_type=AnalystAgentDeps, @@ -359,18 +418,18 @@ def _run_single_test_case( system_prompt=system_prompt, model_settings=ModelSettings(parallel_tool_calls=True), ) - + # Register tools for tool in AGENT_TOOLS: test_agent.tool(tool) - + # Run the test case query full_query = f"Dataset reference: {main_ref}\n\n{test_case['query']}" result = test_agent.run_sync(full_query, deps=deps) eda_report = result.output - + execution_time = time.time() - start_time - + # Collect basic metrics case_result = { "test_id": test_case["id"], @@ -387,18 +446,22 @@ def _run_single_test_case( "recommendations_count": len(eda_report.fixes), "error": None, } - + # Evaluate with LLM judge if available if llm_judge: - judge_evaluation = _get_llm_judge_scores(llm_judge, test_case, case_result) + judge_evaluation = _get_llm_judge_scores( + llm_judge, test_case, case_result + ) case_result.update(judge_evaluation) - + # Check against expected metrics if available expected_metrics = test_case.get("expected_metrics", {}) - case_result["meets_expectations"] = _check_expectations(case_result, expected_metrics) - + case_result["meets_expectations"] = _check_expectations( + case_result, expected_metrics + ) + return case_result - + except Exception as e: execution_time = time.time() - start_time return { @@ -413,24 +476,26 @@ def _run_single_test_case( } -def _get_llm_judge_scores(llm_judge: Agent, test_case: Dict, case_result: Dict) -> Dict: +def _get_llm_judge_scores( + llm_judge: Agent, test_case: Dict, case_result: Dict +) -> Dict: """Get quality scores from LLM judge.""" try: eval_prompt = f""" -Query: {test_case['query']} -Category: {test_case.get('category', 'general')} +Query: {test_case["query"]} +Category: {test_case.get("category", "general")} Response to evaluate: -{case_result['response'][:2000]}... +{case_result["response"][:2000]}... -Data Quality Score Provided: {case_result['data_quality_score']} -Tool Calls Made: {case_result['tool_calls_made']} -Findings Count: {case_result['findings_count']} +Data Quality Score Provided: {case_result["data_quality_score"]} +Tool Calls Made: {case_result["tool_calls_made"]} +Findings Count: {case_result["findings_count"]} Please evaluate this EDA response and provide scores in the requested JSON format.""" judge_response = llm_judge.run_sync(eval_prompt) - + # Parse JSON response try: scores = json.loads(str(judge_response.output)) @@ -445,43 +510,51 @@ def _get_llm_judge_scores(llm_judge: Agent, test_case: Dict, case_result: Dict) } except json.JSONDecodeError: logger.warning("LLM judge response was not valid JSON") - return {"judge_overall": 3, "judge_reasoning": "Failed to parse judge response"} - + return { + "judge_overall": 3, + "judge_reasoning": "Failed to parse judge response", + } + except Exception as e: logger.warning(f"LLM judge evaluation failed: {e}") - return {"judge_overall": 3, "judge_reasoning": f"Evaluation error: {e}"} + return { + "judge_overall": 3, + "judge_reasoning": f"Evaluation error: {e}", + } def _check_expectations(case_result: Dict, expected_metrics: Dict) -> bool: """Check if results meet expected criteria.""" if not expected_metrics: return True - + checks = [] - + # Check minimum quality score min_quality = expected_metrics.get("quality_score_min") if min_quality: checks.append(case_result.get("data_quality_score", 0) >= min_quality) - + # Check minimum recommendations min_recs = expected_metrics.get("recommendations_min") if min_recs: checks.append(case_result.get("recommendations_count", 0) >= min_recs) - + # Check expected tool calls expected_tools = expected_metrics.get("tool_calls_expected", []) if expected_tools: # This would need actual tool tracking - simplified for now checks.append(case_result.get("tool_calls_made", 0) > 0) - + return all(checks) if checks else True -def _analyze_prompt_performance(results: List[Dict], prompt_id: str, system_prompt: str) -> Dict: +def _analyze_prompt_performance( + results: List[Dict], prompt_id: str, system_prompt: str +) -> Dict: """Analyze performance across all test cases for a single prompt.""" successful_results = [r for r in results if r["success"]] - + if not successful_results: return { "prompt_id": prompt_id, @@ -489,46 +562,72 @@ def _analyze_prompt_performance(results: List[Dict], prompt_id: str, system_prom "success_rate": 0, "avg_scores": {}, "category_performance": {}, - "overall_rating": "failed" + "overall_rating": "failed", } - + # Calculate averages avg_scores = { - "execution_time": sum(r["execution_time"] for r in successful_results) / len(successful_results), - "tool_calls": sum(r["tool_calls_made"] for r in successful_results) / len(successful_results), - "data_quality_score": sum(r["data_quality_score"] for r in successful_results) / len(successful_results), - "findings_count": sum(r["findings_count"] for r in successful_results) / len(successful_results), + "execution_time": sum(r["execution_time"] for r in successful_results) + / len(successful_results), + "tool_calls": sum(r["tool_calls_made"] for r in successful_results) + / len(successful_results), + "data_quality_score": sum( + r["data_quality_score"] for r in successful_results + ) + / len(successful_results), + "findings_count": sum(r["findings_count"] for r in successful_results) + / len(successful_results), } - + # Add LLM judge scores if available judge_scores = [r for r in successful_results if "judge_overall" in r] if judge_scores: - avg_scores.update({ - "judge_accuracy": sum(r["judge_accuracy"] for r in judge_scores) / len(judge_scores), - "judge_relevance": sum(r["judge_relevance"] for r in judge_scores) / len(judge_scores), - "judge_completeness": sum(r["judge_completeness"] for r in judge_scores) / len(judge_scores), - "judge_overall": sum(r["judge_overall"] for r in judge_scores) / len(judge_scores), - }) - + avg_scores.update( + { + "judge_accuracy": sum( + r["judge_accuracy"] for r in judge_scores + ) + / len(judge_scores), + "judge_relevance": sum( + r["judge_relevance"] for r in judge_scores + ) + / len(judge_scores), + "judge_completeness": sum( + r["judge_completeness"] for r in judge_scores + ) + / len(judge_scores), + "judge_overall": sum(r["judge_overall"] for r in judge_scores) + / len(judge_scores), + } + ) + # Category-wise performance categories = {} for result in successful_results: cat = result["category"] if cat not in categories: - categories[cat] = {"count": 0, "avg_score": 0, "meets_expectations": 0} + categories[cat] = { + "count": 0, + "avg_score": 0, + "meets_expectations": 0, + } categories[cat]["count"] += 1 categories[cat]["avg_score"] += result["data_quality_score"] if result.get("meets_expectations", False): categories[cat]["meets_expectations"] += 1 - + for cat in categories: categories[cat]["avg_score"] /= categories[cat]["count"] - categories[cat]["success_rate"] = categories[cat]["meets_expectations"] / categories[cat]["count"] - + categories[cat]["success_rate"] = ( + categories[cat]["meets_expectations"] / categories[cat]["count"] + ) + # Overall rating success_rate = len(successful_results) / len(results) - avg_judge_score = avg_scores.get("judge_overall", avg_scores.get("data_quality_score", 50) / 20) - + avg_judge_score = avg_scores.get( + "judge_overall", avg_scores.get("data_quality_score", 50) / 20 + ) + if success_rate >= 0.8 and avg_judge_score >= 4: rating = "excellent" elif success_rate >= 0.6 and avg_judge_score >= 3: @@ -537,7 +636,7 @@ def _analyze_prompt_performance(results: List[Dict], prompt_id: str, system_prom rating = "acceptable" else: rating = "poor" - + return { "prompt_id": prompt_id, "system_prompt": system_prompt, @@ -549,87 +648,119 @@ def _analyze_prompt_performance(results: List[Dict], prompt_id: str, system_prom } -def _generate_comprehensive_analysis(all_results: List[Dict], test_cases: List[Dict]) -> Dict: +def _generate_comprehensive_analysis( + all_results: List[Dict], test_cases: List[Dict] +) -> Dict: """Generate final comprehensive analysis of all prompt variants.""" # Rank prompts by overall performance - ranked_prompts = sorted(all_results, key=lambda x: ( - x["success_rate"] * 0.4 + - x["avg_scores"].get("judge_overall", 3) * 0.3 + - (100 - x["avg_scores"]["execution_time"]) / 100 * 0.2 + - x["avg_scores"]["data_quality_score"] / 100 * 0.1 - ), reverse=True) - + ranked_prompts = sorted( + all_results, + key=lambda x: ( + x["success_rate"] * 0.4 + + x["avg_scores"].get("judge_overall", 3) * 0.3 + + (100 - x["avg_scores"]["execution_time"]) / 100 * 0.2 + + x["avg_scores"]["data_quality_score"] / 100 * 0.1 + ), + reverse=True, + ) + best_prompt = ranked_prompts[0] if ranked_prompts else None - + # Category analysis category_insights = {} for test_case in test_cases: - cat = test_case.get("category", "general") + cat = test_case.get("category", "general") if cat not in category_insights: category_insights[cat] = {"prompt_performance": []} - + for prompt_result in all_results: cat_perf = prompt_result["category_performance"].get(cat, {}) if cat_perf: - category_insights[cat]["prompt_performance"].append({ - "prompt_id": prompt_result["prompt_id"], - "avg_score": cat_perf["avg_score"], - "success_rate": cat_perf["success_rate"] - }) - + category_insights[cat]["prompt_performance"].append( + { + "prompt_id": prompt_result["prompt_id"], + "avg_score": cat_perf["avg_score"], + "success_rate": cat_perf["success_rate"], + } + ) + return { "evaluation_summary": { "total_prompts_tested": len(all_results), "total_test_cases": len(test_cases), "best_prompt": best_prompt["prompt_id"] if best_prompt else "none", - "best_prompt_rating": best_prompt["overall_rating"] if best_prompt else "none", + "best_prompt_rating": best_prompt["overall_rating"] + if best_prompt + else "none", "categories_tested": list(category_insights.keys()), }, "prompt_rankings": ranked_prompts, "category_analysis": category_insights, "detailed_results": all_results, - "recommendations": _generate_advanced_recommendations(all_results, category_insights), + "recommendations": _generate_advanced_recommendations( + all_results, category_insights + ), } -def _generate_advanced_recommendations(all_results: List[Dict], category_insights: Dict) -> List[str]: +def _generate_advanced_recommendations( + all_results: List[Dict], category_insights: Dict +) -> List[str]: """Generate advanced recommendations based on comprehensive analysis.""" recommendations = [] - + if not all_results: return ["No successful evaluations - check agent configuration"] - + # Success rate analysis - avg_success_rate = sum(r["success_rate"] for r in all_results) / len(all_results) + avg_success_rate = sum(r["success_rate"] for r in all_results) / len( + all_results + ) if avg_success_rate < 0.6: - recommendations.append("Low overall success rate - consider simplifying prompts or checking tool integration") - + recommendations.append( + "Low overall success rate - consider simplifying prompts or checking tool integration" + ) + # Performance consistency execution_times = [r["avg_scores"]["execution_time"] for r in all_results] time_variance = max(execution_times) - min(execution_times) if time_variance > 30: - recommendations.append("High variance in execution time - optimize slower prompts for efficiency") - + recommendations.append( + "High variance in execution time - optimize slower prompts for efficiency" + ) + # Quality assessment if any("judge_overall" in r["avg_scores"] for r in all_results): - judge_scores = [r["avg_scores"]["judge_overall"] for r in all_results if "judge_overall" in r["avg_scores"]] + judge_scores = [ + r["avg_scores"]["judge_overall"] + for r in all_results + if "judge_overall" in r["avg_scores"] + ] avg_judge_score = sum(judge_scores) / len(judge_scores) - + if avg_judge_score < 3: - recommendations.append("LLM judge scores are low - review prompt clarity and specificity") + recommendations.append( + "LLM judge scores are low - review prompt clarity and specificity" + ) elif avg_judge_score > 4: - recommendations.append("Excellent LLM judge scores - prompts are producing high-quality responses") - + recommendations.append( + "Excellent LLM judge scores - prompts are producing high-quality responses" + ) + # Category-specific insights for category, data in category_insights.items(): if data["prompt_performance"]: cat_scores = [p["avg_score"] for p in data["prompt_performance"]] if min(cat_scores) < 60: - recommendations.append(f"'{category}' category shows low scores - consider specialized prompts for this use case") - + recommendations.append( + f"'{category}' category shows low scores - consider specialized prompts for this use case" + ) + # Best practices best_prompt = max(all_results, key=lambda x: x["success_rate"]) if best_prompt["success_rate"] > 0.8: - recommendations.append(f"'{best_prompt['prompt_id']}' shows strong performance - consider it as your baseline") - - return recommendations \ No newline at end of file + recommendations.append( + f"'{best_prompt['prompt_id']}' shows strong performance - consider it as your baseline" + ) + + return recommendations From d669d9291e518822b5adccb08f7ff97b94221e99 Mon Sep 17 00:00:00 2001 From: Hamza Tahir Date: Fri, 22 Aug 2025 11:10:06 +0200 Subject: [PATCH 06/14] Update pydantic version to be more flexible --- examples/pydantic_ai_eda/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/pydantic_ai_eda/requirements.txt b/examples/pydantic_ai_eda/requirements.txt index 93b1bee7e3b..2e210ed5bfb 100644 --- a/examples/pydantic_ai_eda/requirements.txt +++ b/examples/pydantic_ai_eda/requirements.txt @@ -7,7 +7,7 @@ numpy>=1.24.0,<2.0.0 duckdb>=1.0.0,<2.0.0 # Pydantic compatibility - CRITICAL: Must be <2.10 for ZenML compatibility -pydantic>=2.0.0,<2.10.0 +pydantic # AI/ML frameworks pydantic-ai[logfire]>=0.4.0 From 42864f9c8d13c535d4c674ff75bc36e6e359c5d6 Mon Sep 17 00:00:00 2001 From: Hamza Tahir Date: Mon, 25 Aug 2025 23:56:26 +0200 Subject: [PATCH 07/14] Formattingg --- examples/prompt_optimization/README.md | 110 +++ .../__init__.py | 0 examples/prompt_optimization/models.py | 31 + .../prompt_optimization/pipelines/__init__.py | 15 + .../pipelines/production_eda_pipeline.py | 58 ++ .../pipelines/prompt_optimization_pipeline.py | 54 ++ .../requirements.txt | 0 examples/prompt_optimization/run.py | 110 +++ .../prompt_optimization/steps/__init__.py | 18 + .../steps/agent_tools.py | 0 .../steps/eda_agent.py | 27 +- examples/prompt_optimization/steps/ingest.py | 93 +++ .../steps/prompt_optimization.py | 170 ++++ examples/pydantic_ai_eda/README.md | 80 -- examples/pydantic_ai_eda/models.py | 137 ---- .../pydantic_ai_eda/pipelines/__init__.py | 15 - .../pydantic_ai_eda/pipelines/eda_pipeline.py | 98 --- .../pipelines/prompt_experiment_pipeline.py | 61 -- examples/pydantic_ai_eda/run.py | 59 -- .../pydantic_ai_eda/run_prompt_experiment.py | 120 --- examples/pydantic_ai_eda/steps/__init__.py | 24 - examples/pydantic_ai_eda/steps/ingest.py | 374 --------- .../steps/prompt_experiment.py | 766 ------------------ .../pydantic_ai_eda/steps/quality_gate.py | 278 ------- examples/pydantic_ai_eda/test_cases.json | 47 -- 25 files changed, 683 insertions(+), 2062 deletions(-) create mode 100644 examples/prompt_optimization/README.md rename examples/{pydantic_ai_eda => prompt_optimization}/__init__.py (100%) create mode 100644 examples/prompt_optimization/models.py create mode 100644 examples/prompt_optimization/pipelines/__init__.py create mode 100644 examples/prompt_optimization/pipelines/production_eda_pipeline.py create mode 100644 examples/prompt_optimization/pipelines/prompt_optimization_pipeline.py rename examples/{pydantic_ai_eda => prompt_optimization}/requirements.txt (100%) create mode 100644 examples/prompt_optimization/run.py create mode 100644 examples/prompt_optimization/steps/__init__.py rename examples/{pydantic_ai_eda => prompt_optimization}/steps/agent_tools.py (100%) rename examples/{pydantic_ai_eda => prompt_optimization}/steps/eda_agent.py (81%) create mode 100644 examples/prompt_optimization/steps/ingest.py create mode 100644 examples/prompt_optimization/steps/prompt_optimization.py delete mode 100644 examples/pydantic_ai_eda/README.md delete mode 100644 examples/pydantic_ai_eda/models.py delete mode 100644 examples/pydantic_ai_eda/pipelines/__init__.py delete mode 100644 examples/pydantic_ai_eda/pipelines/eda_pipeline.py delete mode 100644 examples/pydantic_ai_eda/pipelines/prompt_experiment_pipeline.py delete mode 100644 examples/pydantic_ai_eda/run.py delete mode 100644 examples/pydantic_ai_eda/run_prompt_experiment.py delete mode 100644 examples/pydantic_ai_eda/steps/__init__.py delete mode 100644 examples/pydantic_ai_eda/steps/ingest.py delete mode 100644 examples/pydantic_ai_eda/steps/prompt_experiment.py delete mode 100644 examples/pydantic_ai_eda/steps/quality_gate.py delete mode 100644 examples/pydantic_ai_eda/test_cases.json diff --git a/examples/prompt_optimization/README.md b/examples/prompt_optimization/README.md new file mode 100644 index 00000000000..cfa8579ae68 --- /dev/null +++ b/examples/prompt_optimization/README.md @@ -0,0 +1,110 @@ +# Prompt Optimization with ZenML + +This example demonstrates **ZenML's artifact management** through a two-stage AI prompt optimization workflow using **Pydantic AI** for exploratory data analysis. + +## What This Example Shows + +**Stage 1: Prompt Optimization** +- Tests multiple prompt variants against sample data +- Compares performance using data quality scores and execution time +- Tags the best-performing prompt with an **exclusive ZenML tag** +- Stores the optimized prompt in ZenML's artifact registry + +**Stage 2: Production Analysis** +- Automatically retrieves the tagged optimal prompt from the registry +- Runs production EDA analysis using the best prompt +- Demonstrates real artifact sharing between pipeline runs + +This showcases how ZenML enables **reproducible ML workflows** where optimization results automatically flow into production systems. + +## Quick Start + +### Prerequisites +```bash +# Install ZenML and initialize +pip install "zenml[server]" +zenml init + +# Install example dependencies +pip install -r requirements.txt + +# Set your API key (OpenAI or Anthropic) +export OPENAI_API_KEY="your-openai-key" +# OR +export ANTHROPIC_API_KEY="your-anthropic-key" +``` + +### Run the Example +```bash +# Complete two-stage workflow (default behavior) +python run.py + +# Run individual stages +python run.py --optimization-pipeline # Stage 1: Find best prompt +python run.py --production-pipeline # Stage 2: Use optimized prompt +``` + +## Data Sources + +The example supports multiple data sources: + +```bash +# HuggingFace datasets (default) +python run.py --data-source "hf:scikit-learn/iris" +python run.py --data-source "hf:scikit-learn/wine" + +# Local CSV files +python run.py --data-source "local:./my_data.csv" + +# Specify target column +python run.py --data-source "local:sales.csv" --target-column "revenue" +``` + +## Key ZenML Concepts Demonstrated + +### Artifact Management +- **Exclusive Tagging**: Only one prompt can have the "optimized" tag at a time +- **Artifact Registry**: Centralized storage for ML artifacts with versioning +- **Cross-Pipeline Sharing**: Production pipeline automatically finds optimization results + +### Pipeline Orchestration +- **Multi-Stage Workflows**: Optimization โ†’ Production with artifact passing +- **Conditional Execution**: Production pipeline adapts based on available artifacts +- **Lineage Tracking**: Full traceability from prompt testing to production use + +### AI Integration +- **Model Flexibility**: Works with OpenAI GPT or Anthropic Claude models +- **Performance Testing**: Systematic comparison of prompt variants +- **Production Deployment**: Seamless transition from experimentation to production + +## Configuration Options + +```bash +# Model selection +python run.py --model-name "gpt-4o-mini" +python run.py --model-name "claude-3-haiku-20240307" + +# Performance tuning +python run.py --max-tool-calls 8 --timeout-seconds 120 + +# Development options +python run.py --no-cache # Disable caching for fresh runs +``` + +## Expected Output + +When you run the complete workflow, you'll see: + +1. **Optimization Stage**: Testing of 3 prompt variants with performance metrics +2. **Tagging**: Best prompt automatically tagged in ZenML registry +3. **Production Stage**: Retrieval and use of optimized prompt +4. **Results**: EDA analysis with data quality scores and recommendations + +The ZenML dashboard will show the complete lineage from optimization to production use. + +## Next Steps + +- **View Results**: Check the ZenML dashboard for pipeline runs and artifacts +- **Customize Prompts**: Modify the prompt variants in `run.py` for your domain +- **Scale Up**: Deploy with remote orchestrators for production workloads +- **Integrate**: Connect to your existing data sources and ML pipelines \ No newline at end of file diff --git a/examples/pydantic_ai_eda/__init__.py b/examples/prompt_optimization/__init__.py similarity index 100% rename from examples/pydantic_ai_eda/__init__.py rename to examples/prompt_optimization/__init__.py diff --git a/examples/prompt_optimization/models.py b/examples/prompt_optimization/models.py new file mode 100644 index 00000000000..5b0c1ea7b33 --- /dev/null +++ b/examples/prompt_optimization/models.py @@ -0,0 +1,31 @@ +"""Simple data models for prompt optimization example.""" + +from typing import List, Optional + +from pydantic import BaseModel, Field + + +class DataSourceConfig(BaseModel): + """Data source configuration.""" + + source_type: str = Field(description="Data source type: hf or local") + source_path: str = Field(description="Path or identifier for the data source") + target_column: Optional[str] = Field(None, description="Optional target column name") + sample_size: Optional[int] = Field(None, description="Number of rows to sample") + + +class AgentConfig(BaseModel): + """AI agent configuration.""" + + model_name: str = Field("gpt-4o-mini", description="Language model to use") + max_tool_calls: int = Field(6, description="Maximum tool calls allowed") + timeout_seconds: int = Field(60, description="Max execution time in seconds") + + +class EDAReport(BaseModel): + """Simple EDA analysis report from AI agent.""" + + headline: str = Field(description="Executive summary of key findings") + key_findings: List[str] = Field(default_factory=list, description="Important discoveries") + data_quality_score: float = Field(description="Overall quality score (0-100)") + markdown: str = Field(description="Full markdown report") \ No newline at end of file diff --git a/examples/prompt_optimization/pipelines/__init__.py b/examples/prompt_optimization/pipelines/__init__.py new file mode 100644 index 00000000000..e278ea485db --- /dev/null +++ b/examples/prompt_optimization/pipelines/__init__.py @@ -0,0 +1,15 @@ +"""ZenML pipelines for prompt optimization example. + +This module contains pipeline definitions: + +- prompt_optimization_pipeline.py: Test prompts and tag the best one +- production_eda_pipeline.py: Use optimized prompt for EDA analysis +""" + +from .production_eda_pipeline import production_eda_pipeline +from .prompt_optimization_pipeline import prompt_optimization_pipeline + +__all__ = [ + "prompt_optimization_pipeline", + "production_eda_pipeline", +] \ No newline at end of file diff --git a/examples/prompt_optimization/pipelines/production_eda_pipeline.py b/examples/prompt_optimization/pipelines/production_eda_pipeline.py new file mode 100644 index 00000000000..a8f02ae1084 --- /dev/null +++ b/examples/prompt_optimization/pipelines/production_eda_pipeline.py @@ -0,0 +1,58 @@ +"""Simple production EDA pipeline using optimized prompts.""" + +from typing import Any, Dict, Optional + +from models import AgentConfig, DataSourceConfig +from steps import get_optimized_prompt, ingest_data, run_eda_agent + +from zenml import pipeline +from zenml.logger import get_logger + +logger = get_logger(__name__) + + +@pipeline +def production_eda_pipeline( + source_config: DataSourceConfig, + agent_config: Optional[AgentConfig] = None, +) -> Dict[str, Any]: + """Production EDA pipeline using optimized prompts from the registry. + + This pipeline demonstrates ZenML's artifact retrieval by fetching + previously optimized prompts for production analysis. + + Args: + source_config: Data source configuration + agent_config: AI agent configuration + use_optimized_prompt: Whether to use optimized prompt from registry + + Returns: + EDA results and metadata + """ + logger.info("๐Ÿญ Starting production EDA pipeline") + + # Step 1: Get optimized prompt + optimized_prompt = get_optimized_prompt() + logger.info("๐ŸŽฏ Retrieved optimized prompt") + + # Step 2: Load data + dataset_df, metadata = ingest_data(source_config=source_config) + + # Step 3: Run EDA analysis with optimized prompt + report_markdown, report_json, sql_log, analysis_tables = run_eda_agent( + dataset_df=dataset_df, + dataset_metadata=metadata, + agent_config=agent_config, + custom_system_prompt=optimized_prompt, + ) + + logger.info("โœ… Production EDA pipeline completed") + + return { + "report_markdown": report_markdown, + "report_json": report_json, + "sql_log": sql_log, + "analysis_tables": analysis_tables, + "used_optimized_prompt": optimized_prompt is not None, + "metadata": metadata, + } diff --git a/examples/prompt_optimization/pipelines/prompt_optimization_pipeline.py b/examples/prompt_optimization/pipelines/prompt_optimization_pipeline.py new file mode 100644 index 00000000000..d756a6f875d --- /dev/null +++ b/examples/prompt_optimization/pipelines/prompt_optimization_pipeline.py @@ -0,0 +1,54 @@ +"""Simple prompt optimization pipeline for ZenML artifact management demo.""" + +from typing import Any, Dict, List, Optional + +from models import AgentConfig, DataSourceConfig +from steps import compare_prompts_and_tag_best, ingest_data + +from zenml import pipeline +from zenml.logger import get_logger + +logger = get_logger(__name__) + + +@pipeline +def prompt_optimization_pipeline( + source_config: DataSourceConfig, + prompt_variants: List[str], + agent_config: Optional[AgentConfig] = None, +) -> Dict[str, Any]: + """Optimize prompts and tag the best one for production use. + + This pipeline demonstrates ZenML's artifact management by testing + multiple prompt variants and tagging the best performer. + + Args: + source_config: Data source configuration + prompt_variants: List of prompt strings to test + agent_config: AI agent configuration + + Returns: + Pipeline results with best prompt and metadata + """ + logger.info("๐Ÿงช Starting prompt optimization pipeline") + + # Step 1: Load data + dataset_df, metadata = ingest_data(source_config=source_config) + + # Step 2: Test prompts and tag the best one + best_prompt = compare_prompts_and_tag_best( + dataset_df=dataset_df, + prompt_variants=prompt_variants, + agent_config=agent_config, + ) + + logger.info("โœ… Prompt optimization completed - best prompt tagged with 'optimized'") + + return { + "best_prompt": best_prompt, + "metadata": metadata, + "config": { + "source": f"{source_config.source_type}:{source_config.source_path}", + "variants_tested": len(prompt_variants), + }, + } \ No newline at end of file diff --git a/examples/pydantic_ai_eda/requirements.txt b/examples/prompt_optimization/requirements.txt similarity index 100% rename from examples/pydantic_ai_eda/requirements.txt rename to examples/prompt_optimization/requirements.txt diff --git a/examples/prompt_optimization/run.py b/examples/prompt_optimization/run.py new file mode 100644 index 00000000000..7ff348c79c0 --- /dev/null +++ b/examples/prompt_optimization/run.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python3 +"""Two-Stage Prompt Optimization Example with ZenML.""" + +import os +from typing import Optional + +import click + +from models import AgentConfig, DataSourceConfig +from pipelines.production_eda_pipeline import production_eda_pipeline +from pipelines.prompt_optimization_pipeline import prompt_optimization_pipeline + + +@click.command(help="Run prompt optimization and/or production EDA pipelines (both by default)") +@click.option("--optimization-pipeline", is_flag=True, help="Run prompt optimization") +@click.option("--production-pipeline", is_flag=True, help="Run production EDA") +@click.option("--data-source", default="hf:scikit-learn/iris", help="Data source (type:path)") +@click.option("--target-column", default="target", help="Target column name") +@click.option("--model-name", help="Model name (auto-detected if not specified)") +@click.option("--max-tool-calls", default=6, type=int, help="Max tool calls") +@click.option("--timeout-seconds", default=60, type=int, help="Timeout seconds") +@click.option("--no-cache", is_flag=True, help="Disable caching") +def main( + optimization_pipeline: bool = False, + production_pipeline: bool = False, + data_source: str = "hf:scikit-learn/iris", + target_column: str = "target", + model_name: Optional[str] = None, + max_tool_calls: int = 6, + timeout_seconds: int = 60, + no_cache: bool = False, +): + """Run prompt optimization and/or production EDA pipelines.""" + # Default: run both pipelines if no flags specified + if not optimization_pipeline and not production_pipeline: + optimization_pipeline = True + production_pipeline = True + + # Check API keys + has_openai = bool(os.getenv("OPENAI_API_KEY")) + has_anthropic = bool(os.getenv("ANTHROPIC_API_KEY")) + + if not (has_openai or has_anthropic): + click.echo("โŒ Set OPENAI_API_KEY or ANTHROPIC_API_KEY") + return + + # Auto-detect model + if model_name is None: + model_name = "gpt-4o-mini" if has_openai else "claude-3-haiku-20240307" + + # Parse data source + try: + source_type, source_path = data_source.split(":", 1) + except ValueError: + click.echo(f"โŒ Invalid data source: {data_source}") + return + + # Create configs + source_config = DataSourceConfig( + source_type=source_type, + source_path=source_path, + target_column=target_column, + ) + + agent_config = AgentConfig( + model_name=model_name, + max_tool_calls=max_tool_calls, + timeout_seconds=timeout_seconds, + ) + + pipeline_options = {"enable_cache": not no_cache} + + # Stage 1: Prompt optimization + if optimization_pipeline: + click.echo("๐Ÿงช Running prompt optimization...") + + prompt_variants = [ + "You are a data analyst. Quickly assess data quality (0-100 score) and identify key patterns. Be concise.", + "You are a data quality specialist. Calculate quality score (0-100), find missing data, duplicates, and correlations. Provide specific recommendations.", + "You are a business analyst. Assess ML readiness with quality score. Focus on business impact and actionable insights." + ] + + try: + prompt_optimization_pipeline.with_options(**pipeline_options)( + source_config=source_config, + prompt_variants=prompt_variants, + agent_config=agent_config, + ) + click.echo("โœ… Optimization completed - best prompt tagged") + + except Exception as e: + click.echo(f"โŒ Optimization failed: {e}") + + # Stage 2: Production analysis + if production_pipeline: + click.echo("๐Ÿญ Running production analysis...") + + try: + production_eda_pipeline.with_options(**pipeline_options)( + source_config=source_config, + agent_config=agent_config, + ) + click.echo("โœ… Production analysis completed") + + except Exception as e: + click.echo(f"โŒ Production analysis failed: {e}") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/prompt_optimization/steps/__init__.py b/examples/prompt_optimization/steps/__init__.py new file mode 100644 index 00000000000..3f28fca444d --- /dev/null +++ b/examples/prompt_optimization/steps/__init__.py @@ -0,0 +1,18 @@ +"""ZenML steps for prompt optimization example. + +Simple, focused steps that demonstrate ZenML's core capabilities: +- Data ingestion +- AI-powered analysis +- Prompt optimization and tagging +""" + +from .eda_agent import run_eda_agent +from .ingest import ingest_data +from .prompt_optimization import compare_prompts_and_tag_best, get_optimized_prompt + +__all__ = [ + "ingest_data", + "run_eda_agent", + "compare_prompts_and_tag_best", + "get_optimized_prompt", +] \ No newline at end of file diff --git a/examples/pydantic_ai_eda/steps/agent_tools.py b/examples/prompt_optimization/steps/agent_tools.py similarity index 100% rename from examples/pydantic_ai_eda/steps/agent_tools.py rename to examples/prompt_optimization/steps/agent_tools.py diff --git a/examples/pydantic_ai_eda/steps/eda_agent.py b/examples/prompt_optimization/steps/eda_agent.py similarity index 81% rename from examples/pydantic_ai_eda/steps/eda_agent.py rename to examples/prompt_optimization/steps/eda_agent.py index fa21e24754c..5e95060568f 100644 --- a/examples/pydantic_ai_eda/steps/eda_agent.py +++ b/examples/prompt_optimization/steps/eda_agent.py @@ -9,6 +9,11 @@ from zenml import step from zenml.types import MarkdownString +from zenml.logger import get_logger + +logger = get_logger(__name__) + + # Logfire for observability try: import logfire @@ -26,13 +31,24 @@ def run_eda_agent( dataset_df: pd.DataFrame, dataset_metadata: Dict[str, Any], agent_config: AgentConfig = None, + custom_system_prompt: str = None, ) -> Tuple[ Annotated[MarkdownString, "eda_report_markdown"], Annotated[Dict[str, Any], "eda_report_json"], Annotated[List[Dict[str, str]], "sql_execution_log"], Annotated[Dict[str, pd.DataFrame], "analysis_tables"], ]: - """Run simple Pydantic AI agent for EDA analysis.""" + """Run Pydantic AI agent for EDA analysis with optional custom prompt. + + Args: + dataset_df: Dataset to analyze + dataset_metadata: Metadata about the dataset + agent_config: Configuration for the AI agent + custom_system_prompt: Optional custom system prompt (overrides default) + + Returns: + Tuple of EDA outputs: markdown report, JSON report, SQL log, analysis tables + """ if agent_config is None: agent_config = AgentConfig() @@ -49,8 +65,12 @@ def run_eda_agent( deps = AnalystAgentDeps() main_ref = deps.store(dataset_df) - # Create the EDA analyst agent with focused system prompt - system_prompt = """You are a data analyst. Perform quick but insightful EDA. + # Create the EDA analyst agent with system prompt (custom or default) + if custom_system_prompt: + system_prompt = custom_system_prompt + logger.info("๐ŸŽฏ Using custom optimized system prompt for analysis") + else: + system_prompt = """You are a data analyst. Perform quick but insightful EDA. FOCUS ON: - Data quality score (0-100) based on missing data and duplicates @@ -59,6 +79,7 @@ def run_eda_agent( - 2-3 actionable recommendations Be concise but specific with numbers. Aim for quality insights, not exhaustive analysis.""" + logger.info("๐Ÿ“ Using default system prompt for analysis") analyst_agent = Agent( f"openai:{agent_config.model_name}", diff --git a/examples/prompt_optimization/steps/ingest.py b/examples/prompt_optimization/steps/ingest.py new file mode 100644 index 00000000000..22cad0677cc --- /dev/null +++ b/examples/prompt_optimization/steps/ingest.py @@ -0,0 +1,93 @@ +"""Simple data ingestion step for prompt optimization example.""" + +from typing import Annotated, Any, Dict, Tuple + +import pandas as pd +from models import DataSourceConfig + +from zenml import step +from zenml.logger import get_logger + +logger = get_logger(__name__) + + +@step +def ingest_data( + source_config: DataSourceConfig, +) -> Tuple[ + Annotated[pd.DataFrame, "dataset"], + Annotated[Dict[str, Any], "ingestion_metadata"], +]: + """Simple data ingestion from HuggingFace or local files. + + Args: + source_config: Configuration specifying data source + + Returns: + Tuple of (dataframe, metadata) + """ + logger.info(f"Loading data from {source_config.source_type}:{source_config.source_path}") + + # Load data based on source type + if source_config.source_type == "hf": + df = _load_from_huggingface(source_config) + elif source_config.source_type == "local": + df = _load_from_local(source_config) + else: + raise ValueError(f"Unsupported source type: {source_config.source_type}") + + # Apply sampling if configured + if source_config.sample_size and len(df) > source_config.sample_size: + df = df.sample(n=source_config.sample_size, random_state=42) + logger.info(f"Sampled {source_config.sample_size} rows from {len(df)} total") + + # Generate simple metadata + metadata = { + "source_type": source_config.source_type, + "source_path": source_config.source_path, + "rows": len(df), + "columns": len(df.columns), + "column_names": df.columns.tolist(), + "target_column": source_config.target_column, + } + + logger.info(f"Loaded dataset: {len(df)} rows ร— {len(df.columns)} columns") + return df, metadata + + +def _load_from_huggingface(config: DataSourceConfig) -> pd.DataFrame: + """Load dataset from HuggingFace Hub.""" + try: + from datasets import load_dataset + + # Simple dataset loading + dataset = load_dataset(config.source_path, split="train") + df = dataset.to_pandas() + + logger.info(f"Loaded HuggingFace dataset: {config.source_path}") + return df + + except ImportError: + raise ImportError("datasets library required. Install with: pip install datasets") + except Exception as e: + raise RuntimeError(f"Failed to load dataset {config.source_path}: {e}") + + +def _load_from_local(config: DataSourceConfig) -> pd.DataFrame: + """Load dataset from local file.""" + try: + file_path = config.source_path + + if file_path.endswith(".csv"): + df = pd.read_csv(file_path) + elif file_path.endswith(".json"): + df = pd.read_json(file_path) + else: + # Try CSV as fallback + df = pd.read_csv(file_path) + + logger.info(f"Loaded local file: {file_path}") + return df + + except Exception as e: + raise RuntimeError(f"Failed to load file {config.source_path}: {e}") \ No newline at end of file diff --git a/examples/prompt_optimization/steps/prompt_optimization.py b/examples/prompt_optimization/steps/prompt_optimization.py new file mode 100644 index 00000000000..e946c677081 --- /dev/null +++ b/examples/prompt_optimization/steps/prompt_optimization.py @@ -0,0 +1,170 @@ +"""Simple prompt optimization step for demonstrating ZenML artifact management.""" + +import time +from typing import Annotated + +import pandas as pd +from models import AgentConfig, EDAReport +from pydantic_ai import Agent +from pydantic_ai.settings import ModelSettings +from steps.agent_tools import AGENT_TOOLS, AnalystAgentDeps + +from zenml import ArtifactConfig, ExternalArtifact, Tag, add_tags, step +from zenml.client import Client +from zenml.logger import get_logger + +logger = get_logger(__name__) + + +@step +def compare_prompts_and_tag_best( + dataset_df: pd.DataFrame, + prompt_variants: list[str], + agent_config: AgentConfig = None, +) -> Annotated[str, ArtifactConfig(name="best_prompt")]: + """Compare prompt variants and tag the best one with exclusive 'optimized' tag. + + This step demonstrates ZenML's artifact management by: + 1. Testing multiple prompt variants + 2. Finding the best performer + 3. Returning it as a tagged artifact that other pipelines can find + + The 'optimized' tag is exclusive, so only one prompt can be 'optimized' at a time. + + Args: + dataset_df: Dataset to test prompts against + prompt_variants: List of system prompts to compare + agent_config: Configuration for AI agents + + Returns: + The best performing prompt string with exclusive 'optimized' tag + """ + if agent_config is None: + agent_config = AgentConfig() + + logger.info(f"๐Ÿงช Testing {len(prompt_variants)} prompt variants") + + results = [] + + for i, system_prompt in enumerate(prompt_variants): + prompt_id = f"variant_{i + 1}" + logger.info(f"Testing {prompt_id}...") + + start_time = time.time() + + try: + # Create agent with this prompt + deps = AnalystAgentDeps() + main_ref = deps.store(dataset_df) + + agent = Agent( + f"openai:{agent_config.model_name}", + deps_type=AnalystAgentDeps, + output_type=EDAReport, + system_prompt=system_prompt, + model_settings=ModelSettings(parallel_tool_calls=False), + ) + + for tool in AGENT_TOOLS: + agent.tool(tool) + + # Run analysis + user_prompt = f"Analyze dataset '{main_ref}' - focus on data quality and key insights." + result = agent.run_sync(user_prompt, deps=deps) + eda_report = result.output + + execution_time = time.time() - start_time + + # Score this variant + score = ( + eda_report.data_quality_score * 0.7 + # Primary metric + (100 - min(execution_time * 2, 100)) * 0.2 + # Speed bonus + len(eda_report.key_findings) * 5 * 0.1 # Thoroughness + ) + + results.append({ + "prompt_id": prompt_id, + "prompt": system_prompt, + "score": score, + "quality_score": eda_report.data_quality_score, + "execution_time": execution_time, + "findings_count": len(eda_report.key_findings), + "success": True, + }) + + logger.info(f"โœ… {prompt_id}: score={score:.1f}, time={execution_time:.1f}s") + + except Exception as e: + logger.warning(f"โŒ {prompt_id} failed: {e}") + results.append({ + "prompt_id": prompt_id, + "prompt": system_prompt, + "score": 0, + "success": False, + "error": str(e), + }) + + # Find best performer + successful_results = [r for r in results if r["success"]] + + if not successful_results: + logger.warning("All prompts failed, using first as fallback") + best_prompt = prompt_variants[0] + else: + best_result = max(successful_results, key=lambda x: x["score"]) + best_prompt = best_result["prompt"] + logger.info(f"๐Ÿ† Best prompt: {best_result['prompt_id']} (score: {best_result['score']:.1f})") + + logger.info("๐Ÿ’พ Best prompt will be stored with exclusive 'optimized' tag") + + # Add exclusive tag to this step's output artifact + add_tags( + tags=[Tag(name="optimized", exclusive=True)], + infer_artifact=True + ) + + return best_prompt + + +def get_optimized_prompt(): + """Retrieve the optimized prompt using ZenML's tag-based filtering. + + This demonstrates ZenML's tag filtering by finding artifacts tagged with 'optimized'. + Since 'optimized' is an exclusive tag, there will be at most one such artifact. + + Returns: + The optimized prompt from the latest optimization run, or default if none found + """ + try: + client = Client() + + # Find artifacts tagged with 'optimized' (our exclusive tag) + artifacts = client.list_artifact_versions( + tags=["optimized"], + size=1 + ) + + if artifacts.items: + optimized_prompt = artifacts.items[0] + logger.info(f"๐ŸŽฏ Retrieved optimized prompt from artifact: {optimized_prompt.id}") + logger.info(f" Artifact created: {optimized_prompt.created}") + return optimized_prompt + else: + logger.info("๐Ÿ” No optimized prompt found (no artifacts with 'optimized' tag)") + + except Exception as e: + logger.warning(f"Failed to retrieve optimized prompt: {e}") + + # Fallback to default prompt if no optimization artifacts found + logger.info("๐Ÿ“ Using default prompt (run optimization pipeline first)") + default_prompt = """You are a data analyst. Perform comprehensive EDA analysis. + +FOCUS ON: +- Calculate data quality score (0-100) based on completeness and consistency +- Identify missing data patterns and duplicates +- Find key correlations and patterns +- Provide 3-5 actionable recommendations + +Be thorough but efficient.""" + + return ExternalArtifact(value=default_prompt) \ No newline at end of file diff --git a/examples/pydantic_ai_eda/README.md b/examples/pydantic_ai_eda/README.md deleted file mode 100644 index 59835b1f02a..00000000000 --- a/examples/pydantic_ai_eda/README.md +++ /dev/null @@ -1,80 +0,0 @@ -# Pydantic AI EDA Pipeline - -AI-powered Exploratory Data Analysis pipeline using **ZenML** and **Pydantic AI**. Automatically analyzes datasets, generates reports, and makes quality decisions for downstream processing. - -## Architecture - -``` -ingest_data โ†’ run_eda_agent โ†’ evaluate_quality_gate_with_routing -``` - -## Features - -- ๐Ÿค– **AI-Powered Analysis** with GPT-4/Claude -- ๐Ÿ“Š **SQL-Based EDA** through DuckDB with safety guards -- โœ… **Quality Gates** with configurable thresholds -- ๐ŸŒ **Multiple Data Sources** (HuggingFace, local files, warehouses) -- ๐Ÿ“ˆ **Comprehensive Reporting** (JSON/markdown) - -## Quick Start - -```bash -# Install -pip install "zenml[server]" && zenml init -cd zenml/examples/pydantic_ai_eda && pip install -r requirements.txt - -# Set API key -export OPENAI_API_KEY="your-openai-key" # or ANTHROPIC_API_KEY - -# Run examples -python run.py -python run_prompt_experiment.py -``` - -## Usage - -```python -from models import DataSourceConfig, AgentConfig -from pipelines.eda_pipeline import eda_pipeline - -# EDA Analysis -source_config = DataSourceConfig( - source_type="local", - source_path="iris_dataset.csv", - target_column="species" -) - -results = eda_pipeline(source_config=source_config) - -# Prompt Experimentation -from pipelines.prompt_experiment_pipeline import prompt_experiment_pipeline - -prompts = ["Analyze this data", "Provide detailed insights"] -experiment = prompt_experiment_pipeline( - source_config=source_config, - prompt_variants=prompts -) -``` - -## Output - -The pipeline generates: -- **EDA Report**: Statistical analysis, correlations, missing data patterns, quality score (0-100) -- **Quality Gate**: Pass/fail decision with recommendations -- **Remediation**: Actionable code snippets for data issues - -## Security & Production - -- **SQL Safety**: Only SELECT/WITH queries allowed, auto-LIMIT injection -- **Remote Orchestration**: Kubernetes, S3, ECR support via ZenML stacks -- **Monitoring**: Pipeline tracking, quality alerts, cost monitoring - -## Troubleshooting - -```bash -# Common fixes -pip install duckdb>=1.0.0 pydantic-ai>=0.0.13 -export OPENAI_API_KEY="your-key" - -# Performance: use gpt-4o-mini, reduce sample_size for large datasets -``` \ No newline at end of file diff --git a/examples/pydantic_ai_eda/models.py b/examples/pydantic_ai_eda/models.py deleted file mode 100644 index e17f0087d4f..00000000000 --- a/examples/pydantic_ai_eda/models.py +++ /dev/null @@ -1,137 +0,0 @@ -"""Simple data models for Pydantic AI EDA pipeline.""" - -from typing import Any, Dict, List, Optional - -from pydantic import BaseModel, Field - - -class DataSourceConfig(BaseModel): - """Simple data source configuration.""" - - source_type: str = Field( - description="Data source type: hf, local, or warehouse" - ) - source_path: str = Field( - description="Path or identifier for the data source" - ) - target_column: Optional[str] = Field( - None, description="Optional target column name" - ) - sample_size: Optional[int] = Field( - None, description="Number of rows to sample" - ) - - -class DataQualityFix(BaseModel): - """Model for data quality issues and suggested fixes. - - Represents a specific data quality problem identified during analysis - along with recommended remediation actions. - - Attributes: - title: Short description of the issue - rationale: Explanation of why this is a problem - severity: Impact level (low, medium, high, critical) - code_snippet: Optional code to address the issue - affected_columns: Columns affected by this issue - estimated_impact: Estimated impact on data quality (0-1) - """ - - title: str = Field( - description="Short description of the data quality issue" - ) - rationale: str = Field(description="Explanation of why this is a problem") - severity: str = Field( - description="Severity level: low, medium, high, critical" - ) - code_snippet: Optional[str] = Field( - None, description="Code to fix the issue" - ) - affected_columns: List[str] = Field( - description="Columns affected by this issue" - ) - estimated_impact: float = Field( - description="Estimated impact on data quality (0-1)" - ) - - -class EDAReport(BaseModel): - """Model for comprehensive EDA report results. - - Contains the structured output from Pydantic AI analysis including - findings, quality assessment, and recommendations. - - Attributes: - headline: Executive summary of key findings - key_findings: List of important discoveries about the data - risks: Potential data quality or analysis risks identified - fixes: Recommended fixes for data quality issues - data_quality_score: Overall data quality score (0-100) - markdown: Full markdown report for human consumption - column_profiles: Statistical profiles for each column - correlation_insights: Key correlation findings - missing_data_analysis: Analysis of missing data patterns - """ - - headline: str = Field(description="Executive summary of key findings") - key_findings: List[str] = Field( - default_factory=list, - description="Important discoveries about the data", - ) - risks: List[str] = Field( - default_factory=list, - description="Potential risks identified in the data", - ) - fixes: List[DataQualityFix] = Field( - default_factory=list, description="Recommended data quality fixes" - ) - data_quality_score: float = Field( - description="Overall quality score (0-100)" - ) - markdown: str = Field(description="Full markdown report") - column_profiles: Optional[Dict[str, Dict[str, Any]]] = Field( - default_factory=dict, description="Statistical profiles per column" - ) - correlation_insights: List[str] = Field( - default_factory=list, description="Key correlation findings" - ) - missing_data_analysis: Optional[Dict[str, Any]] = Field( - default_factory=dict, description="Missing data patterns" - ) - - -class QualityGateDecision(BaseModel): - """Model for quality gate decision results. - - Represents the outcome of evaluating whether data quality meets - requirements for downstream processing or model training. - - Attributes: - passed: Whether the quality gate check passed - quality_score: The computed data quality score - decision_reason: Explanation for the pass/fail decision - blocking_issues: Issues that caused failure (if failed) - recommendations: Suggested next steps - metadata: Additional decision metadata - """ - - passed: bool = Field(description="Whether quality gate passed") - quality_score: float = Field(description="Computed data quality score") - decision_reason: str = Field(description="Explanation for the decision") - blocking_issues: List[str] = Field( - description="Issues that caused failure" - ) - recommendations: List[str] = Field(description="Recommended next steps") - metadata: Dict[str, Any] = Field( - default_factory=dict, description="Additional metadata" - ) - - -class AgentConfig(BaseModel): - """Simple configuration for Pydantic AI agent.""" - - model_name: str = Field("gpt-4o-mini", description="Language model to use") - max_tool_calls: int = Field(6, description="Maximum tool calls allowed") - timeout_seconds: int = Field( - 60, description="Max execution time in seconds" - ) diff --git a/examples/pydantic_ai_eda/pipelines/__init__.py b/examples/pydantic_ai_eda/pipelines/__init__.py deleted file mode 100644 index 11952e656bc..00000000000 --- a/examples/pydantic_ai_eda/pipelines/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -"""ZenML pipelines for Pydantic AI EDA workflows. - -This module contains pipeline definitions: - -- eda_pipeline.py: Complete EDA pipeline with AI analysis and quality gates -- prompt_experiment_pipeline.py: A/B testing pipeline for agent prompt optimization -""" - -from .eda_pipeline import eda_pipeline -from .prompt_experiment_pipeline import prompt_experiment_pipeline - -__all__ = [ - "eda_pipeline", - "prompt_experiment_pipeline", -] \ No newline at end of file diff --git a/examples/pydantic_ai_eda/pipelines/eda_pipeline.py b/examples/pydantic_ai_eda/pipelines/eda_pipeline.py deleted file mode 100644 index bf18c99954a..00000000000 --- a/examples/pydantic_ai_eda/pipelines/eda_pipeline.py +++ /dev/null @@ -1,98 +0,0 @@ -"""EDA pipeline using Pydantic AI for automated data analysis. - -This pipeline orchestrates the complete EDA workflow: -1. Data ingestion from various sources -2. AI-powered EDA analysis with Pydantic AI -3. Quality gate evaluation for pipeline routing -""" - -from typing import Any, Dict, Optional - -from models import AgentConfig, DataSourceConfig -from steps import ( - evaluate_quality_gate_with_routing, - ingest_data, - run_eda_agent, -) - -from zenml import pipeline -from zenml.logger import get_logger - -logger = get_logger(__name__) - - -@pipeline -def eda_pipeline( - source_config: DataSourceConfig, - agent_config: Optional[AgentConfig] = None, - min_quality_score: float = 70.0, - block_on_high_severity: bool = True, - max_missing_data_pct: float = 30.0, - require_target_column: bool = False, -) -> Dict[str, Any]: - """Complete EDA pipeline with AI-powered analysis and quality gating. - - Performs end-to-end exploratory data analysis using Pydantic AI, - from data ingestion through quality assessment and routing decisions. - - Args: - source_config: Configuration for data source (HuggingFace/local/warehouse) - agent_config: Configuration for Pydantic AI agent behavior - min_quality_score: Minimum quality score for passing quality gate - block_on_high_severity: Whether high-severity issues block the pipeline - max_missing_data_pct: Maximum allowable missing data percentage - require_target_column: Whether to require a target column for analysis - - Returns: - Dictionary containing all pipeline outputs and routing decisions - """ - logger.info( - f"Starting EDA pipeline for {source_config.source_type}:{source_config.source_path}" - ) - - # Step 1: Ingest data from configured source - raw_df, ingestion_metadata = ingest_data(source_config=source_config) - - # Step 2: Run AI-powered EDA analysis - report_markdown, report_json, sql_log, analysis_tables = run_eda_agent( - dataset_df=raw_df, - dataset_metadata=ingestion_metadata, - agent_config=agent_config, - ) - - # Step 3: Evaluate data quality gate and get routing decision - quality_decision, routing_message = evaluate_quality_gate_with_routing( - report_json=report_json, - min_quality_score=min_quality_score, - block_on_high_severity=block_on_high_severity, - max_missing_data_pct=max_missing_data_pct, - require_target_column=require_target_column, - target_column=source_config.target_column, - ) - - # Log pipeline summary (note: artifacts are returned, actual values logged in steps) - logger.info("Pipeline steps completed successfully") - logger.info("Check step outputs for detailed analysis results") - - # Return comprehensive results - return { - # Core analysis outputs - "report_markdown": report_markdown, - "report_json": report_json, - "analysis_tables": analysis_tables, - "sql_log": sql_log, - # Graph visualization removed - # Quality assessment - "quality_decision": quality_decision, - "routing_message": routing_message, - # Pipeline metadata - "source_config": source_config, - "ingestion_metadata": ingestion_metadata, - "agent_config": agent_config, - # Summary metrics (basic info only, artifacts available separately) - "pipeline_summary": { - "data_source": f"{source_config.source_type}:{source_config.source_path}", - "target_column": source_config.target_column, - "timestamp": ingestion_metadata, # This will be the artifact - }, - } diff --git a/examples/pydantic_ai_eda/pipelines/prompt_experiment_pipeline.py b/examples/pydantic_ai_eda/pipelines/prompt_experiment_pipeline.py deleted file mode 100644 index 73ec28558c2..00000000000 --- a/examples/pydantic_ai_eda/pipelines/prompt_experiment_pipeline.py +++ /dev/null @@ -1,61 +0,0 @@ -"""Pipeline for experimenting with Pydantic AI agent prompts.""" - -from typing import Any, Dict, List, Optional - -from models import AgentConfig, DataSourceConfig -from steps import compare_agent_prompts, ingest_data - -from zenml import pipeline -from zenml.logger import get_logger - -logger = get_logger(__name__) - - -@pipeline -def prompt_experiment_pipeline( - source_config: DataSourceConfig, - prompt_variants: List[str], - agent_config: Optional[AgentConfig] = None, -) -> Dict[str, Any]: - """Pipeline for A/B testing Pydantic AI agent prompts. - - This pipeline helps developers optimize their agent prompts by testing - multiple variants on the same dataset and comparing performance metrics. - - Args: - source_config: Data source to test prompts against - prompt_variants: List of system prompts to compare - agent_config: Configuration for agent behavior during testing - - Returns: - Comprehensive comparison results with recommendations - """ - logger.info( - f"๐Ÿงช Starting prompt experiment with {len(prompt_variants)} variants" - ) - - # Step 1: Load the test dataset - dataset_df, ingestion_metadata = ingest_data(source_config=source_config) - - # Step 2: Run prompt comparison experiment - experiment_results = compare_agent_prompts( - dataset_df=dataset_df, - prompt_variants=prompt_variants, - agent_config=agent_config, - ) - - logger.info( - "โœ… Prompt experiment completed - check results for best performing variant" - ) - - return { - "experiment_results": experiment_results, - "dataset_metadata": ingestion_metadata, - "test_config": { - "source": f"{source_config.source_type}:{source_config.source_path}", - "prompt_count": len(prompt_variants), - "agent_config": agent_config.model_dump() - if agent_config - else None, - }, - } diff --git a/examples/pydantic_ai_eda/run.py b/examples/pydantic_ai_eda/run.py deleted file mode 100644 index efec12111c3..00000000000 --- a/examples/pydantic_ai_eda/run.py +++ /dev/null @@ -1,59 +0,0 @@ -#!/usr/bin/env python3 -"""Simple Pydantic AI EDA pipeline runner.""" - -import os - -from models import AgentConfig, DataSourceConfig -from pipelines.eda_pipeline import eda_pipeline - - -def main(): - """Run the EDA pipeline with simple configuration.""" - print("๐Ÿ” Pydantic AI EDA Pipeline") - print("=" * 30) - - # Check for API keys - has_openai = bool(os.getenv("OPENAI_API_KEY")) - has_anthropic = bool(os.getenv("ANTHROPIC_API_KEY")) - - if not (has_openai or has_anthropic): - print("โŒ No API keys found!") - print("Set OPENAI_API_KEY or ANTHROPIC_API_KEY environment variable") - return - - model_name = "gpt-4o-mini" if has_openai else "claude-3-haiku-20240307" - print(f"๐Ÿค– Using model: {model_name}") - - # Simple configuration - source_config = DataSourceConfig( - source_type="hf", - source_path="scikit-learn/iris", - target_column="target", - ) - - agent_config = AgentConfig( - model_name=model_name, - max_tool_calls=6, # Keep it snappy - just the essentials - timeout_seconds=60, # Quick analysis - ) - - print(f"๐Ÿ“Š Analyzing: {source_config.source_path}") - - try: - results = eda_pipeline.with_options(enable_cache=False)( - source_config=source_config, - agent_config=agent_config, - min_quality_score=70.0, - ) - print("โœ… Pipeline completed! Check ZenML dashboard for results.") - return results - except Exception as e: - print(f"โŒ Pipeline failed: {e}") - print("\nTroubleshooting:") - print("- Check your API key is valid") - print("- Ensure ZenML is initialized: zenml init") - print("- Install requirements: pip install -r requirements.txt") - - -if __name__ == "__main__": - main() diff --git a/examples/pydantic_ai_eda/run_prompt_experiment.py b/examples/pydantic_ai_eda/run_prompt_experiment.py deleted file mode 100644 index c724f4e2182..00000000000 --- a/examples/pydantic_ai_eda/run_prompt_experiment.py +++ /dev/null @@ -1,120 +0,0 @@ -#!/usr/bin/env python3 -"""Run prompt experimentation pipeline to optimize Pydantic AI agents.""" - -import os - -from models import AgentConfig, DataSourceConfig -from pipelines.prompt_experiment_pipeline import prompt_experiment_pipeline - - -def main(): - """Run prompt A/B testing to find the best system prompt.""" - print("๐Ÿงช Pydantic AI Prompt Experimentation") - print("=" * 40) - - # Check for API keys - has_openai = bool(os.getenv("OPENAI_API_KEY")) - has_anthropic = bool(os.getenv("ANTHROPIC_API_KEY")) - - if not (has_openai or has_anthropic): - print("โŒ No API keys found!") - print("Set OPENAI_API_KEY or ANTHROPIC_API_KEY environment variable") - return - - model_name = "gpt-4o-mini" if has_openai else "claude-3-haiku-20240307" - print(f"๐Ÿค– Using model: {model_name}") - - # Dataset configuration - source_config = DataSourceConfig( - source_type="hf", - source_path="scikit-learn/iris", - target_column="target", - ) - - agent_config = AgentConfig( - model_name=model_name, - max_tool_calls=4, # Reduce for faster testing - timeout_seconds=30, # Shorter timeout to avoid stalling - ) - - # Define prompt variants to test (simplified for speed) - prompt_variants = [ - # Variant 1: Concise - """You are a data analyst. Analyze the dataset quickly - focus on data quality score and key findings. Be concise.""", - # Variant 2: Quality-focused - """You are a data quality specialist. Calculate data quality score, identify missing data and duplicates. Provide specific recommendations.""", - # Variant 3: Business-oriented - """You are a business analyst. Is this data ready for ML? Provide go/no-go recommendation with quality score and business impact.""", - ] - - print( - f"๐Ÿ“Š Testing {len(prompt_variants)} prompt variants on: {source_config.source_path}" - ) - print( - "This will help identify the best performing prompt for your use case.\n" - ) - - try: - pipeline_run = prompt_experiment_pipeline( - source_config=source_config, - prompt_variants=prompt_variants, - agent_config=agent_config, - ) - - # Extract results from ZenML pipeline artifacts - print("๐Ÿ“ˆ EXPERIMENT RESULTS") - print("=" * 25) - print("โœ… Pipeline completed successfully!") - - # Get the artifact from the pipeline run - run_metadata = pipeline_run.dict() - print(f"๐Ÿ” Pipeline run ID: {pipeline_run.id}") - print(f"๐Ÿ“Š Check ZenML dashboard for detailed experiment results") - print(f"๐Ÿ† Results are stored as pipeline artifacts") - - # Try to access the step outputs - try: - step_names = list(pipeline_run.steps.keys()) - print(f"๐Ÿ“‹ Pipeline steps: {step_names}") - - if "compare_agent_prompts" in step_names: - step_output = pipeline_run.steps["compare_agent_prompts"] - print(f"๐ŸŽฏ Experiment data available in step outputs") - - # Try to load the actual results - outputs = step_output.outputs - if "prompt_comparison_results" in outputs: - experiment_data = outputs[ - "prompt_comparison_results" - ].load() - summary = experiment_data["experiment_summary"] - - print( - f"โœ… Successful runs: {summary['successful_runs']}/{summary['total_prompts_tested']}" - ) - print(f"๐Ÿ† Best prompt: {summary['best_prompt_variant']}") - print(f"โฑ๏ธ Average time: {summary['avg_execution_time']}s") - - print("\n๐Ÿ’ก RECOMMENDATIONS:") - for rec in experiment_data["recommendations"]: - print(f" โ€ข {rec}") - - except Exception as e: - print(f"โš ๏ธ Could not extract detailed results: {e}") - print("Check ZenML dashboard for full experiment analysis") - - print( - f"\nโœ… Prompt experiment completed! Check ZenML dashboard for detailed results." - ) - return pipeline_run - - except Exception as e: - print(f"โŒ Experiment failed: {e}") - print("\nTroubleshooting:") - print("- Check your API key is valid") - print("- Ensure ZenML is initialized: zenml init") - print("- Install requirements: pip install -r requirements.txt") - - -if __name__ == "__main__": - main() diff --git a/examples/pydantic_ai_eda/steps/__init__.py b/examples/pydantic_ai_eda/steps/__init__.py deleted file mode 100644 index 7be92fe9ac5..00000000000 --- a/examples/pydantic_ai_eda/steps/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -"""ZenML steps for Pydantic AI EDA pipeline. - -This module contains all the step functions used in the EDA pipeline: - -- ingest.py: Data ingestion from multiple sources (HF, local, warehouse) -- snapshot.py: Data snapshot creation with optional masking -- agent_tools.py: Pydantic AI agent tools and dependencies -- eda_agent.py: AI-powered EDA analysis step -- quality_gate.py: Data quality assessment and routing steps -""" - -from .eda_agent import run_eda_agent -from .ingest import ingest_data -from .prompt_experiment import compare_agent_prompts, evaluate_prompts_with_test_cases -from .quality_gate import evaluate_quality_gate, evaluate_quality_gate_with_routing - -__all__ = [ - "ingest_data", - "run_eda_agent", - "compare_agent_prompts", - "evaluate_prompts_with_test_cases", - "evaluate_quality_gate", - "evaluate_quality_gate_with_routing", -] \ No newline at end of file diff --git a/examples/pydantic_ai_eda/steps/ingest.py b/examples/pydantic_ai_eda/steps/ingest.py deleted file mode 100644 index 77352b5d6db..00000000000 --- a/examples/pydantic_ai_eda/steps/ingest.py +++ /dev/null @@ -1,374 +0,0 @@ -"""Data ingestion step for EDA pipeline. - -Supports loading data from multiple sources including HuggingFace datasets, -local files, and data warehouse connections. -""" - -import hashlib -import logging -from typing import Annotated, Any, Dict, Tuple - -import pandas as pd -from models import DataSourceConfig - -from zenml import step - -logger = logging.getLogger(__name__) - - -@step -def ingest_data( - source_config: DataSourceConfig, -) -> Tuple[ - Annotated[pd.DataFrame, "dataset"], - Annotated[Dict[str, Any], "ingestion_metadata"], -]: - """Ingest data from configured source. - - Loads data from HuggingFace, local files, or warehouse based on - source configuration. Returns both the DataFrame and metadata. - - Args: - source_config: Configuration specifying data source and parameters - - Returns: - Tuple of (raw_df, metadata) where metadata contains schema info, - row count, and content hash for traceability - """ - logger.info( - f"Ingesting data from {source_config.source_type}: {source_config.source_path}" - ) - - # Load data based on source type - if source_config.source_type == "hf": - df = _load_from_huggingface(source_config) - elif source_config.source_type == "local": - df = _load_from_local(source_config) - elif source_config.source_type == "warehouse": - df = _load_from_warehouse(source_config) - else: - raise ValueError( - f"Unsupported source type: {source_config.source_type}" - ) - - # Apply sampling if configured - if source_config.sample_size and len(df) > source_config.sample_size: - if source_config.sampling_strategy == "random": - df = df.sample(n=source_config.sample_size, random_state=42) - elif source_config.sampling_strategy == "first_n": - df = df.head(source_config.sample_size) - elif ( - source_config.sampling_strategy == "stratified" - and source_config.target_column - ): - df = _stratified_sample( - df, source_config.target_column, source_config.sample_size - ) - else: - logger.warning( - f"Unknown sampling strategy: {source_config.sampling_strategy}" - ) - df = df.sample(n=source_config.sample_size, random_state=42) - - # Note: For basic datasets like iris, pandas DataFrames should work fine with ZenML - # The dtype warnings come from ZenML's internal serialization of DataFrame metadata, - # not the actual data. This is normal for pandas DataFrames in ZenML. - - # Generate metadata - metadata = { - "source_type": source_config.source_type, - "source_path": source_config.source_path, - "original_rows": len(df), - "columns": len(df.columns), - "column_names": df.columns.tolist(), - "dtypes": df.dtypes.to_dict(), - "target_column": source_config.target_column, - "content_hash": _compute_content_hash(df), - "memory_usage_mb": df.memory_usage(deep=True).sum() / 1024 / 1024, - } - - # Add target column statistics if specified - if ( - source_config.target_column - and source_config.target_column in df.columns - ): - metadata["target_value_counts"] = ( - df[source_config.target_column].value_counts().to_dict() - ) - metadata["target_null_count"] = ( - df[source_config.target_column].isnull().sum() - ) - - logger.info( - f"Loaded dataset with {len(df)} rows and {len(df.columns)} columns" - ) - - return df, metadata - - -def _load_from_huggingface(config: DataSourceConfig) -> pd.DataFrame: - """Load dataset from HuggingFace Hub.""" - try: - from datasets import load_dataset - - # Parse dataset path (may include config/split) - parts = config.source_path.split("/") - if len(parts) >= 2: - dataset_name = "/".join(parts[:2]) - subset = parts[2] if len(parts) > 2 else None - else: - dataset_name = config.source_path - subset = None - - # Load dataset - dataset = load_dataset(dataset_name, subset, split="train") - df = dataset.to_pandas() - - logger.info(f"Loaded HuggingFace dataset: {dataset_name}") - return df - - except ImportError: - raise ImportError( - "datasets library required for HuggingFace loading. Install with: pip install datasets" - ) - except Exception as e: - raise RuntimeError( - f"Failed to load HuggingFace dataset {config.source_path}: {e}" - ) - - -def _load_from_local(config: DataSourceConfig) -> pd.DataFrame: - """Load dataset from local file.""" - try: - file_path = config.source_path - - if file_path.endswith(".csv"): - df = pd.read_csv(file_path) - elif file_path.endswith(".parquet"): - df = pd.read_parquet(file_path) - elif file_path.endswith(".json"): - df = pd.read_json(file_path) - elif file_path.endswith((".xlsx", ".xls")): - df = pd.read_excel(file_path) - else: - # Try CSV as fallback - df = pd.read_csv(file_path) - - logger.info(f"Loaded local file: {file_path}") - return df - - except Exception as e: - raise RuntimeError( - f"Failed to load local file {config.source_path}: {e}" - ) - - -def _load_from_warehouse(config: DataSourceConfig) -> pd.DataFrame: - """Load dataset from data warehouse connection.""" - try: - warehouse_config = config.warehouse_config or {} - - if "connection_string" in warehouse_config: - # Generic SQL connection - import sqlalchemy - - engine = sqlalchemy.create_engine( - warehouse_config["connection_string"] - ) - df = pd.read_sql(config.source_path, engine) - elif "type" in warehouse_config: - # Specific warehouse type - warehouse_type = warehouse_config["type"].lower() - - if warehouse_type == "bigquery": - df = _load_from_bigquery(config.source_path, warehouse_config) - elif warehouse_type == "snowflake": - df = _load_from_snowflake(config.source_path, warehouse_config) - elif warehouse_type == "redshift": - df = _load_from_redshift(config.source_path, warehouse_config) - else: - raise ValueError( - f"Unsupported warehouse type: {warehouse_type}" - ) - else: - raise ValueError( - "Warehouse config must specify connection_string or type" - ) - - logger.info(f"Loaded from warehouse: {config.source_path}") - return df - - except Exception as e: - raise RuntimeError( - f"Failed to load from warehouse {config.source_path}: {e}" - ) - - -def _load_from_bigquery( - table_path: str, config: Dict[str, Any] -) -> pd.DataFrame: - """Load data from Google BigQuery.""" - try: - import pandas_gbq - - project_id = config.get("project_id") - credentials = config.get("credentials_path") - - if table_path.startswith("SELECT"): - # It's a query - df = pandas_gbq.read_gbq( - table_path, project_id=project_id, credentials=credentials - ) - else: - # It's a table reference - query = f"SELECT * FROM `{table_path}`" - df = pandas_gbq.read_gbq( - query, project_id=project_id, credentials=credentials - ) - - return df - - except ImportError: - raise ImportError( - "pandas-gbq required for BigQuery. Install with: pip install pandas-gbq" - ) - - -def _load_from_snowflake( - table_path: str, config: Dict[str, Any] -) -> pd.DataFrame: - """Load data from Snowflake.""" - try: - import snowflake.connector - from snowflake.connector.pandas_tools import pd_writer - - conn = snowflake.connector.connect( - user=config["user"], - password=config["password"], - account=config["account"], - warehouse=config.get("warehouse"), - database=config.get("database"), - schema=config.get("schema"), - ) - - if table_path.upper().startswith("SELECT"): - query = table_path - else: - query = f"SELECT * FROM {table_path}" - - df = pd.read_sql(query, conn) - conn.close() - - return df - - except ImportError: - raise ImportError( - "snowflake-connector-python required. Install with: pip install snowflake-connector-python" - ) - - -def _load_from_redshift( - table_path: str, config: Dict[str, Any] -) -> pd.DataFrame: - """Load data from Amazon Redshift.""" - try: - import psycopg2 - import sqlalchemy - - connection_string = f"postgresql://{config['user']}:{config['password']}@{config['host']}:{config.get('port', 5439)}/{config['database']}" - engine = sqlalchemy.create_engine(connection_string) - - if table_path.upper().startswith("SELECT"): - query = table_path - else: - query = f"SELECT * FROM {table_path}" - - df = pd.read_sql(query, engine) - - return df - - except ImportError: - raise ImportError( - "psycopg2 required for Redshift. Install with: pip install psycopg2-binary" - ) - - -def _stratified_sample( - df: pd.DataFrame, target_column: str, sample_size: int -) -> pd.DataFrame: - """Perform stratified sampling based on target column.""" - try: - # Calculate proportional sample sizes for each class - value_counts = df[target_column].value_counts() - proportions = value_counts / len(df) - - sampled_dfs = [] - remaining_samples = sample_size - - for value, proportion in proportions.items(): - if remaining_samples <= 0: - break - - # Calculate sample size for this class - class_sample_size = max(1, int(proportion * sample_size)) - class_sample_size = min( - class_sample_size, remaining_samples, value_counts[value] - ) - - # Sample from this class - class_df = df[df[target_column] == value].sample( - n=class_sample_size, random_state=42 - ) - sampled_dfs.append(class_df) - - remaining_samples -= class_sample_size - - # Combine all samples - result_df = pd.concat(sampled_dfs, ignore_index=True) - - # Shuffle the final result - result_df = result_df.sample(frac=1, random_state=42).reset_index( - drop=True - ) - - logger.info( - f"Stratified sampling: {len(result_df)} samples across {len(sampled_dfs)} classes" - ) - return result_df - - except Exception as e: - logger.warning(f"Stratified sampling failed, using random: {e}") - return df.sample(n=sample_size, random_state=42) - - -def _compute_content_hash(df: pd.DataFrame) -> str: - """Compute hash of DataFrame content for change detection.""" - try: - # Create a string representation of the dataframe structure and sample - content_parts = [ - f"shape:{df.shape}", - f"columns:{sorted(df.columns.tolist())}", - f"dtypes:{sorted(df.dtypes.astype(str).tolist())}", - ] - - # Add sample of data if not too large - if len(df) <= 1000: - content_parts.append(f"data:{df.to_string()}") - else: - # Use a sample and summary stats - sample_df = ( - df.sample(n=100, random_state=42) if len(df) > 100 else df - ) - content_parts.extend( - [ - f"sample:{sample_df.to_string()}", - f"describe:{df.describe().to_string()}", - ] - ) - - content_str = "|".join(content_parts) - return hashlib.md5(content_str.encode()).hexdigest() - - except Exception as e: - logger.warning(f"Failed to compute content hash: {e}") - return f"error_{hashlib.md5(str(df.shape).encode()).hexdigest()}" diff --git a/examples/pydantic_ai_eda/steps/prompt_experiment.py b/examples/pydantic_ai_eda/steps/prompt_experiment.py deleted file mode 100644 index d66ae9b9e86..00000000000 --- a/examples/pydantic_ai_eda/steps/prompt_experiment.py +++ /dev/null @@ -1,766 +0,0 @@ -"""Advanced prompt experimentation step for Pydantic AI agent development.""" - -import json -import time -from pathlib import Path -from typing import Annotated, Any, Dict, List, Optional - -import pandas as pd -from models import AgentConfig, EDAReport -from pydantic_ai import Agent -from pydantic_ai.settings import ModelSettings -from steps.agent_tools import AGENT_TOOLS, AnalystAgentDeps - -from zenml import step -from zenml.logger import get_logger - -logger = get_logger(__name__) - - -@step -def evaluate_prompts_with_test_cases( - dataset_df: pd.DataFrame, - prompt_variants: List[str], - test_cases_path: str = "test_cases.json", - agent_config: AgentConfig = None, - use_llm_judge: bool = True, -) -> Annotated[Dict[str, Any], "comprehensive_prompt_evaluation"]: - """Advanced prompt evaluation using structured test cases and LLM judge. - - This step implements best practices from AI evaluation methodology: - - Structured test cases with categories - - LLM judge evaluation for quality assessment - - Tool usage tracking and analysis - - Statistical significance testing - - Args: - dataset_df: Dataset to test prompts against - prompt_variants: List of system prompts to compare - test_cases_path: Path to JSON file with structured test cases - agent_config: Base configuration for all agents - use_llm_judge: Whether to use LLM for response quality evaluation - - Returns: - Comprehensive evaluation results with quality scores and recommendations - """ - if agent_config is None: - agent_config = AgentConfig() - - # Load structured test cases - test_cases = _load_test_cases(test_cases_path) - if not test_cases: - # Fallback to simple evaluation if no test cases - return compare_agent_prompts(dataset_df, prompt_variants, agent_config) - - logger.info( - f"๐Ÿงช Running comprehensive evaluation: {len(prompt_variants)} prompts ร— {len(test_cases)} test cases" - ) - - # Initialize LLM judge if requested - llm_judge = None - if use_llm_judge: - llm_judge = _create_llm_judge(agent_config.model_name) - - all_results = [] - - # Test each prompt variant - for prompt_idx, system_prompt in enumerate(prompt_variants): - prompt_id = f"variant_{prompt_idx + 1}" - logger.info(f"Testing {prompt_id}/{len(prompt_variants)}") - - prompt_results = [] - - # Run each test case for this prompt - for test_case in test_cases: - case_result = _run_single_test_case( - dataset_df, - system_prompt, - test_case, - agent_config, - llm_judge, - prompt_id, - ) - prompt_results.append(case_result) - - # Aggregate results for this prompt - prompt_summary = _analyze_prompt_performance( - prompt_results, prompt_id, system_prompt - ) - all_results.append(prompt_summary) - - # Generate comprehensive comparison - final_analysis = _generate_comprehensive_analysis(all_results, test_cases) - - return final_analysis - - -@step -def compare_agent_prompts( - dataset_df: pd.DataFrame, - prompt_variants: List[str], - agent_config: AgentConfig = None, -) -> Annotated[Dict[str, Any], "prompt_comparison_results"]: - """Test multiple system prompts on the same dataset for agent optimization. - - This step helps developers iterate on agent prompts by running A/B tests - and comparing quality, performance, and output characteristics. - - Args: - dataset_df: Dataset to test prompts against - prompt_variants: List of system prompts to compare - agent_config: Base configuration for all agents - - Returns: - Comparison report with metrics for each prompt variant - """ - if agent_config is None: - agent_config = AgentConfig() - - logger.info( - f"๐Ÿงช Testing {len(prompt_variants)} prompt variants on dataset with {len(dataset_df)} rows" - ) - - results = [] - - for i, system_prompt in enumerate(prompt_variants): - prompt_id = f"variant_{i + 1}" - logger.info( - f"๐Ÿ”„ Testing prompt variant {i + 1}/{len(prompt_variants)}" - ) - print(f"๐Ÿ”„ Testing prompt variant {i + 1}/{len(prompt_variants)}") - - start_time = time.time() - - try: - # Initialize fresh dependencies for each test - deps = AnalystAgentDeps() - main_ref = deps.store(dataset_df) - - print(f" ๐Ÿค– Creating agent for variant {i + 1}...") - # Create agent with this prompt variant - test_agent = Agent( - f"openai:{agent_config.model_name}", - deps_type=AnalystAgentDeps, - output_type=EDAReport, - output_retries=3, - system_prompt=system_prompt, - model_settings=ModelSettings( - parallel_tool_calls=False - ), # Disable parallel to avoid hanging - ) - - # Register tools - print(f" ๐Ÿ”ง Registering {len(AGENT_TOOLS)} tools...") - for tool in AGENT_TOOLS: - test_agent.tool(tool) - - # Run analysis with consistent user prompt - user_prompt = f"""Analyze dataset '{main_ref}' ({dataset_df.shape[0]} rows, {dataset_df.shape[1]} cols). - -Focus on data quality, key patterns, and actionable insights.""" - - print(f" โšก Running analysis for variant {i + 1}...") - result = test_agent.run_sync(user_prompt, deps=deps) - eda_report = result.output - - execution_time = time.time() - start_time - - # Collect metrics for comparison - result_metrics = { - "prompt_id": prompt_id, - "system_prompt": system_prompt, - "success": True, - "execution_time_seconds": round(execution_time, 2), - "tool_calls_made": len(deps.query_history), - "data_quality_score": eda_report.data_quality_score, - "key_findings_count": len(eda_report.key_findings), - "risks_identified": len(eda_report.risks), - "fixes_suggested": len(eda_report.fixes), - "correlation_insights": len(eda_report.correlation_insights), - "headline_length": len(eda_report.headline), - "markdown_length": len(eda_report.markdown), - "tables_generated": len(deps.output) - - 1, # Exclude main dataset - "error": None, - } - - logger.info( - f"โœ… Variant {i + 1}: Score={eda_report.data_quality_score:.1f}, Tools={len(deps.query_history)}, Time={execution_time:.1f}s" - ) - - except Exception as e: - execution_time = time.time() - start_time - logger.warning(f"โŒ Variant {i + 1} failed: {str(e)}") - - result_metrics = { - "prompt_id": prompt_id, - "system_prompt": system_prompt, - "success": False, - "execution_time_seconds": round(execution_time, 2), - "tool_calls_made": 0, - "data_quality_score": 0.0, - "key_findings_count": 0, - "risks_identified": 0, - "fixes_suggested": 0, - "correlation_insights": 0, - "headline_length": 0, - "markdown_length": 0, - "tables_generated": 0, - "error": str(e), - } - - results.append(result_metrics) - - # Analyze results and determine best variant - successful_results = [r for r in results if r["success"]] - - if successful_results: - # Rank by composite score (quality + speed + thoroughness) - for result in successful_results: - # Composite score: 60% quality + 20% speed + 20% thoroughness - speed_score = max( - 0, 100 - result["execution_time_seconds"] * 10 - ) # Penalty for slowness - thoroughness_score = ( - result["key_findings_count"] * 10 - + result["risks_identified"] * 5 - + result["tables_generated"] * 5 - ) - - result["composite_score"] = ( - result["data_quality_score"] * 0.6 - + min(speed_score, 100) * 0.2 - + min(thoroughness_score, 100) * 0.2 - ) - - # Find best performer - best_result = max( - successful_results, key=lambda x: x["composite_score"] - ) - best_prompt_id = best_result["prompt_id"] - - logger.info( - f"๐Ÿ† Best performing prompt: {best_prompt_id} (score: {best_result['composite_score']:.1f})" - ) - else: - best_prompt_id = "none" - logger.warning("โŒ No prompts succeeded") - - # Generate summary comparison - summary = { - "experiment_summary": { - "total_prompts_tested": len(prompt_variants), - "successful_runs": len(successful_results), - "failed_runs": len(results) - len(successful_results), - "best_prompt_variant": best_prompt_id, - "avg_execution_time": round( - sum(r["execution_time_seconds"] for r in results) - / len(results), - 2, - ) - if results - else 0, - "dataset_info": { - "rows": len(dataset_df), - "columns": len(dataset_df.columns), - "column_names": list(dataset_df.columns), - }, - }, - "detailed_results": results, - "recommendations": _generate_prompt_recommendations(results), - } - - return summary - - -def _generate_prompt_recommendations( - results: List[Dict[str, Any]], -) -> List[str]: - """Generate recommendations based on prompt experiment results.""" - recommendations = [] - - successful_results = [r for r in results if r["success"]] - - if not successful_results: - return [ - "All prompts failed - check agent configuration and error messages" - ] - - # Performance analysis - avg_time = sum( - r["execution_time_seconds"] for r in successful_results - ) / len(successful_results) - avg_quality = sum( - r["data_quality_score"] for r in successful_results - ) / len(successful_results) - - if avg_time > 30: - recommendations.append( - "Consider shorter, more focused prompts to reduce execution time" - ) - - if avg_quality < 70: - recommendations.append( - "Consider more specific instructions for data quality assessment" - ) - - # Consistency analysis - quality_scores = [r["data_quality_score"] for r in successful_results] - quality_variance = max(quality_scores) - min(quality_scores) - - if quality_variance > 20: - recommendations.append( - "High variance in quality scores - consider more consistent prompt structure" - ) - else: - recommendations.append( - "Quality scores are consistent across prompts - good stability" - ) - - # Tool usage analysis - tool_calls = [r["tool_calls_made"] for r in successful_results] - if max(tool_calls) - min(tool_calls) > 3: - recommendations.append( - "Variable tool usage detected - consider standardizing analysis steps" - ) - - # Success rate analysis - success_rate = len(successful_results) / len(results) * 100 - if success_rate < 80: - recommendations.append( - "Low success rate - review prompt complexity and error handling" - ) - else: - recommendations.append( - f"Good success rate ({success_rate:.0f}%) - prompts are robust" - ) - - return recommendations - - -def _load_test_cases(test_cases_path: str) -> List[Dict[str, Any]]: - """Load structured test cases from JSON file.""" - try: - test_cases_file = Path(test_cases_path) - if not test_cases_file.exists(): - logger.warning(f"Test cases file not found: {test_cases_path}") - return [] - - with open(test_cases_file, "r") as f: - test_cases = json.load(f) - - logger.info( - f"Loaded {len(test_cases)} test cases from {test_cases_path}" - ) - return test_cases - - except Exception as e: - logger.error(f"Failed to load test cases: {e}") - return [] - - -def _create_llm_judge(base_model: str) -> Optional[Agent]: - """Create an LLM judge agent for evaluating responses.""" - try: - # Use a stronger model for evaluation if available - judge_model = ( - "gpt-4o" if "gpt" in base_model else "claude-3-5-sonnet-20241022" - ) - - judge_system_prompt = """You are an expert AI evaluator specializing in data analysis responses. - -Your job is to assess EDA (Exploratory Data Analysis) responses based on: - -1. **Accuracy** (1-5): Factual correctness and valid statistical insights -2. **Relevance** (1-5): How well the response addresses the specific query -3. **Completeness** (1-5): Whether all aspects of the query are covered -4. **Tool Usage** (1-5): Appropriate use of available analysis tools -5. **Actionability** (1-5): Quality of recommendations and insights - -Score each criterion from 1 (poor) to 5 (excellent). -Provide scores in JSON format: {"accuracy": X, "relevance": X, "completeness": X, "tool_usage": X, "actionability": X, "overall": X, "reasoning": "brief explanation"} - -Be objective and consistent in your evaluations.""" - - return Agent( - f"openai:{judge_model}" - if "gpt" in base_model - else f"anthropic:{judge_model}", - system_prompt=judge_system_prompt, - ) - - except Exception as e: - logger.warning(f"Failed to create LLM judge: {e}") - return None - - -def _run_single_test_case( - dataset_df: pd.DataFrame, - system_prompt: str, - test_case: Dict[str, Any], - agent_config: AgentConfig, - llm_judge: Optional[Agent], - prompt_id: str, -) -> Dict[str, Any]: - """Run a single test case and collect comprehensive metrics.""" - start_time = time.time() - - try: - # Initialize agent for this test - deps = AnalystAgentDeps() - main_ref = deps.store(dataset_df) - - test_agent = Agent( - f"openai:{agent_config.model_name}", - deps_type=AnalystAgentDeps, - output_type=EDAReport, - output_retries=3, - system_prompt=system_prompt, - model_settings=ModelSettings(parallel_tool_calls=True), - ) - - # Register tools - for tool in AGENT_TOOLS: - test_agent.tool(tool) - - # Run the test case query - full_query = f"Dataset reference: {main_ref}\n\n{test_case['query']}" - result = test_agent.run_sync(full_query, deps=deps) - eda_report = result.output - - execution_time = time.time() - start_time - - # Collect basic metrics - case_result = { - "test_id": test_case["id"], - "category": test_case.get("category", "general"), - "prompt_id": prompt_id, - "query": test_case["query"], - "success": True, - "execution_time": execution_time, - "tool_calls_made": len(deps.query_history), - "response": str(eda_report.markdown), - "data_quality_score": eda_report.data_quality_score, - "findings_count": len(eda_report.key_findings), - "risks_count": len(eda_report.risks), - "recommendations_count": len(eda_report.fixes), - "error": None, - } - - # Evaluate with LLM judge if available - if llm_judge: - judge_evaluation = _get_llm_judge_scores( - llm_judge, test_case, case_result - ) - case_result.update(judge_evaluation) - - # Check against expected metrics if available - expected_metrics = test_case.get("expected_metrics", {}) - case_result["meets_expectations"] = _check_expectations( - case_result, expected_metrics - ) - - return case_result - - except Exception as e: - execution_time = time.time() - start_time - return { - "test_id": test_case["id"], - "category": test_case.get("category", "general"), - "prompt_id": prompt_id, - "query": test_case["query"], - "success": False, - "execution_time": execution_time, - "error": str(e), - "meets_expectations": False, - } - - -def _get_llm_judge_scores( - llm_judge: Agent, test_case: Dict, case_result: Dict -) -> Dict: - """Get quality scores from LLM judge.""" - try: - eval_prompt = f""" -Query: {test_case["query"]} -Category: {test_case.get("category", "general")} - -Response to evaluate: -{case_result["response"][:2000]}... - -Data Quality Score Provided: {case_result["data_quality_score"]} -Tool Calls Made: {case_result["tool_calls_made"]} -Findings Count: {case_result["findings_count"]} - -Please evaluate this EDA response and provide scores in the requested JSON format.""" - - judge_response = llm_judge.run_sync(eval_prompt) - - # Parse JSON response - try: - scores = json.loads(str(judge_response.output)) - return { - "judge_accuracy": scores.get("accuracy", 0), - "judge_relevance": scores.get("relevance", 0), - "judge_completeness": scores.get("completeness", 0), - "judge_tool_usage": scores.get("tool_usage", 0), - "judge_actionability": scores.get("actionability", 0), - "judge_overall": scores.get("overall", 0), - "judge_reasoning": scores.get("reasoning", ""), - } - except json.JSONDecodeError: - logger.warning("LLM judge response was not valid JSON") - return { - "judge_overall": 3, - "judge_reasoning": "Failed to parse judge response", - } - - except Exception as e: - logger.warning(f"LLM judge evaluation failed: {e}") - return { - "judge_overall": 3, - "judge_reasoning": f"Evaluation error: {e}", - } - - -def _check_expectations(case_result: Dict, expected_metrics: Dict) -> bool: - """Check if results meet expected criteria.""" - if not expected_metrics: - return True - - checks = [] - - # Check minimum quality score - min_quality = expected_metrics.get("quality_score_min") - if min_quality: - checks.append(case_result.get("data_quality_score", 0) >= min_quality) - - # Check minimum recommendations - min_recs = expected_metrics.get("recommendations_min") - if min_recs: - checks.append(case_result.get("recommendations_count", 0) >= min_recs) - - # Check expected tool calls - expected_tools = expected_metrics.get("tool_calls_expected", []) - if expected_tools: - # This would need actual tool tracking - simplified for now - checks.append(case_result.get("tool_calls_made", 0) > 0) - - return all(checks) if checks else True - - -def _analyze_prompt_performance( - results: List[Dict], prompt_id: str, system_prompt: str -) -> Dict: - """Analyze performance across all test cases for a single prompt.""" - successful_results = [r for r in results if r["success"]] - - if not successful_results: - return { - "prompt_id": prompt_id, - "system_prompt": system_prompt, - "success_rate": 0, - "avg_scores": {}, - "category_performance": {}, - "overall_rating": "failed", - } - - # Calculate averages - avg_scores = { - "execution_time": sum(r["execution_time"] for r in successful_results) - / len(successful_results), - "tool_calls": sum(r["tool_calls_made"] for r in successful_results) - / len(successful_results), - "data_quality_score": sum( - r["data_quality_score"] for r in successful_results - ) - / len(successful_results), - "findings_count": sum(r["findings_count"] for r in successful_results) - / len(successful_results), - } - - # Add LLM judge scores if available - judge_scores = [r for r in successful_results if "judge_overall" in r] - if judge_scores: - avg_scores.update( - { - "judge_accuracy": sum( - r["judge_accuracy"] for r in judge_scores - ) - / len(judge_scores), - "judge_relevance": sum( - r["judge_relevance"] for r in judge_scores - ) - / len(judge_scores), - "judge_completeness": sum( - r["judge_completeness"] for r in judge_scores - ) - / len(judge_scores), - "judge_overall": sum(r["judge_overall"] for r in judge_scores) - / len(judge_scores), - } - ) - - # Category-wise performance - categories = {} - for result in successful_results: - cat = result["category"] - if cat not in categories: - categories[cat] = { - "count": 0, - "avg_score": 0, - "meets_expectations": 0, - } - categories[cat]["count"] += 1 - categories[cat]["avg_score"] += result["data_quality_score"] - if result.get("meets_expectations", False): - categories[cat]["meets_expectations"] += 1 - - for cat in categories: - categories[cat]["avg_score"] /= categories[cat]["count"] - categories[cat]["success_rate"] = ( - categories[cat]["meets_expectations"] / categories[cat]["count"] - ) - - # Overall rating - success_rate = len(successful_results) / len(results) - avg_judge_score = avg_scores.get( - "judge_overall", avg_scores.get("data_quality_score", 50) / 20 - ) - - if success_rate >= 0.8 and avg_judge_score >= 4: - rating = "excellent" - elif success_rate >= 0.6 and avg_judge_score >= 3: - rating = "good" - elif success_rate >= 0.4: - rating = "acceptable" - else: - rating = "poor" - - return { - "prompt_id": prompt_id, - "system_prompt": system_prompt, - "success_rate": success_rate, - "avg_scores": avg_scores, - "category_performance": categories, - "overall_rating": rating, - "detailed_results": results, - } - - -def _generate_comprehensive_analysis( - all_results: List[Dict], test_cases: List[Dict] -) -> Dict: - """Generate final comprehensive analysis of all prompt variants.""" - # Rank prompts by overall performance - ranked_prompts = sorted( - all_results, - key=lambda x: ( - x["success_rate"] * 0.4 - + x["avg_scores"].get("judge_overall", 3) * 0.3 - + (100 - x["avg_scores"]["execution_time"]) / 100 * 0.2 - + x["avg_scores"]["data_quality_score"] / 100 * 0.1 - ), - reverse=True, - ) - - best_prompt = ranked_prompts[0] if ranked_prompts else None - - # Category analysis - category_insights = {} - for test_case in test_cases: - cat = test_case.get("category", "general") - if cat not in category_insights: - category_insights[cat] = {"prompt_performance": []} - - for prompt_result in all_results: - cat_perf = prompt_result["category_performance"].get(cat, {}) - if cat_perf: - category_insights[cat]["prompt_performance"].append( - { - "prompt_id": prompt_result["prompt_id"], - "avg_score": cat_perf["avg_score"], - "success_rate": cat_perf["success_rate"], - } - ) - - return { - "evaluation_summary": { - "total_prompts_tested": len(all_results), - "total_test_cases": len(test_cases), - "best_prompt": best_prompt["prompt_id"] if best_prompt else "none", - "best_prompt_rating": best_prompt["overall_rating"] - if best_prompt - else "none", - "categories_tested": list(category_insights.keys()), - }, - "prompt_rankings": ranked_prompts, - "category_analysis": category_insights, - "detailed_results": all_results, - "recommendations": _generate_advanced_recommendations( - all_results, category_insights - ), - } - - -def _generate_advanced_recommendations( - all_results: List[Dict], category_insights: Dict -) -> List[str]: - """Generate advanced recommendations based on comprehensive analysis.""" - recommendations = [] - - if not all_results: - return ["No successful evaluations - check agent configuration"] - - # Success rate analysis - avg_success_rate = sum(r["success_rate"] for r in all_results) / len( - all_results - ) - if avg_success_rate < 0.6: - recommendations.append( - "Low overall success rate - consider simplifying prompts or checking tool integration" - ) - - # Performance consistency - execution_times = [r["avg_scores"]["execution_time"] for r in all_results] - time_variance = max(execution_times) - min(execution_times) - if time_variance > 30: - recommendations.append( - "High variance in execution time - optimize slower prompts for efficiency" - ) - - # Quality assessment - if any("judge_overall" in r["avg_scores"] for r in all_results): - judge_scores = [ - r["avg_scores"]["judge_overall"] - for r in all_results - if "judge_overall" in r["avg_scores"] - ] - avg_judge_score = sum(judge_scores) / len(judge_scores) - - if avg_judge_score < 3: - recommendations.append( - "LLM judge scores are low - review prompt clarity and specificity" - ) - elif avg_judge_score > 4: - recommendations.append( - "Excellent LLM judge scores - prompts are producing high-quality responses" - ) - - # Category-specific insights - for category, data in category_insights.items(): - if data["prompt_performance"]: - cat_scores = [p["avg_score"] for p in data["prompt_performance"]] - if min(cat_scores) < 60: - recommendations.append( - f"'{category}' category shows low scores - consider specialized prompts for this use case" - ) - - # Best practices - best_prompt = max(all_results, key=lambda x: x["success_rate"]) - if best_prompt["success_rate"] > 0.8: - recommendations.append( - f"'{best_prompt['prompt_id']}' shows strong performance - consider it as your baseline" - ) - - return recommendations diff --git a/examples/pydantic_ai_eda/steps/quality_gate.py b/examples/pydantic_ai_eda/steps/quality_gate.py deleted file mode 100644 index 7540e5062b1..00000000000 --- a/examples/pydantic_ai_eda/steps/quality_gate.py +++ /dev/null @@ -1,278 +0,0 @@ -"""Quality gate step for data quality assessment and pipeline routing.""" - -import logging -from typing import Annotated, Any, Dict, List, Tuple - -from models import QualityGateDecision - -from zenml import step - -logger = logging.getLogger(__name__) - - -def _check_quality_metric( - condition: bool, - pass_msg: str, - fail_msg: str, - blocking_issues: List[str], - decision_factors: List[str], -) -> None: - """Helper to standardize quality check pattern.""" - if condition: - blocking_issues.append(fail_msg) - decision_factors.append(f"{pass_msg.split(':')[0]}: โŒ") - else: - decision_factors.append(f"{pass_msg}: โœ…") - - -def evaluate_quality_gate( - report_json: Dict[str, Any], - min_quality_score: float = 70.0, - block_on_high_severity: bool = True, - max_missing_data_pct: float = 30.0, - require_target_column: bool = False, - target_column: str = None, -) -> Annotated[QualityGateDecision, "quality_gate_decision"]: - """Evaluate data quality and make gate pass/fail decision.""" - logger.info("Evaluating data quality gate") - - try: - # Extract metrics - quality_score = report_json.get("data_quality_score", 0.0) - fixes = report_json.get("fixes", []) - missing_data_analysis = report_json.get("missing_data_analysis", {}) - column_profiles = report_json.get("column_profiles", {}) - - blocking_issues = [] - decision_factors = [] - - # Quality score check - _check_quality_metric( - quality_score < min_quality_score, - f"Quality score: {quality_score:.1f}/{min_quality_score}", - f"Data quality score ({quality_score:.1f}) below minimum threshold ({min_quality_score})", - blocking_issues, - decision_factors, - ) - - # High severity issues check - high_severity_fixes = [f for f in fixes if f.get("severity") == "high"] - if block_on_high_severity and high_severity_fixes: - titles = [f.get("title", "Unknown") for f in high_severity_fixes] - _check_quality_metric( - True, - f"High severity issues: {len(high_severity_fixes)}", - f"High severity issues found: {', '.join(titles)}", - blocking_issues, - decision_factors, - ) - else: - decision_factors.append( - f"High severity issues: {len(high_severity_fixes)} โœ…" - ) - - # Missing data check - overall_missing_pct = missing_data_analysis.get( - "missing_percentage", 0.0 - ) - _check_quality_metric( - overall_missing_pct > max_missing_data_pct, - f"Missing data: {overall_missing_pct:.1f}%/{max_missing_data_pct}%", - f"Missing data percentage ({overall_missing_pct:.1f}%) exceeds threshold ({max_missing_data_pct}%)", - blocking_issues, - decision_factors, - ) - - # Target column check - if require_target_column: - if not target_column: - blocking_issues.append( - "Target column required but not specified" - ) - decision_factors.append("Target column: Not specified โŒ") - elif target_column not in column_profiles: - blocking_issues.append( - f"Required target column '{target_column}' not found" - ) - decision_factors.append( - f"Target column '{target_column}': Missing โŒ" - ) - else: - target_missing_pct = column_profiles[target_column].get( - "null_percentage", 0.0 - ) - if target_missing_pct > 50: - blocking_issues.append( - f"Target column '{target_column}' has {target_missing_pct:.1f}% missing values" - ) - decision_factors.append( - f"Target column '{target_column}': {target_missing_pct:.1f}% missing โŒ" - ) - else: - decision_factors.append( - f"Target column '{target_column}': Present โœ…" - ) - - # Generate recommendations - recommendations = _generate_recommendations( - quality_score, fixes, overall_missing_pct, min_quality_score - ) - - # Make decision - passed = len(blocking_issues) == 0 - decision_reason = ( - f"All quality checks passed. {', '.join(decision_factors)}" - if passed - else f"Quality gate failed. Issues: {'; '.join(blocking_issues)}" - ) - - logger.info( - "โœ… Quality gate PASSED" if passed else "โŒ Quality gate FAILED" - ) - if not passed: - logger.warning(f"Blocking issues: {blocking_issues}") - - return QualityGateDecision( - passed=passed, - quality_score=quality_score, - decision_reason=decision_reason, - blocking_issues=blocking_issues, - recommendations=recommendations, - metadata={ - "decision_factors": decision_factors, - "thresholds": { - "min_quality_score": min_quality_score, - "max_missing_data_pct": max_missing_data_pct, - "block_on_high_severity": block_on_high_severity, - "require_target_column": require_target_column, - }, - "metrics": { - "overall_missing_pct": overall_missing_pct, - "high_severity_count": len(high_severity_fixes), - "total_fixes": len(fixes), - "column_count": len(column_profiles), - }, - }, - ) - - except Exception as e: - logger.error(f"Quality gate evaluation failed: {e}") - return QualityGateDecision( - passed=False, - quality_score=0.0, - decision_reason=f"Quality gate evaluation failed: {str(e)}", - blocking_issues=[f"Technical error: {str(e)}"], - recommendations=[ - "Review EDA report format and quality gate configuration" - ], - metadata={"error": str(e)}, - ) - - -def _generate_recommendations( - quality_score: float, - fixes: list, - overall_missing_pct: float, - min_threshold: float, -) -> list: - """Generate actionable recommendations based on quality assessment.""" - recommendations = [] - - # Quality score recommendations - score_gap = min_threshold - quality_score - if score_gap > 30: - recommendations.append( - "Consider alternative data sources due to significant quality issues" - ) - elif score_gap > 15: - recommendations.append( - "Address high-priority data quality issues before proceeding" - ) - elif score_gap > 0: - recommendations.append("Minor quality improvements recommended") - - # Critical fixes - critical_fixes = [ - f for f in fixes if f.get("severity") in ["high", "critical"] - ] - if critical_fixes: - recommendations.append( - f"Implement {len(critical_fixes)} critical data quality fixes" - ) - - # Add specific fix types - fix_types = { - "missing": "Consider imputation strategies for missing data", - "duplicate": "Remove or consolidate duplicate records", - "outlier": "Investigate and handle outlier values", - } - - for fix in critical_fixes[:3]: - title = fix.get("title", "").lower() - for keyword, recommendation in fix_types.items(): - if keyword in title and recommendation not in recommendations: - recommendations.append(recommendation) - break - - # Missing data recommendations - if overall_missing_pct > 20: - recommendations.append( - "High missing data detected - consider imputation or collection improvements" - ) - elif overall_missing_pct > 10: - recommendations.append( - "Moderate missing data - review imputation strategies" - ) - - # Pipeline recommendations - if quality_score >= min_threshold and not critical_fixes: - recommendations.extend( - [ - "Data quality acceptable for downstream processing", - "Consider implementing quality monitoring", - ] - ) - elif score_gap <= min_threshold * 0.2: # Close to passing - recommendations.append( - "Data quality borderline - implement fixes and re-evaluate" - ) - else: - recommendations.append( - "Significant quality issues require attention before production use" - ) - - return recommendations - - -@step -def evaluate_quality_gate_with_routing( - report_json: Dict[str, Any], - min_quality_score: float = 70.0, - block_on_high_severity: bool = True, - max_missing_data_pct: float = 30.0, - require_target_column: bool = False, - target_column: str = None, -) -> Tuple[ - Annotated[QualityGateDecision, "quality_gate_decision"], - Annotated[str, "routing_message"], -]: - """Combined quality gate evaluation and routing decision.""" - decision = evaluate_quality_gate( - report_json, - min_quality_score, - block_on_high_severity, - max_missing_data_pct, - require_target_column, - target_column, - ) - - routing_message = ( - "๐Ÿš€ Data quality passed - proceed to downstream processing" - if decision.passed - else "๐Ÿ›‘ Data quality insufficient - review and improve data before proceeding" - ) - - logger.info(routing_message) if decision.passed else logger.warning( - routing_message - ) - return decision, routing_message diff --git a/examples/pydantic_ai_eda/test_cases.json b/examples/pydantic_ai_eda/test_cases.json deleted file mode 100644 index 27fae0b4c3b..00000000000 --- a/examples/pydantic_ai_eda/test_cases.json +++ /dev/null @@ -1,47 +0,0 @@ -[ - { - "id": "iris-quality-1", - "query": "Analyze the iris dataset for data quality issues, missing values, and overall readiness for machine learning.", - "category": "data_quality", - "expected_metrics": { - "should_identify": ["no missing values", "clean dataset", "high quality score"], - "quality_score_min": 85 - } - }, - { - "id": "iris-distribution-1", - "query": "Examine the distribution of features in this dataset and identify any patterns or anomalies.", - "category": "distribution_analysis", - "expected_metrics": { - "should_identify": ["sepal/petal measurements", "species distribution", "correlation patterns"], - "tool_calls_expected": ["describe", "analyze_correlations"] - } - }, - { - "id": "iris-ml-readiness-1", - "query": "Is this dataset ready for training a classification model? What preprocessing steps are needed?", - "category": "ml_readiness", - "expected_metrics": { - "should_identify": ["classification target", "feature scaling", "data preparation"], - "quality_score_min": 80 - } - }, - { - "id": "iris-edge-case-1", - "query": "Find any outliers, duplicates, or data inconsistencies that could impact model performance.", - "category": "edge_cases", - "expected_metrics": { - "should_identify": ["outlier analysis", "duplicate check", "consistency validation"], - "tool_calls_expected": ["run_sql"] - } - }, - { - "id": "iris-business-1", - "query": "From a business perspective, what insights can you extract from this data and what are the key recommendations?", - "category": "business_insights", - "expected_metrics": { - "should_identify": ["actionable insights", "business recommendations", "data value assessment"], - "recommendations_min": 2 - } - } -] \ No newline at end of file From aa6f4c1712f965b67dc6a8a90ccab0a7ba968c56 Mon Sep 17 00:00:00 2001 From: Hamza Tahir Date: Mon, 25 Aug 2025 23:56:59 +0200 Subject: [PATCH 08/14] Optimize prompt variants and tag the best performer --- examples/prompt_optimization/models.py | 26 +++- .../pipelines/prompt_optimization_pipeline.py | 24 ++-- examples/prompt_optimization/run.py | 49 ++++--- .../prompt_optimization/steps/eda_agent.py | 7 +- examples/prompt_optimization/steps/ingest.py | 22 ++- .../steps/prompt_optimization.py | 128 +++++++++--------- 6 files changed, 147 insertions(+), 109 deletions(-) diff --git a/examples/prompt_optimization/models.py b/examples/prompt_optimization/models.py index 5b0c1ea7b33..c49fa562c3f 100644 --- a/examples/prompt_optimization/models.py +++ b/examples/prompt_optimization/models.py @@ -9,9 +9,15 @@ class DataSourceConfig(BaseModel): """Data source configuration.""" source_type: str = Field(description="Data source type: hf or local") - source_path: str = Field(description="Path or identifier for the data source") - target_column: Optional[str] = Field(None, description="Optional target column name") - sample_size: Optional[int] = Field(None, description="Number of rows to sample") + source_path: str = Field( + description="Path or identifier for the data source" + ) + target_column: Optional[str] = Field( + None, description="Optional target column name" + ) + sample_size: Optional[int] = Field( + None, description="Number of rows to sample" + ) class AgentConfig(BaseModel): @@ -19,13 +25,19 @@ class AgentConfig(BaseModel): model_name: str = Field("gpt-4o-mini", description="Language model to use") max_tool_calls: int = Field(6, description="Maximum tool calls allowed") - timeout_seconds: int = Field(60, description="Max execution time in seconds") + timeout_seconds: int = Field( + 60, description="Max execution time in seconds" + ) class EDAReport(BaseModel): """Simple EDA analysis report from AI agent.""" headline: str = Field(description="Executive summary of key findings") - key_findings: List[str] = Field(default_factory=list, description="Important discoveries") - data_quality_score: float = Field(description="Overall quality score (0-100)") - markdown: str = Field(description="Full markdown report") \ No newline at end of file + key_findings: List[str] = Field( + default_factory=list, description="Important discoveries" + ) + data_quality_score: float = Field( + description="Overall quality score (0-100)" + ) + markdown: str = Field(description="Full markdown report") diff --git a/examples/prompt_optimization/pipelines/prompt_optimization_pipeline.py b/examples/prompt_optimization/pipelines/prompt_optimization_pipeline.py index d756a6f875d..7d936865dbc 100644 --- a/examples/prompt_optimization/pipelines/prompt_optimization_pipeline.py +++ b/examples/prompt_optimization/pipelines/prompt_optimization_pipeline.py @@ -18,32 +18,34 @@ def prompt_optimization_pipeline( agent_config: Optional[AgentConfig] = None, ) -> Dict[str, Any]: """Optimize prompts and tag the best one for production use. - + This pipeline demonstrates ZenML's artifact management by testing multiple prompt variants and tagging the best performer. - + Args: - source_config: Data source configuration + source_config: Data source configuration prompt_variants: List of prompt strings to test agent_config: AI agent configuration - + Returns: Pipeline results with best prompt and metadata """ logger.info("๐Ÿงช Starting prompt optimization pipeline") - + # Step 1: Load data dataset_df, metadata = ingest_data(source_config=source_config) - + # Step 2: Test prompts and tag the best one best_prompt = compare_prompts_and_tag_best( dataset_df=dataset_df, - prompt_variants=prompt_variants, + prompt_variants=prompt_variants, agent_config=agent_config, ) - - logger.info("โœ… Prompt optimization completed - best prompt tagged with 'optimized'") - + + logger.info( + "โœ… Prompt optimization completed - best prompt tagged with 'optimized'" + ) + return { "best_prompt": best_prompt, "metadata": metadata, @@ -51,4 +53,4 @@ def prompt_optimization_pipeline( "source": f"{source_config.source_type}:{source_config.source_path}", "variants_tested": len(prompt_variants), }, - } \ No newline at end of file + } diff --git a/examples/prompt_optimization/run.py b/examples/prompt_optimization/run.py index 7ff348c79c0..6eb4ead81f3 100644 --- a/examples/prompt_optimization/run.py +++ b/examples/prompt_optimization/run.py @@ -5,20 +5,31 @@ from typing import Optional import click - from models import AgentConfig, DataSourceConfig from pipelines.production_eda_pipeline import production_eda_pipeline from pipelines.prompt_optimization_pipeline import prompt_optimization_pipeline -@click.command(help="Run prompt optimization and/or production EDA pipelines (both by default)") -@click.option("--optimization-pipeline", is_flag=True, help="Run prompt optimization") +@click.command( + help="Run prompt optimization and/or production EDA pipelines (both by default)" +) +@click.option( + "--optimization-pipeline", is_flag=True, help="Run prompt optimization" +) @click.option("--production-pipeline", is_flag=True, help="Run production EDA") -@click.option("--data-source", default="hf:scikit-learn/iris", help="Data source (type:path)") +@click.option( + "--data-source", + default="hf:scikit-learn/iris", + help="Data source (type:path)", +) @click.option("--target-column", default="target", help="Target column name") -@click.option("--model-name", help="Model name (auto-detected if not specified)") +@click.option( + "--model-name", help="Model name (auto-detected if not specified)" +) @click.option("--max-tool-calls", default=6, type=int, help="Max tool calls") -@click.option("--timeout-seconds", default=60, type=int, help="Timeout seconds") +@click.option( + "--timeout-seconds", default=60, type=int, help="Timeout seconds" +) @click.option("--no-cache", is_flag=True, help="Disable caching") def main( optimization_pipeline: bool = False, @@ -47,14 +58,14 @@ def main( # Auto-detect model if model_name is None: model_name = "gpt-4o-mini" if has_openai else "claude-3-haiku-20240307" - + # Parse data source try: source_type, source_path = data_source.split(":", 1) except ValueError: click.echo(f"โŒ Invalid data source: {data_source}") return - + # Create configs source_config = DataSourceConfig( source_type=source_type, @@ -67,19 +78,19 @@ def main( max_tool_calls=max_tool_calls, timeout_seconds=timeout_seconds, ) - + pipeline_options = {"enable_cache": not no_cache} - + # Stage 1: Prompt optimization if optimization_pipeline: click.echo("๐Ÿงช Running prompt optimization...") - + prompt_variants = [ "You are a data analyst. Quickly assess data quality (0-100 score) and identify key patterns. Be concise.", "You are a data quality specialist. Calculate quality score (0-100), find missing data, duplicates, and correlations. Provide specific recommendations.", - "You are a business analyst. Assess ML readiness with quality score. Focus on business impact and actionable insights." + "You are a business analyst. Assess ML readiness with quality score. Focus on business impact and actionable insights.", ] - + try: prompt_optimization_pipeline.with_options(**pipeline_options)( source_config=source_config, @@ -87,24 +98,24 @@ def main( agent_config=agent_config, ) click.echo("โœ… Optimization completed - best prompt tagged") - + except Exception as e: click.echo(f"โŒ Optimization failed: {e}") - - # Stage 2: Production analysis + + # Stage 2: Production analysis if production_pipeline: click.echo("๐Ÿญ Running production analysis...") - + try: production_eda_pipeline.with_options(**pipeline_options)( source_config=source_config, agent_config=agent_config, ) click.echo("โœ… Production analysis completed") - + except Exception as e: click.echo(f"โŒ Production analysis failed: {e}") if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/examples/prompt_optimization/steps/eda_agent.py b/examples/prompt_optimization/steps/eda_agent.py index 5e95060568f..d795821f24f 100644 --- a/examples/prompt_optimization/steps/eda_agent.py +++ b/examples/prompt_optimization/steps/eda_agent.py @@ -7,9 +7,8 @@ from pydantic_ai.settings import ModelSettings from zenml import step -from zenml.types import MarkdownString - from zenml.logger import get_logger +from zenml.types import MarkdownString logger = get_logger(__name__) @@ -39,13 +38,13 @@ def run_eda_agent( Annotated[Dict[str, pd.DataFrame], "analysis_tables"], ]: """Run Pydantic AI agent for EDA analysis with optional custom prompt. - + Args: dataset_df: Dataset to analyze dataset_metadata: Metadata about the dataset agent_config: Configuration for the AI agent custom_system_prompt: Optional custom system prompt (overrides default) - + Returns: Tuple of EDA outputs: markdown report, JSON report, SQL log, analysis tables """ diff --git a/examples/prompt_optimization/steps/ingest.py b/examples/prompt_optimization/steps/ingest.py index 22cad0677cc..b35f002ce9f 100644 --- a/examples/prompt_optimization/steps/ingest.py +++ b/examples/prompt_optimization/steps/ingest.py @@ -26,7 +26,9 @@ def ingest_data( Returns: Tuple of (dataframe, metadata) """ - logger.info(f"Loading data from {source_config.source_type}:{source_config.source_path}") + logger.info( + f"Loading data from {source_config.source_type}:{source_config.source_path}" + ) # Load data based on source type if source_config.source_type == "hf": @@ -34,12 +36,16 @@ def ingest_data( elif source_config.source_type == "local": df = _load_from_local(source_config) else: - raise ValueError(f"Unsupported source type: {source_config.source_type}") + raise ValueError( + f"Unsupported source type: {source_config.source_type}" + ) # Apply sampling if configured if source_config.sample_size and len(df) > source_config.sample_size: df = df.sample(n=source_config.sample_size, random_state=42) - logger.info(f"Sampled {source_config.sample_size} rows from {len(df)} total") + logger.info( + f"Sampled {source_config.sample_size} rows from {len(df)} total" + ) # Generate simple metadata metadata = { @@ -63,12 +69,14 @@ def _load_from_huggingface(config: DataSourceConfig) -> pd.DataFrame: # Simple dataset loading dataset = load_dataset(config.source_path, split="train") df = dataset.to_pandas() - + logger.info(f"Loaded HuggingFace dataset: {config.source_path}") return df except ImportError: - raise ImportError("datasets library required. Install with: pip install datasets") + raise ImportError( + "datasets library required. Install with: pip install datasets" + ) except Exception as e: raise RuntimeError(f"Failed to load dataset {config.source_path}: {e}") @@ -77,7 +85,7 @@ def _load_from_local(config: DataSourceConfig) -> pd.DataFrame: """Load dataset from local file.""" try: file_path = config.source_path - + if file_path.endswith(".csv"): df = pd.read_csv(file_path) elif file_path.endswith(".json"): @@ -90,4 +98,4 @@ def _load_from_local(config: DataSourceConfig) -> pd.DataFrame: return df except Exception as e: - raise RuntimeError(f"Failed to load file {config.source_path}: {e}") \ No newline at end of file + raise RuntimeError(f"Failed to load file {config.source_path}: {e}") diff --git a/examples/prompt_optimization/steps/prompt_optimization.py b/examples/prompt_optimization/steps/prompt_optimization.py index e946c677081..458d8ed3dbb 100644 --- a/examples/prompt_optimization/steps/prompt_optimization.py +++ b/examples/prompt_optimization/steps/prompt_optimization.py @@ -23,19 +23,19 @@ def compare_prompts_and_tag_best( agent_config: AgentConfig = None, ) -> Annotated[str, ArtifactConfig(name="best_prompt")]: """Compare prompt variants and tag the best one with exclusive 'optimized' tag. - + This step demonstrates ZenML's artifact management by: 1. Testing multiple prompt variants - 2. Finding the best performer + 2. Finding the best performer 3. Returning it as a tagged artifact that other pipelines can find - + The 'optimized' tag is exclusive, so only one prompt can be 'optimized' at a time. - + Args: dataset_df: Dataset to test prompts against prompt_variants: List of system prompts to compare agent_config: Configuration for AI agents - + Returns: The best performing prompt string with exclusive 'optimized' tag """ @@ -43,20 +43,20 @@ def compare_prompts_and_tag_best( agent_config = AgentConfig() logger.info(f"๐Ÿงช Testing {len(prompt_variants)} prompt variants") - + results = [] - + for i, system_prompt in enumerate(prompt_variants): prompt_id = f"variant_{i + 1}" logger.info(f"Testing {prompt_id}...") - + start_time = time.time() - + try: # Create agent with this prompt deps = AnalystAgentDeps() main_ref = deps.store(dataset_df) - + agent = Agent( f"openai:{agent_config.model_name}", deps_type=AnalystAgentDeps, @@ -64,97 +64,103 @@ def compare_prompts_and_tag_best( system_prompt=system_prompt, model_settings=ModelSettings(parallel_tool_calls=False), ) - + for tool in AGENT_TOOLS: agent.tool(tool) - + # Run analysis user_prompt = f"Analyze dataset '{main_ref}' - focus on data quality and key insights." result = agent.run_sync(user_prompt, deps=deps) eda_report = result.output - + execution_time = time.time() - start_time - + # Score this variant score = ( - eda_report.data_quality_score * 0.7 + # Primary metric - (100 - min(execution_time * 2, 100)) * 0.2 + # Speed bonus - len(eda_report.key_findings) * 5 * 0.1 # Thoroughness + eda_report.data_quality_score * 0.7 # Primary metric + + (100 - min(execution_time * 2, 100)) * 0.2 # Speed bonus + + len(eda_report.key_findings) * 5 * 0.1 # Thoroughness ) - - results.append({ - "prompt_id": prompt_id, - "prompt": system_prompt, - "score": score, - "quality_score": eda_report.data_quality_score, - "execution_time": execution_time, - "findings_count": len(eda_report.key_findings), - "success": True, - }) - - logger.info(f"โœ… {prompt_id}: score={score:.1f}, time={execution_time:.1f}s") - + + results.append( + { + "prompt_id": prompt_id, + "prompt": system_prompt, + "score": score, + "quality_score": eda_report.data_quality_score, + "execution_time": execution_time, + "findings_count": len(eda_report.key_findings), + "success": True, + } + ) + + logger.info( + f"โœ… {prompt_id}: score={score:.1f}, time={execution_time:.1f}s" + ) + except Exception as e: logger.warning(f"โŒ {prompt_id} failed: {e}") - results.append({ - "prompt_id": prompt_id, - "prompt": system_prompt, - "score": 0, - "success": False, - "error": str(e), - }) - + results.append( + { + "prompt_id": prompt_id, + "prompt": system_prompt, + "score": 0, + "success": False, + "error": str(e), + } + ) + # Find best performer successful_results = [r for r in results if r["success"]] - + if not successful_results: logger.warning("All prompts failed, using first as fallback") best_prompt = prompt_variants[0] else: best_result = max(successful_results, key=lambda x: x["score"]) best_prompt = best_result["prompt"] - logger.info(f"๐Ÿ† Best prompt: {best_result['prompt_id']} (score: {best_result['score']:.1f})") - + logger.info( + f"๐Ÿ† Best prompt: {best_result['prompt_id']} (score: {best_result['score']:.1f})" + ) + logger.info("๐Ÿ’พ Best prompt will be stored with exclusive 'optimized' tag") - + # Add exclusive tag to this step's output artifact - add_tags( - tags=[Tag(name="optimized", exclusive=True)], - infer_artifact=True - ) - + add_tags(tags=[Tag(name="optimized", exclusive=True)], infer_artifact=True) + return best_prompt def get_optimized_prompt(): """Retrieve the optimized prompt using ZenML's tag-based filtering. - + This demonstrates ZenML's tag filtering by finding artifacts tagged with 'optimized'. Since 'optimized' is an exclusive tag, there will be at most one such artifact. - + Returns: The optimized prompt from the latest optimization run, or default if none found """ try: client = Client() - + # Find artifacts tagged with 'optimized' (our exclusive tag) - artifacts = client.list_artifact_versions( - tags=["optimized"], - size=1 - ) - + artifacts = client.list_artifact_versions(tags=["optimized"], size=1) + if artifacts.items: optimized_prompt = artifacts.items[0] - logger.info(f"๐ŸŽฏ Retrieved optimized prompt from artifact: {optimized_prompt.id}") + logger.info( + f"๐ŸŽฏ Retrieved optimized prompt from artifact: {optimized_prompt.id}" + ) logger.info(f" Artifact created: {optimized_prompt.created}") return optimized_prompt else: - logger.info("๐Ÿ” No optimized prompt found (no artifacts with 'optimized' tag)") - + logger.info( + "๐Ÿ” No optimized prompt found (no artifacts with 'optimized' tag)" + ) + except Exception as e: logger.warning(f"Failed to retrieve optimized prompt: {e}") - + # Fallback to default prompt if no optimization artifacts found logger.info("๐Ÿ“ Using default prompt (run optimization pipeline first)") default_prompt = """You are a data analyst. Perform comprehensive EDA analysis. @@ -166,5 +172,5 @@ def get_optimized_prompt(): - Provide 3-5 actionable recommendations Be thorough but efficient.""" - - return ExternalArtifact(value=default_prompt) \ No newline at end of file + + return ExternalArtifact(value=default_prompt) From f6ba7d96f427aeea8b6e6daaf244e5c04d59ba2d Mon Sep 17 00:00:00 2001 From: Alex Strick van Linschoten Date: Mon, 29 Sep 2025 11:15:51 +0200 Subject: [PATCH 09/14] Use argparse not click --- examples/prompt_optimization/run.py | 145 +++++++++++++++++----------- 1 file changed, 90 insertions(+), 55 deletions(-) diff --git a/examples/prompt_optimization/run.py b/examples/prompt_optimization/run.py index 6eb4ead81f3..309eb612f82 100644 --- a/examples/prompt_optimization/run.py +++ b/examples/prompt_optimization/run.py @@ -1,89 +1,124 @@ #!/usr/bin/env python3 """Two-Stage Prompt Optimization Example with ZenML.""" +import argparse import os +import sys from typing import Optional -import click from models import AgentConfig, DataSourceConfig from pipelines.production_eda_pipeline import production_eda_pipeline from pipelines.prompt_optimization_pipeline import prompt_optimization_pipeline -@click.command( - help="Run prompt optimization and/or production EDA pipelines (both by default)" -) -@click.option( - "--optimization-pipeline", is_flag=True, help="Run prompt optimization" -) -@click.option("--production-pipeline", is_flag=True, help="Run production EDA") -@click.option( - "--data-source", - default="hf:scikit-learn/iris", - help="Data source (type:path)", -) -@click.option("--target-column", default="target", help="Target column name") -@click.option( - "--model-name", help="Model name (auto-detected if not specified)" -) -@click.option("--max-tool-calls", default=6, type=int, help="Max tool calls") -@click.option( - "--timeout-seconds", default=60, type=int, help="Timeout seconds" -) -@click.option("--no-cache", is_flag=True, help="Disable caching") -def main( - optimization_pipeline: bool = False, - production_pipeline: bool = False, - data_source: str = "hf:scikit-learn/iris", - target_column: str = "target", - model_name: Optional[str] = None, - max_tool_calls: int = 6, - timeout_seconds: int = 60, - no_cache: bool = False, -): +def build_parser() -> argparse.ArgumentParser: + """Build the CLI parser. + + Rationale: + - Uses only the Python standard library to avoid extra dependencies. + - Mirrors the original Click-based flag and option semantics to keep the + README usage and user experience unchanged. + """ + parser = argparse.ArgumentParser( + description="Run prompt optimization and/or production EDA pipelines (both by default)" + ) + parser.add_argument( + "--optimization-pipeline", + action="store_true", + help="Run prompt optimization", + ) + parser.add_argument( + "--production-pipeline", + action="store_true", + help="Run production EDA", + ) + parser.add_argument( + "--data-source", + default="hf:scikit-learn/iris", + help="Data source (type:path)", + ) + parser.add_argument( + "--target-column", + default="target", + help="Target column name", + ) + parser.add_argument( + "--model-name", + help="Model name (auto-detected if not specified)", + ) + parser.add_argument( + "--max-tool-calls", + default=6, + type=int, + help="Max tool calls", + ) + parser.add_argument( + "--timeout-seconds", + default=60, + type=int, + help="Timeout seconds", + ) + parser.add_argument( + "--no-cache", + action="store_true", + help="Disable caching", + ) + return parser + + +def main() -> int: """Run prompt optimization and/or production EDA pipelines.""" + parser = build_parser() + args = parser.parse_args() + # Default: run both pipelines if no flags specified + optimization_pipeline = args.optimization_pipeline + production_pipeline = args.production_pipeline if not optimization_pipeline and not production_pipeline: optimization_pipeline = True production_pipeline = True - # Check API keys + # Check API keys. We support either OpenAI or Anthropic; at least one is required. has_openai = bool(os.getenv("OPENAI_API_KEY")) has_anthropic = bool(os.getenv("ANTHROPIC_API_KEY")) if not (has_openai or has_anthropic): - click.echo("โŒ Set OPENAI_API_KEY or ANTHROPIC_API_KEY") - return + print("โŒ Set OPENAI_API_KEY or ANTHROPIC_API_KEY", file=sys.stderr) + return 1 - # Auto-detect model + # Auto-detect model if not explicitly provided. This keeps a sensible default + # that aligns with whichever provider the user configured. + model_name: Optional[str] = args.model_name if model_name is None: model_name = "gpt-4o-mini" if has_openai else "claude-3-haiku-20240307" - # Parse data source + # Parse data source "type:path" into its components early so we can fail-fast + # with a helpful error if the format is invalid. try: - source_type, source_path = data_source.split(":", 1) + source_type, source_path = args.data_source.split(":", 1) except ValueError: - click.echo(f"โŒ Invalid data source: {data_source}") - return + print(f"โŒ Invalid data source: {args.data_source}", file=sys.stderr) + return 2 - # Create configs + # Create configs passed to the pipelines source_config = DataSourceConfig( source_type=source_type, source_path=source_path, - target_column=target_column, + target_column=args.target_column, ) agent_config = AgentConfig( model_name=model_name, - max_tool_calls=max_tool_calls, - timeout_seconds=timeout_seconds, + max_tool_calls=args.max_tool_calls, + timeout_seconds=args.timeout_seconds, ) - pipeline_options = {"enable_cache": not no_cache} + # ZenML run options: keep parity with the original example + pipeline_options = {"enable_cache": not args.no_cache} # Stage 1: Prompt optimization if optimization_pipeline: - click.echo("๐Ÿงช Running prompt optimization...") + print("๐Ÿงช Running prompt optimization...") prompt_variants = [ "You are a data analyst. Quickly assess data quality (0-100 score) and identify key patterns. Be concise.", @@ -97,25 +132,25 @@ def main( prompt_variants=prompt_variants, agent_config=agent_config, ) - click.echo("โœ… Optimization completed - best prompt tagged") - + print("โœ… Optimization completed - best prompt tagged") except Exception as e: - click.echo(f"โŒ Optimization failed: {e}") + # Best-effort behavior: log the error and continue to the next stage + print(f"โŒ Optimization failed: {e}", file=sys.stderr) # Stage 2: Production analysis if production_pipeline: - click.echo("๐Ÿญ Running production analysis...") - + print("๐Ÿญ Running production analysis...") try: production_eda_pipeline.with_options(**pipeline_options)( source_config=source_config, agent_config=agent_config, ) - click.echo("โœ… Production analysis completed") - + print("โœ… Production analysis completed") except Exception as e: - click.echo(f"โŒ Production analysis failed: {e}") + print(f"โŒ Production analysis failed: {e}", file=sys.stderr) + + return 0 if __name__ == "__main__": - main() + raise SystemExit(main()) From dc04d9b1b0d9f24cd4480aca16c006ccdf6c41fc Mon Sep 17 00:00:00 2001 From: Alex Strick van Linschoten Date: Mon, 29 Sep 2025 11:29:25 +0200 Subject: [PATCH 10/14] Fix Pydantic validation bug --- .../prompt_optimization/steps/eda_agent.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/examples/prompt_optimization/steps/eda_agent.py b/examples/prompt_optimization/steps/eda_agent.py index d795821f24f..79e3a8b5c89 100644 --- a/examples/prompt_optimization/steps/eda_agent.py +++ b/examples/prompt_optimization/steps/eda_agent.py @@ -110,19 +110,18 @@ def run_eda_agent( result = analyst_agent.run_sync(user_prompt, deps=deps) eda_report = result.output except Exception as e: - # Simple fallback - create basic report eda_report = EDAReport( - headline=f"Basic analysis of {dataset_df.shape[0]} rows, {dataset_df.shape[1]} columns", + headline=f"Analysis failed for dataset with {dataset_df.shape[0]} rows", key_findings=[ - f"Dataset contains {len(dataset_df)} rows and {len(dataset_df.columns)} columns" + f"Dataset contains {len(dataset_df)} rows and {len(dataset_df.columns)} columns.", + "The AI agent failed to generate a report.", ], - risks=["Analysis failed - using basic fallback"], - fixes=[], - data_quality_score=50.0, - markdown=f"# EDA Report\n\nBasic analysis failed: {str(e)}\n\nDataset shape: {dataset_df.shape}", - column_profiles={}, - correlation_insights=[], - missing_data_analysis={}, + data_quality_score=0.0, + markdown=( + f"# EDA Report Failed\n\n" + f"Analysis failed with error: {str(e)}\n\n" + f"Dataset shape: {dataset_df.shape}" + ), ) # Return results From b00bac1e5c2bf2a76bfb076df91a92ef5670b3f9 Mon Sep 17 00:00:00 2001 From: Alex Strick van Linschoten Date: Mon, 29 Sep 2025 11:58:46 +0200 Subject: [PATCH 11/14] Align prompt optimization prompts --- .../prompt_optimization/steps/eda_agent.py | 23 +---- .../steps/prompt_optimization.py | 91 +++++++++++++------ .../prompt_optimization/steps/prompt_text.py | 46 ++++++++++ 3 files changed, 111 insertions(+), 49 deletions(-) create mode 100644 examples/prompt_optimization/steps/prompt_text.py diff --git a/examples/prompt_optimization/steps/eda_agent.py b/examples/prompt_optimization/steps/eda_agent.py index 79e3a8b5c89..82079ea7863 100644 --- a/examples/prompt_optimization/steps/eda_agent.py +++ b/examples/prompt_optimization/steps/eda_agent.py @@ -23,6 +23,7 @@ from models import AgentConfig, EDAReport from steps.agent_tools import AGENT_TOOLS, AnalystAgentDeps +from steps.prompt_text import DEFAULT_SYSTEM_PROMPT, build_user_prompt @step @@ -69,15 +70,7 @@ def run_eda_agent( system_prompt = custom_system_prompt logger.info("๐ŸŽฏ Using custom optimized system prompt for analysis") else: - system_prompt = """You are a data analyst. Perform quick but insightful EDA. - -FOCUS ON: -- Data quality score (0-100) based on missing data and duplicates -- Key patterns and distributions -- Notable correlations or anomalies -- 2-3 actionable recommendations - -Be concise but specific with numbers. Aim for quality insights, not exhaustive analysis.""" + system_prompt = DEFAULT_SYSTEM_PROMPT logger.info("๐Ÿ“ Using default system prompt for analysis") analyst_agent = Agent( @@ -95,16 +88,8 @@ def run_eda_agent( for tool in AGENT_TOOLS: analyst_agent.tool(tool) - # Run focused analysis - user_prompt = f"""Quick EDA analysis for dataset '{main_ref}' ({dataset_df.shape[0]} rows, {dataset_df.shape[1]} cols). - -STEPS (keep it fast): -1. display('{main_ref}') - check data structure -2. describe('{main_ref}') - get key stats -3. run_sql('{main_ref}', 'SELECT COUNT(*) as total, COUNT(DISTINCT *) as unique FROM dataset') - check duplicates -4. If multiple numeric columns: analyze_correlations('{main_ref}') - -Generate EDAReport with data quality score and 2-3 key insights.""" + # Run focused analysis using shared user prompt builder + user_prompt = build_user_prompt(main_ref, dataset_df) try: result = analyst_agent.run_sync(user_prompt, deps=deps) diff --git a/examples/prompt_optimization/steps/prompt_optimization.py b/examples/prompt_optimization/steps/prompt_optimization.py index 458d8ed3dbb..239a9c6af75 100644 --- a/examples/prompt_optimization/steps/prompt_optimization.py +++ b/examples/prompt_optimization/steps/prompt_optimization.py @@ -8,13 +8,50 @@ from pydantic_ai import Agent from pydantic_ai.settings import ModelSettings from steps.agent_tools import AGENT_TOOLS, AnalystAgentDeps +from steps.prompt_text import DEFAULT_SYSTEM_PROMPT, build_user_prompt -from zenml import ArtifactConfig, ExternalArtifact, Tag, add_tags, step +from zenml import ArtifactConfig, Tag, add_tags, step from zenml.client import Client from zenml.logger import get_logger logger = get_logger(__name__) +# Scoring weights reflect priorities: quality first (insightful outputs), +# then speed (encourage lower latency), then findings (reward coverage). +WEIGHT_QUALITY: float = 0.7 +WEIGHT_SPEED: float = 0.2 +WEIGHT_FINDINGS: float = 0.1 + +# Linear time penalty: each second reduces the speed score by this many points +# until the score floors at 0 (capped at 100 points of penalty). +SPEED_PENALTY_PER_SECOND: float = 2.0 + +# Reward per key finding discovered by the agent before applying the findings weight. +# Keeping this explicit makes it easy to tune coverage incentives. +FINDINGS_SCORE_PER_ITEM: float = 0.5 + + +def compute_prompt_score( + eda_report: EDAReport, execution_time: float +) -> float: + """Compute a prompt's score from EDA results and runtime. + + This makes scoring trade-offs explicit and tunable via module-level constants: + - Prioritize report quality (WEIGHT_QUALITY) + - Encourage faster execution via a linear time penalty converted to a 0โ€“100 speed score (WEIGHT_SPEED) + - Reward thoroughness by crediting key findings (WEIGHT_FINDINGS) + """ + speed_score = max( + 0.0, + 100.0 - min(execution_time * SPEED_PENALTY_PER_SECOND, 100.0), + ) + findings_score = len(eda_report.key_findings) * FINDINGS_SCORE_PER_ITEM + return ( + eda_report.data_quality_score * WEIGHT_QUALITY + + speed_score * WEIGHT_SPEED + + findings_score * WEIGHT_FINDINGS + ) + @step def compare_prompts_and_tag_best( @@ -69,18 +106,14 @@ def compare_prompts_and_tag_best( agent.tool(tool) # Run analysis - user_prompt = f"Analyze dataset '{main_ref}' - focus on data quality and key insights." + user_prompt = build_user_prompt(main_ref, dataset_df) result = agent.run_sync(user_prompt, deps=deps) eda_report = result.output execution_time = time.time() - start_time # Score this variant - score = ( - eda_report.data_quality_score * 0.7 # Primary metric - + (100 - min(execution_time * 2, 100)) * 0.2 # Speed bonus - + len(eda_report.key_findings) * 5 * 0.1 # Thoroughness - ) + score = compute_prompt_score(eda_report, execution_time) results.append( { @@ -114,8 +147,10 @@ def compare_prompts_and_tag_best( successful_results = [r for r in results if r["success"]] if not successful_results: - logger.warning("All prompts failed, using first as fallback") - best_prompt = prompt_variants[0] + logger.warning( + "All prompts failed, falling back to DEFAULT_SYSTEM_PROMPT" + ) + best_prompt = DEFAULT_SYSTEM_PROMPT else: best_result = max(successful_results, key=lambda x: x["score"]) best_prompt = best_result["prompt"] @@ -131,14 +166,15 @@ def compare_prompts_and_tag_best( return best_prompt -def get_optimized_prompt(): +def get_optimized_prompt() -> str: """Retrieve the optimized prompt using ZenML's tag-based filtering. This demonstrates ZenML's tag filtering by finding artifacts tagged with 'optimized'. Since 'optimized' is an exclusive tag, there will be at most one such artifact. Returns: - The optimized prompt from the latest optimization run, or default if none found + The optimized prompt from the latest optimization run as a plain string, + or DEFAULT_SYSTEM_PROMPT if none found or retrieval fails. """ try: client = Client() @@ -147,30 +183,25 @@ def get_optimized_prompt(): artifacts = client.list_artifact_versions(tags=["optimized"], size=1) if artifacts.items: - optimized_prompt = artifacts.items[0] + optimized_artifact = artifacts.items[0] + prompt_value = optimized_artifact.load() logger.info( - f"๐ŸŽฏ Retrieved optimized prompt from artifact: {optimized_prompt.id}" + f"๐ŸŽฏ Retrieved optimized prompt from artifact: {optimized_artifact.id}" ) - logger.info(f" Artifact created: {optimized_prompt.created}") - return optimized_prompt + logger.info(f" Artifact created: {optimized_artifact.created}") + return prompt_value else: logger.info( - "๐Ÿ” No optimized prompt found (no artifacts with 'optimized' tag)" + "๐Ÿ” No optimized prompt found (no artifacts with 'optimized' tag). Using DEFAULT_SYSTEM_PROMPT." ) except Exception as e: - logger.warning(f"Failed to retrieve optimized prompt: {e}") - - # Fallback to default prompt if no optimization artifacts found - logger.info("๐Ÿ“ Using default prompt (run optimization pipeline first)") - default_prompt = """You are a data analyst. Perform comprehensive EDA analysis. - -FOCUS ON: -- Calculate data quality score (0-100) based on completeness and consistency -- Identify missing data patterns and duplicates -- Find key correlations and patterns -- Provide 3-5 actionable recommendations - -Be thorough but efficient.""" + logger.warning( + f"Failed to retrieve optimized prompt: {e}. Falling back to DEFAULT_SYSTEM_PROMPT." + ) - return ExternalArtifact(value=default_prompt) + # Fallback to default system prompt if no optimization artifacts found or retrieval failed + logger.info( + "๐Ÿ“ Using default system prompt (run optimization pipeline first)" + ) + return DEFAULT_SYSTEM_PROMPT diff --git a/examples/prompt_optimization/steps/prompt_text.py b/examples/prompt_optimization/steps/prompt_text.py new file mode 100644 index 00000000000..6e80eb20541 --- /dev/null +++ b/examples/prompt_optimization/steps/prompt_text.py @@ -0,0 +1,46 @@ +"""Shared prompt text for the prompt optimization example. + +This module centralizes the default system prompt and the user prompt +builder so both steps can reuse a single source of truth. +""" + +from __future__ import annotations + +import pandas as pd + +# Canonical default system prompt used by the EDA agent when no custom prompt +# is provided. Kept identical to the original in run_eda_agent to preserve behavior. +DEFAULT_SYSTEM_PROMPT: str = """You are a data analyst. Perform quick but insightful EDA. + +FOCUS ON: +- Data quality score (0-100) based on missing data and duplicates +- Key patterns and distributions +- Notable correlations or anomalies +- 2-3 actionable recommendations + +Be concise but specific with numbers. Aim for quality insights, not exhaustive analysis.""" + + +def build_user_prompt(main_ref: str, df: pd.DataFrame) -> str: + """Build the user prompt for the EDA agent. + + Produces the same structured instructions used previously, including dataset + shape metadata and the numbered action steps that guide the agent to use tools. + + Args: + main_ref: Reference key for the stored dataset (e.g., 'Out[1]'). + df: The pandas DataFrame being analyzed, used to derive rows and columns. + + Returns: + A formatted user prompt string guiding the agent through quick EDA steps. + """ + rows, cols = df.shape + return ( + f"Quick EDA analysis for dataset '{main_ref}' ({rows} rows, {cols} cols).\n\n" + "STEPS (keep it fast):\n" + f"1. display('{main_ref}') - check data structure \n" + f"2. describe('{main_ref}') - get key stats\n" + f"3. run_sql('{main_ref}', 'SELECT COUNT(*) as total FROM dataset') - check row count\n" + f"4. If multiple numeric columns: analyze_correlations('{main_ref}')\n\n" + "Generate EDAReport with data quality score and 2-3 key insights." + ) From 4a61bcc4ca179ead191b30d0e704f87868c1ead4 Mon Sep 17 00:00:00 2001 From: Alex Strick van Linschoten Date: Mon, 29 Sep 2025 12:14:53 +0200 Subject: [PATCH 12/14] Update Agents file with instructions about commit messages --- AGENTS.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/AGENTS.md b/AGENTS.md index ea4df462504..225008d9be6 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -107,6 +107,9 @@ Use filesystem navigation tools to explore the codebase structure as needed. - Use imperative mood: "Add feature" not "Added feature" - Reference issue numbers when applicable: "Fix user auth bug (#1234)" - For multi-line messages, add a blank line after the summary + +Codex-style agents must review the diff (when not already tracked), craft a concise summary line, and include a detailed body covering the key changes. The body should describe the main code adjustments so reviewers can understand the scope from the commit message alone. + - Example: ``` Add retry logic to artifact upload From 7df954513c42fdee3282571e10fc90eced808583 Mon Sep 17 00:00:00 2001 From: Alex Strick van Linschoten Date: Mon, 29 Sep 2025 12:15:08 +0200 Subject: [PATCH 13/14] Remove partial comment --- examples/prompt_optimization/requirements.txt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/prompt_optimization/requirements.txt b/examples/prompt_optimization/requirements.txt index 2e210ed5bfb..f5ce8dc2336 100644 --- a/examples/prompt_optimization/requirements.txt +++ b/examples/prompt_optimization/requirements.txt @@ -1,4 +1,4 @@ -# ZenML (use existing installation) +# ZenML zenml # Core dependencies with version constraints for compatibility @@ -6,7 +6,7 @@ pandas>=2.0.0,<3.0.0 numpy>=1.24.0,<2.0.0 duckdb>=1.0.0,<2.0.0 -# Pydantic compatibility - CRITICAL: Must be <2.10 for ZenML compatibility +# Pydantic compatibility pydantic # AI/ML frameworks @@ -18,4 +18,4 @@ anthropic>=0.30.0 # Alternative LLM provider datasets>=2.14.0,<3.0.0 # HuggingFace datasets (optional for local files) # Utility libraries -nest-asyncio>=1.5.6,<2.0.0 # For async compatibility in notebooks \ No newline at end of file +nest-asyncio>=1.5.6,<2.0.0 # For async compatibility in notebooks From 9f4f5449840bc62b425d0d2018424a15e158f35b Mon Sep 17 00:00:00 2001 From: Alex Strick van Linschoten Date: Mon, 29 Sep 2025 15:22:13 +0200 Subject: [PATCH 14/14] Enhance prompt optimization example with scoring and provider auto-detection This update significantly improves the prompt optimization example with several key features: - Auto-detect LLM providers (OpenAI/Anthropic) from model names or environment - Add configurable scoring system with quality, speed, and findings weights - Introduce scoreboard artifact to track and compare prompt performance - Improve production pipeline with graceful fallback to default prompts - Expand CLI with options for custom prompts, sampling, and scoring config - Support fully-qualified model names (provider:model format) The scoring system uses normalized weights and caps to prevent gaming, while the provider auto-detection simplifies setup for users switching between models. --- examples/prompt_optimization/README.md | 85 +++++-- examples/prompt_optimization/models.py | 177 +++++++++++++- .../pipelines/production_eda_pipeline.py | 25 +- .../pipelines/prompt_optimization_pipeline.py | 18 +- examples/prompt_optimization/requirements.txt | 5 +- examples/prompt_optimization/run.py | 222 ++++++++++++++++-- .../prompt_optimization/steps/agent_tools.py | 60 ++++- .../prompt_optimization/steps/eda_agent.py | 28 ++- examples/prompt_optimization/steps/ingest.py | 5 +- .../steps/prompt_optimization.py | 203 +++++++++------- 10 files changed, 663 insertions(+), 165 deletions(-) diff --git a/examples/prompt_optimization/README.md b/examples/prompt_optimization/README.md index cfa8579ae68..d8ae92ab6b5 100644 --- a/examples/prompt_optimization/README.md +++ b/examples/prompt_optimization/README.md @@ -7,15 +7,18 @@ This example demonstrates **ZenML's artifact management** through a two-stage AI **Stage 1: Prompt Optimization** - Tests multiple prompt variants against sample data - Compares performance using data quality scores and execution time +- Emits a scoreboard artifact that summarizes quality, speed, findings, and success per prompt - Tags the best-performing prompt with an **exclusive ZenML tag** - Stores the optimized prompt in ZenML's artifact registry **Stage 2: Production Analysis** -- Automatically retrieves the tagged optimal prompt from the registry -- Runs production EDA analysis using the best prompt +- Automatically attempts to retrieve the latest optimized prompt from the registry +- Falls back to the default system prompt if no optimized prompt is available or retrieval fails +- Runs production EDA analysis using the selected prompt +- Returns a `used_optimized_prompt` boolean to indicate whether the optimized prompt was actually used - Demonstrates real artifact sharing between pipeline runs -This showcases how ZenML enables **reproducible ML workflows** where optimization results automatically flow into production systems. +This showcases how ZenML enables **reproducible ML workflows** where optimization results automatically flow into production systems, with safe fallbacks. ## Quick Start @@ -42,6 +45,10 @@ python run.py # Run individual stages python run.py --optimization-pipeline # Stage 1: Find best prompt python run.py --production-pipeline # Stage 2: Use optimized prompt + +# Force a specific provider/model (override auto-detection) +python run.py --provider openai --model-name gpt-4o-mini +python run.py --provider anthropic --model-name claude-3-haiku-20240307 ``` ## Data Sources @@ -58,6 +65,9 @@ python run.py --data-source "local:./my_data.csv" # Specify target column python run.py --data-source "local:sales.csv" --target-column "revenue" + +# Sample a subset of rows for faster iterations +python run.py --data-source "hf:scikit-learn/iris" --sample-size 500 ``` ## Key ZenML Concepts Demonstrated @@ -79,32 +89,77 @@ python run.py --data-source "local:sales.csv" --target-column "revenue" ## Configuration Options +### Provider and Model Selection ```bash -# Model selection -python run.py --model-name "gpt-4o-mini" -python run.py --model-name "claude-3-haiku-20240307" +# Auto (default): infer provider from model name or environment keys +python run.py + +# Force provider explicitly +python run.py --provider openai --model-name gpt-4o-mini +python run.py --provider anthropic --model-name claude-3-haiku-20240307 + +# Fully-qualified model names are also supported +python run.py --model-name "openai:gpt-4o-mini" +python run.py --model-name "anthropic:claude-3-haiku-20240307" +``` + +### Scoring Configuration +Configure how prompts are ranked during optimization: +```bash +# Weights for the aggregate score +python run.py --weight-quality 0.7 --weight-speed 0.2 --weight-findings 0.1 + +# Speed penalty (points per second) applied to compute a speed score +python run.py --speed-penalty-per-second 2.0 + +# Findings scoring (base points per finding) and cap +python run.py --findings-score-per-item 0.5 --findings-cap 20 +``` -# Performance tuning +*Why a cap?* It prevents a variant that simply emits an excessively long list of "findings" from dominating the overall score. Capping keeps the aggregate score bounded and comparable across variants while still rewarding useful coverage. + +### Custom Prompt Files +Provide your own prompt variants via a file (UTF-8, one prompt per line; blank lines ignored): +```bash +python run.py --prompts-file ./my_prompts.txt +``` + +### Sampling +Downsample large datasets to speed up experiments: +```bash +python run.py --sample-size 500 +``` + +### Budgets and Timeouts +Budgets are enforced at the tool boundary for deterministic behavior and cost control: +```bash +# Tool-call budget and overall time budget for the agent python run.py --max-tool-calls 8 --timeout-seconds 120 +``` +- The agent tools check and enforce these budgets during execution. -# Development options -python run.py --no-cache # Disable caching for fresh runs +### Caching +```bash +# Disable caching for fresh runs +python run.py --no-cache ``` ## Expected Output When you run the complete workflow, you'll see: -1. **Optimization Stage**: Testing of 3 prompt variants with performance metrics -2. **Tagging**: Best prompt automatically tagged in ZenML registry -3. **Production Stage**: Retrieval and use of optimized prompt -4. **Results**: EDA analysis with data quality scores and recommendations +1. **Optimization Stage**: Testing of multiple prompt variants with performance metrics +2. **Scoreboard**: CLI prints a compact top-3 summary (score, time, findings, success) and a short preview of the best prompt; a full scoreboard artifact is saved +3. **Tagging**: Best prompt automatically tagged in ZenML registry (exclusive "optimized" tag) +4. **Production Stage**: Retrieval and use of optimized prompt if available; otherwise falls back to the default +5. **Results**: EDA analysis with data quality scores and recommendations, plus a `used_optimized_prompt` flag indicating whether an optimized prompt was actually used -The ZenML dashboard will show the complete lineage from optimization to production use. +The ZenML dashboard will show the complete lineage from optimization to production use, including the prompt scoreboard and tagged best prompt. ## Next Steps - **View Results**: Check the ZenML dashboard for pipeline runs and artifacts -- **Customize Prompts**: Modify the prompt variants in `run.py` for your domain +- **Customize Prompts**: Provide your own variants via `--prompts-file` +- **Tune Scoring**: Adjust weights and penalties to match your evaluation criteria - **Scale Up**: Deploy with remote orchestrators for production workloads - **Integrate**: Connect to your existing data sources and ML pipelines \ No newline at end of file diff --git a/examples/prompt_optimization/models.py b/examples/prompt_optimization/models.py index c49fa562c3f..4b60283a52c 100644 --- a/examples/prompt_optimization/models.py +++ b/examples/prompt_optimization/models.py @@ -1,8 +1,9 @@ """Simple data models for prompt optimization example.""" -from typing import List, Optional +import os +from typing import List, Literal, Optional, Tuple -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator, model_validator class DataSourceConfig(BaseModel): @@ -20,15 +21,115 @@ class DataSourceConfig(BaseModel): ) +def normalize_model_name(name: str) -> str: + """Normalize a model name by stripping any leading provider prefix. + + This prevents accidental double-prefixing when constructing a provider:model id. + + Examples: + - 'openai:gpt-4o-mini' -> 'gpt-4o-mini' + - 'anthropic:claude-3-haiku-20240307' -> 'claude-3-haiku-20240307' + - 'gpt-4o-mini' -> 'gpt-4o-mini' (unchanged) + """ + if not isinstance(name, str): + return name + value = name.strip() + if ":" in value: + maybe_provider, remainder = value.split(":", 1) + if maybe_provider.lower() in ("openai", "anthropic"): + return remainder.strip() + return value + + +def infer_provider( + model_name: Optional[str] = None, +) -> Literal["openai", "anthropic"]: + """Infer the provider from model-name hints first, then environment variables. + + Rules: + - If `model_name` clearly references Claude โ†’ Anthropic. + - If `model_name` clearly references GPT / -o models โ†’ OpenAI. + - Otherwise use environment keys as tie-breakers. + - If both or neither keys exist and there is no hint, default to OpenAI. + """ + if isinstance(model_name, str): + mn = model_name.lower().strip() + # Handle common hints robustly + if mn.startswith("claude") or "claude" in mn: + return "anthropic" + if mn.startswith("gpt") or "gpt-" in mn or "-o" in mn: + # e.g., gpt-4o, gpt-4o-mini, gpt-4.1 + return "openai" + + has_openai = bool(os.getenv("OPENAI_API_KEY")) + has_anthropic = bool(os.getenv("ANTHROPIC_API_KEY")) + + if has_openai and not has_anthropic: + return "openai" + if has_anthropic and not has_openai: + return "anthropic" + + # Both present or neither present โ†’ default to OpenAI for GPT-style defaults + return "openai" + + class AgentConfig(BaseModel): """AI agent configuration.""" + provider: Optional[Literal["openai", "anthropic"]] = Field( + default=None, + description="Optional model provider. If omitted, inferred from the model name or environment.", + ) model_name: str = Field("gpt-4o-mini", description="Language model to use") max_tool_calls: int = Field(6, description="Maximum tool calls allowed") timeout_seconds: int = Field( 60, description="Max execution time in seconds" ) + @field_validator("model_name", mode="before") + @classmethod + def _normalize_model_name(cls, v: object) -> object: + """Normalize the model name early to remove accidental provider prefixes.""" + if isinstance(v, str): + return normalize_model_name(v) + return v + + @model_validator(mode="before") + @classmethod + def _set_provider_from_prefix(cls, data: object) -> object: + """If a provider prefix is present in model_name, capture it into `provider`. + + This preserves the user's explicit intent (e.g., 'openai:gpt-4o-mini') + even though `model_name` is normalized by the field validator. + """ + if not isinstance(data, dict): + return data + provider = data.get("provider") + raw_name = data.get("model_name") + if (provider is None or provider == "") and isinstance(raw_name, str): + value = raw_name.strip() + if ":" in value: + maybe_provider = value.split(":", 1)[0].strip().lower() + if maybe_provider in ("openai", "anthropic"): + data["provider"] = maybe_provider + return data + + def model_id(self) -> str: + """Return a 'provider:model' string, inferring the provider when necessary. + + Precedence: + 1) Explicit provider captured from model_name prefix or set via `provider` + 2) Heuristic inference using model_name hints or environment variables + """ + # Normalize again defensively to avoid double-prefixing if callers mutate model_name later + name = normalize_model_name(self.model_name) + provider: Literal["openai", "anthropic"] = ( + self.provider + if self.provider is not None + else infer_provider(name) + ) + return f"{provider}:{name}" + class EDAReport(BaseModel): """Simple EDA analysis report from AI agent.""" @@ -41,3 +142,75 @@ class EDAReport(BaseModel): description="Overall quality score (0-100)" ) markdown: str = Field(description="Full markdown report") + + +class ScoringConfig(BaseModel): + """Scoring config with weights, penalties, and caps.""" + + weight_quality: float = Field( + 0.7, + ge=0.0, + description="Weight for the quality component.", + ) + weight_speed: float = Field( + 0.2, + ge=0.0, + description="Weight for the speed/latency component.", + ) + weight_findings: float = Field( + 0.1, + ge=0.0, + description="Weight for the findings coverage component.", + ) + speed_penalty_per_second: float = Field( + 2.0, + ge=0.0, + description="Linear penalty (points per second) applied when computing a speed score.", + ) + findings_score_per_item: float = Field( + 0.5, + ge=0.0, + description="Base points credited per key finding before applying weights.", + ) + findings_cap: int = Field( + 20, + ge=0, + description="Maximum number of findings to credit when scoring.", + ) + + @property + def normalized_weights(self) -> Tuple[float, float, float]: + """Return (quality, speed, findings) weights normalized to sum to 1.""" + total = self.weight_quality + self.weight_speed + self.weight_findings + if total <= 0: + return (1.0, 0.0, 0.0) + return ( + self.weight_quality / total, + self.weight_speed / total, + self.weight_findings / total, + ) + + +class VariantScore(BaseModel): + """Score summary for a single prompt variant.""" + + prompt_id: str = Field( + description="Stable identifier for the variant (e.g., 'variant_1')." + ) + prompt: str = Field(description="The system prompt text evaluated.") + score: float = Field(description="Aggregate score used for ranking.") + quality_score: float = Field( + description="Quality score produced by the agent/report (0-100)." + ) + execution_time: float = Field( + ge=0.0, description="Execution time in seconds for the evaluation." + ) + findings_count: int = Field( + ge=0, description="Number of key findings counted for this run." + ) + success: bool = Field( + description="Whether the evaluation completed successfully." + ) + error: Optional[str] = Field( + None, description="Optional error message if success is False." + ) diff --git a/examples/prompt_optimization/pipelines/production_eda_pipeline.py b/examples/prompt_optimization/pipelines/production_eda_pipeline.py index a8f02ae1084..c60cd633459 100644 --- a/examples/prompt_optimization/pipelines/production_eda_pipeline.py +++ b/examples/prompt_optimization/pipelines/production_eda_pipeline.py @@ -16,29 +16,34 @@ def production_eda_pipeline( source_config: DataSourceConfig, agent_config: Optional[AgentConfig] = None, ) -> Dict[str, Any]: - """Production EDA pipeline using optimized prompts from the registry. + """Production EDA pipeline using an optimized prompt when available. - This pipeline demonstrates ZenML's artifact retrieval by fetching - previously optimized prompts for production analysis. + The pipeline automatically attempts to retrieve the latest optimized prompt + from the artifact registry (tagged 'optimized') and falls back to the + default system prompt if none is available or retrieval fails. Args: source_config: Data source configuration agent_config: AI agent configuration - use_optimized_prompt: Whether to use optimized prompt from registry Returns: - EDA results and metadata + EDA results and metadata, including whether an optimized prompt was used. """ logger.info("๐Ÿญ Starting production EDA pipeline") - # Step 1: Get optimized prompt - optimized_prompt = get_optimized_prompt() - logger.info("๐ŸŽฏ Retrieved optimized prompt") + # Step 1: Get optimized prompt or fall back to default + optimized_prompt, used_optimized = get_optimized_prompt() + if used_optimized: + logger.info("๐ŸŽฏ Using optimized prompt retrieved from artifact") + else: + logger.warning( + "๐Ÿ” Optimized prompt not found; falling back to default system prompt" + ) # Step 2: Load data dataset_df, metadata = ingest_data(source_config=source_config) - # Step 3: Run EDA analysis with optimized prompt + # Step 3: Run EDA analysis with the selected prompt report_markdown, report_json, sql_log, analysis_tables = run_eda_agent( dataset_df=dataset_df, dataset_metadata=metadata, @@ -53,6 +58,6 @@ def production_eda_pipeline( "report_json": report_json, "sql_log": sql_log, "analysis_tables": analysis_tables, - "used_optimized_prompt": optimized_prompt is not None, + "used_optimized_prompt": used_optimized, "metadata": metadata, } diff --git a/examples/prompt_optimization/pipelines/prompt_optimization_pipeline.py b/examples/prompt_optimization/pipelines/prompt_optimization_pipeline.py index 7d936865dbc..18ecc4b48b2 100644 --- a/examples/prompt_optimization/pipelines/prompt_optimization_pipeline.py +++ b/examples/prompt_optimization/pipelines/prompt_optimization_pipeline.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional -from models import AgentConfig, DataSourceConfig +from models import AgentConfig, DataSourceConfig, ScoringConfig from steps import compare_prompts_and_tag_best, ingest_data from zenml import pipeline @@ -16,6 +16,7 @@ def prompt_optimization_pipeline( source_config: DataSourceConfig, prompt_variants: List[str], agent_config: Optional[AgentConfig] = None, + scoring_config: Optional[ScoringConfig] = None, ) -> Dict[str, Any]: """Optimize prompts and tag the best one for production use. @@ -26,28 +27,35 @@ def prompt_optimization_pipeline( source_config: Data source configuration prompt_variants: List of prompt strings to test agent_config: AI agent configuration + scoring_config: Optional scoring configuration controlling ranking Returns: - Pipeline results with best prompt and metadata + Pipeline results with best prompt, scoreboard, effective scoring config, and metadata. """ logger.info("๐Ÿงช Starting prompt optimization pipeline") # Step 1: Load data dataset_df, metadata = ingest_data(source_config=source_config) - # Step 2: Test prompts and tag the best one - best_prompt = compare_prompts_and_tag_best( + # Use a single ScoringConfig instance for both the step and return payload + effective_scoring = scoring_config or ScoringConfig() + + # Step 2: Test prompts and tag the best one, collecting a scoreboard + best_prompt, scoreboard = compare_prompts_and_tag_best( dataset_df=dataset_df, prompt_variants=prompt_variants, agent_config=agent_config, + scoring_config=effective_scoring, ) logger.info( - "โœ… Prompt optimization completed - best prompt tagged with 'optimized'" + "โœ… Prompt optimization completed - best prompt tagged and scoreboard generated" ) return { "best_prompt": best_prompt, + "scoreboard": [entry.model_dump() for entry in scoreboard], + "scoring": effective_scoring.model_dump(), "metadata": metadata, "config": { "source": f"{source_config.source_type}:{source_config.source_path}", diff --git a/examples/prompt_optimization/requirements.txt b/examples/prompt_optimization/requirements.txt index f5ce8dc2336..3215a20c163 100644 --- a/examples/prompt_optimization/requirements.txt +++ b/examples/prompt_optimization/requirements.txt @@ -1,5 +1,5 @@ # ZenML -zenml +zenml[server] # Core dependencies with version constraints for compatibility pandas>=2.0.0,<3.0.0 @@ -7,7 +7,7 @@ numpy>=1.24.0,<2.0.0 duckdb>=1.0.0,<2.0.0 # Pydantic compatibility -pydantic +pydantic>=2.6,<3 # AI/ML frameworks pydantic-ai[logfire]>=0.4.0 @@ -18,4 +18,3 @@ anthropic>=0.30.0 # Alternative LLM provider datasets>=2.14.0,<3.0.0 # HuggingFace datasets (optional for local files) # Utility libraries -nest-asyncio>=1.5.6,<2.0.0 # For async compatibility in notebooks diff --git a/examples/prompt_optimization/run.py b/examples/prompt_optimization/run.py index 309eb612f82..f80206dc584 100644 --- a/examples/prompt_optimization/run.py +++ b/examples/prompt_optimization/run.py @@ -4,9 +4,9 @@ import argparse import os import sys -from typing import Optional +from typing import List, Optional -from models import AgentConfig, DataSourceConfig +from models import AgentConfig, DataSourceConfig, ScoringConfig from pipelines.production_eda_pipeline import production_eda_pipeline from pipelines.prompt_optimization_pipeline import prompt_optimization_pipeline @@ -22,6 +22,10 @@ def build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="Run prompt optimization and/or production EDA pipelines (both by default)" ) + + # Instantiate defaults from ScoringConfig so CLI flags reflect canonical values + scoring_defaults = ScoringConfig() + parser.add_argument( "--optimization-pipeline", action="store_true", @@ -35,7 +39,7 @@ def build_parser() -> argparse.ArgumentParser: parser.add_argument( "--data-source", default="hf:scikit-learn/iris", - help="Data source (type:path)", + help="Data source in 'type:path' format, e.g., 'hf:scikit-learn/iris' or 'local:./data.csv'", ) parser.add_argument( "--target-column", @@ -44,7 +48,22 @@ def build_parser() -> argparse.ArgumentParser: ) parser.add_argument( "--model-name", - help="Model name (auto-detected if not specified)", + help="Model name (auto-detected if not specified). Can be fully-qualified like 'openai:gpt-4o-mini' or 'anthropic:claude-3-haiku-20240307'.", + ) + parser.add_argument( + "--provider", + choices=["auto", "openai", "anthropic"], + default="auto", + help="Model provider to use. 'auto' infers from model/name/environment.", + ) + parser.add_argument( + "--sample-size", + type=int, + help="Optional row sample size for the dataset.", + ) + parser.add_argument( + "--prompts-file", + help="Path to a newline-delimited prompts file (UTF-8). One prompt per line; blank lines are ignored.", ) parser.add_argument( "--max-tool-calls", @@ -63,6 +82,44 @@ def build_parser() -> argparse.ArgumentParser: action="store_true", help="Disable caching", ) + + # Scoring knobs (seeded from ScoringConfig defaults) + parser.add_argument( + "--weight-quality", + type=float, + default=scoring_defaults.weight_quality, + help="Weight for the quality component.", + ) + parser.add_argument( + "--weight-speed", + type=float, + default=scoring_defaults.weight_speed, + help="Weight for the speed/latency component.", + ) + parser.add_argument( + "--weight-findings", + type=float, + default=scoring_defaults.weight_findings, + help="Weight for the findings coverage component.", + ) + parser.add_argument( + "--speed-penalty-per-second", + type=float, + default=scoring_defaults.speed_penalty_per_second, + help="Penalty points per second applied when computing a speed score.", + ) + parser.add_argument( + "--findings-score-per-item", + type=float, + default=scoring_defaults.findings_score_per_item, + help="Base points credited per key finding before applying weights.", + ) + parser.add_argument( + "--findings-cap", + type=int, + default=scoring_defaults.findings_cap, + help="Maximum number of findings to credit when scoring.", + ) return parser @@ -86,53 +143,174 @@ def main() -> int: print("โŒ Set OPENAI_API_KEY or ANTHROPIC_API_KEY", file=sys.stderr) return 1 - # Auto-detect model if not explicitly provided. This keeps a sensible default - # that aligns with whichever provider the user configured. - model_name: Optional[str] = args.model_name - if model_name is None: - model_name = "gpt-4o-mini" if has_openai else "claude-3-haiku-20240307" + # Derive raw model name from CLI or sensible defaults considering provider/env + raw_model_name: Optional[str] = args.model_name + if raw_model_name is None: + if args.provider == "openai": + raw_model_name = "gpt-4o-mini" + elif args.provider == "anthropic": + raw_model_name = "claude-3-haiku-20240307" + else: + raw_model_name = ( + "gpt-4o-mini" if has_openai else "claude-3-haiku-20240307" + ) + + # Detect and strip provider prefixes (e.g., 'openai:gpt-4o-mini'), preserving explicit provider + explicit_provider: Optional[str] = None + normalized_model_name = raw_model_name.strip() + if ":" in normalized_model_name: + maybe_provider, remainder = normalized_model_name.split(":", 1) + maybe_provider = maybe_provider.strip().lower() + if maybe_provider in ("openai", "anthropic"): + explicit_provider = maybe_provider + normalized_model_name = remainder.strip() - # Parse data source "type:path" into its components early so we can fail-fast - # with a helpful error if the format is invalid. + # Combine with --provider to determine provider choice + if args.provider == "auto": + provider_choice: Optional[str] = explicit_provider or None + else: + provider_choice = args.provider + + # Validate that the chosen provider has its corresponding API key configured + if provider_choice == "openai" and not has_openai: + print( + "โŒ You selected --provider openai (or model prefix 'openai:') but OPENAI_API_KEY is not set.", + file=sys.stderr, + ) + return 1 + if provider_choice == "anthropic" and not has_anthropic: + print( + "โŒ You selected --provider anthropic (or model prefix 'anthropic:') but ANTHROPIC_API_KEY is not set.", + file=sys.stderr, + ) + return 1 + + # Parse data source "type:path" into its components try: source_type, source_path = args.data_source.split(":", 1) except ValueError: - print(f"โŒ Invalid data source: {args.data_source}", file=sys.stderr) + print( + "โŒ Invalid --data-source. Expected 'type:path' (e.g., 'hf:scikit-learn/iris' or 'local:./data.csv').", + file=sys.stderr, + ) + print( + f"Provided value: {args.data_source}. Update the flag or use one of the examples above.", + file=sys.stderr, + ) return 2 + # Prepare prompt variants: file-based if provided; otherwise use baked-in defaults + prompt_variants: List[str] + if args.prompts_file: + try: + with open(args.prompts_file, "r", encoding="utf-8") as f: + prompt_variants = [ + line.strip() + for line in f.read().splitlines() + if line.strip() + ] + if not prompt_variants: + print( + f"โŒ Prompts file is empty after removing blank lines: {args.prompts_file}", + file=sys.stderr, + ) + print( + "Add one prompt per line, or omit --prompts-file to use the built-in defaults.", + file=sys.stderr, + ) + return 4 + except OSError: + print( + f"โŒ Unable to read prompts file: {args.prompts_file}", + file=sys.stderr, + ) + print( + "Provide a valid path with --prompts-file or omit the flag to use the built-in prompts.", + file=sys.stderr, + ) + return 3 + else: + prompt_variants = [ + "You are a data analyst. Quickly assess data quality (0-100 score) and identify key patterns. Be concise.", + "You are a data quality specialist. Calculate quality score (0-100), find missing data, duplicates, and correlations. Provide specific recommendations.", + "You are a business analyst. Assess ML readiness with quality score. Focus on business impact and actionable insights.", + ] + # Create configs passed to the pipelines source_config = DataSourceConfig( source_type=source_type, source_path=source_path, target_column=args.target_column, + sample_size=args.sample_size, ) agent_config = AgentConfig( - model_name=model_name, + model_name=normalized_model_name, + provider=provider_choice, # None => auto inference in AgentConfig max_tool_calls=args.max_tool_calls, timeout_seconds=args.timeout_seconds, ) + scoring_config = ScoringConfig( + weight_quality=args.weight_quality, + weight_speed=args.weight_speed, + weight_findings=args.weight_findings, + speed_penalty_per_second=args.speed_penalty_per_second, + findings_score_per_item=args.findings_score_per_item, + findings_cap=args.findings_cap, + ) + # ZenML run options: keep parity with the original example pipeline_options = {"enable_cache": not args.no_cache} + # Concise transparency logs + sample_info = args.sample_size if args.sample_size is not None else "all" + print( + f"โ„น๏ธ Provider: {provider_choice or 'auto'} | Model: {normalized_model_name} | Data: {source_type}:{source_path} (target={args.target_column}, sample={sample_info})" + ) + print( + f"โ„น๏ธ Prompt variants: {len(prompt_variants)} | Tool calls: {args.max_tool_calls} | Timeout: {args.timeout_seconds}s" + ) + # Stage 1: Prompt optimization if optimization_pipeline: print("๐Ÿงช Running prompt optimization...") - - prompt_variants = [ - "You are a data analyst. Quickly assess data quality (0-100 score) and identify key patterns. Be concise.", - "You are a data quality specialist. Calculate quality score (0-100), find missing data, duplicates, and correlations. Provide specific recommendations.", - "You are a business analyst. Assess ML readiness with quality score. Focus on business impact and actionable insights.", - ] - try: - prompt_optimization_pipeline.with_options(**pipeline_options)( + optimization_result = prompt_optimization_pipeline.with_options( + **pipeline_options + )( source_config=source_config, prompt_variants=prompt_variants, agent_config=agent_config, + scoring_config=scoring_config, ) - print("โœ… Optimization completed - best prompt tagged") + print("โœ… Optimization completed") + # Pretty-print a compact scoreboard summary if available + if isinstance(optimization_result, dict): + scoreboard = optimization_result.get("scoreboard") or [] + best_prompt = optimization_result.get("best_prompt") + if isinstance(scoreboard, list) and scoreboard: + top = sorted( + scoreboard, + key=lambda x: x.get("score", 0.0), + reverse=True, + )[:3] + print("๐Ÿ“Š Scoreboard (top 3):") + for entry in top: + pid = entry.get("prompt_id", "?") + sc = entry.get("score", 0.0) + t = entry.get("execution_time", 0.0) + f = entry.get("findings_count", 0) + ok = entry.get("success", False) + mark = "โœ…" if ok else "โŒ" + print( + f"- {pid} | score: {sc:.1f} | time: {t:.1f}s | findings: {f} | {mark}" + ) + if isinstance(best_prompt, str) and best_prompt: + preview = best_prompt.replace("\n", " ")[:80] + print( + f"๐Ÿ“ Best prompt preview: {preview}{'...' if len(best_prompt) > 80 else ''}" + ) except Exception as e: # Best-effort behavior: log the error and continue to the next stage print(f"โŒ Optimization failed: {e}", file=sys.stderr) diff --git a/examples/prompt_optimization/steps/agent_tools.py b/examples/prompt_optimization/steps/agent_tools.py index c04309031b2..cfaa394f3e8 100644 --- a/examples/prompt_optimization/steps/agent_tools.py +++ b/examples/prompt_optimization/steps/agent_tools.py @@ -1,7 +1,10 @@ """Simple Pydantic AI agent tools for EDA analysis.""" +import time from dataclasses import dataclass, field -from typing import Any, Dict, List +from functools import wraps +from threading import Lock +from typing import Any, Callable, Dict, List, Optional import duckdb import pandas as pd @@ -12,13 +15,18 @@ class AnalystAgentDeps: """Simple storage for analysis results with Out[n] references.""" - output: dict[str, pd.DataFrame] = field(default_factory=dict) + output: Dict[str, pd.DataFrame] = field(default_factory=dict) query_history: List[Dict[str, Any]] = field(default_factory=list) + tool_calls: int = 0 + started_at: float = field(default_factory=time.monotonic) + time_budget_s: Optional[float] = None + lock: Lock = field(default_factory=Lock, repr=False, compare=False) def store(self, value: pd.DataFrame) -> str: """Store the output and return reference like Out[1] for the LLM.""" - ref = f"Out[{len(self.output) + 1}]" - self.output[ref] = value + with self.lock: + ref = f"Out[{len(self.output) + 1}]" + self.output[ref] = value return ref def get(self, ref: str) -> pd.DataFrame: @@ -44,14 +52,16 @@ def run_sql(ctx: RunContext[AnalystAgentDeps], dataset: str, sql: str) -> str: result = duckdb.query_df( df=data, virtual_table_name="dataset", sql_query=sql ) - ref = ctx.deps.store(result.df()) + df = result.df() + rows = len(df) + ref = ctx.deps.store(df) # Log the query for tracking ctx.deps.query_history.append( - {"sql": sql, "result_ref": ref, "rows_returned": len(result.df())} + {"sql": sql, "result_ref": ref, "rows_returned": rows} ) - return f"Query executed successfully. Result stored as `{ref}` ({len(result.df())} rows)." + return f"Query executed successfully. Result stored as `{ref}` ({rows} rows)." except Exception as e: raise ModelRetry(f"SQL query failed: {str(e)}") @@ -157,5 +167,41 @@ def analyze_correlations( return f"Correlation analysis error: {str(e)}" +def budget_wrapper(max_tool_calls: Optional[int]): + """Return a decorator that enforces time/tool-call budgets for tools. + + It reads `time_budget_s`, `started_at`, and `tool_calls` from ctx.deps and + raises ModelRetry when limits are exceeded. + """ + + def with_budget(tool: Callable) -> Callable: + @wraps(tool) + def _wrapped(ctx: RunContext[AnalystAgentDeps], *args, **kwargs): + # Enforce budgets atomically to be safe under parallel tool execution + with ctx.deps.lock: + tb = getattr(ctx.deps, "time_budget_s", None) + if ( + tb is not None + and (time.monotonic() - ctx.deps.started_at) > tb + ): + raise ModelRetry("Time budget exceeded.") + + if ( + max_tool_calls is not None + and ctx.deps.tool_calls >= max_tool_calls + ): + raise ModelRetry("Tool-call budget exceeded.") + + # Increment tool call count after passing checks + ctx.deps.tool_calls += 1 + + # Execute the actual tool logic outside the lock + return tool(ctx, *args, **kwargs) + + return _wrapped + + return with_budget + + # Enhanced tool registry AGENT_TOOLS = [run_sql, display, describe, analyze_correlations] diff --git a/examples/prompt_optimization/steps/eda_agent.py b/examples/prompt_optimization/steps/eda_agent.py index 82079ea7863..f25b7ceea7e 100644 --- a/examples/prompt_optimization/steps/eda_agent.py +++ b/examples/prompt_optimization/steps/eda_agent.py @@ -22,7 +22,7 @@ LOGFIRE_AVAILABLE = False from models import AgentConfig, EDAReport -from steps.agent_tools import AGENT_TOOLS, AnalystAgentDeps +from steps.agent_tools import AGENT_TOOLS, AnalystAgentDeps, budget_wrapper from steps.prompt_text import DEFAULT_SYSTEM_PROMPT, build_user_prompt @@ -35,7 +35,7 @@ def run_eda_agent( ) -> Tuple[ Annotated[MarkdownString, "eda_report_markdown"], Annotated[Dict[str, Any], "eda_report_json"], - Annotated[List[Dict[str, str]], "sql_execution_log"], + Annotated[List[Dict[str, Any]], "sql_execution_log"], Annotated[Dict[str, pd.DataFrame], "analysis_tables"], ]: """Run Pydantic AI agent for EDA analysis with optional custom prompt. @@ -61,8 +61,12 @@ def run_eda_agent( except Exception as e: print(f"Warning: Failed to configure Logfire: {e}") - # Initialize agent dependencies and store the dataset - deps = AnalystAgentDeps() + # Initialize agent dependencies with time budget and store the dataset + deps = AnalystAgentDeps( + time_budget_s=float(agent_config.timeout_seconds) + if agent_config + else None + ) main_ref = deps.store(dataset_df) # Create the EDA analyst agent with system prompt (custom or default) @@ -73,8 +77,11 @@ def run_eda_agent( system_prompt = DEFAULT_SYSTEM_PROMPT logger.info("๐Ÿ“ Using default system prompt for analysis") + # Provider:model id computed once for both agent creation and metadata + provider_model = agent_config.model_id() + analyst_agent = Agent( - f"openai:{agent_config.model_name}", + provider_model, deps_type=AnalystAgentDeps, output_type=EDAReport, output_retries=3, # Allow more retries for result validation @@ -84,9 +91,12 @@ def run_eda_agent( ), ) - # Register tools + # Use shared budget wrapper to enforce time/tool-call limits for all tools + wrapper = budget_wrapper(getattr(agent_config, "max_tool_calls", None)) + + # Register tools with budget enforcement for tool in AGENT_TOOLS: - analyst_agent.tool(tool) + analyst_agent.tool(wrapper(tool)) # Run focused analysis using shared user prompt builder user_prompt = build_user_prompt(main_ref, dataset_df) @@ -118,7 +128,9 @@ def run_eda_agent( "data_quality_score": eda_report.data_quality_score, "agent_metadata": { "model": agent_config.model_name, - "tool_calls": len(deps.query_history), + "provider_model": provider_model, + "tool_calls": deps.tool_calls, + "sql_queries": len(deps.query_history), }, }, deps.query_history, diff --git a/examples/prompt_optimization/steps/ingest.py b/examples/prompt_optimization/steps/ingest.py index b35f002ce9f..867f73f8947 100644 --- a/examples/prompt_optimization/steps/ingest.py +++ b/examples/prompt_optimization/steps/ingest.py @@ -41,10 +41,11 @@ def ingest_data( ) # Apply sampling if configured - if source_config.sample_size and len(df) > source_config.sample_size: + total = len(df) + if source_config.sample_size and total > source_config.sample_size: df = df.sample(n=source_config.sample_size, random_state=42) logger.info( - f"Sampled {source_config.sample_size} rows from {len(df)} total" + f"Sampled {source_config.sample_size} rows from {total} total" ) # Generate simple metadata diff --git a/examples/prompt_optimization/steps/prompt_optimization.py b/examples/prompt_optimization/steps/prompt_optimization.py index 239a9c6af75..ef71469abc3 100644 --- a/examples/prompt_optimization/steps/prompt_optimization.py +++ b/examples/prompt_optimization/steps/prompt_optimization.py @@ -1,13 +1,13 @@ """Simple prompt optimization step for demonstrating ZenML artifact management.""" import time -from typing import Annotated +from typing import Annotated, List, Optional, Tuple import pandas as pd -from models import AgentConfig, EDAReport +from models import AgentConfig, EDAReport, ScoringConfig, VariantScore from pydantic_ai import Agent from pydantic_ai.settings import ModelSettings -from steps.agent_tools import AGENT_TOOLS, AnalystAgentDeps +from steps.agent_tools import AGENT_TOOLS, AnalystAgentDeps, budget_wrapper from steps.prompt_text import DEFAULT_SYSTEM_PROMPT, build_user_prompt from zenml import ArtifactConfig, Tag, add_tags, step @@ -16,72 +16,71 @@ logger = get_logger(__name__) -# Scoring weights reflect priorities: quality first (insightful outputs), -# then speed (encourage lower latency), then findings (reward coverage). -WEIGHT_QUALITY: float = 0.7 -WEIGHT_SPEED: float = 0.2 -WEIGHT_FINDINGS: float = 0.1 - -# Linear time penalty: each second reduces the speed score by this many points -# until the score floors at 0 (capped at 100 points of penalty). -SPEED_PENALTY_PER_SECOND: float = 2.0 - -# Reward per key finding discovered by the agent before applying the findings weight. -# Keeping this explicit makes it easy to tune coverage incentives. -FINDINGS_SCORE_PER_ITEM: float = 0.5 - +# NOTE: Quality and speed are both on a 0โ€“100 scale. Findings accrue raw points +# (capped via `findings_cap`) before weights are applied. After weight normalization, +# the aggregate score remains roughly bounded within 0โ€“100. def compute_prompt_score( - eda_report: EDAReport, execution_time: float + eda_report: EDAReport, execution_time: float, scoring: ScoringConfig ) -> float: - """Compute a prompt's score from EDA results and runtime. - - This makes scoring trade-offs explicit and tunable via module-level constants: - - Prioritize report quality (WEIGHT_QUALITY) - - Encourage faster execution via a linear time penalty converted to a 0โ€“100 speed score (WEIGHT_SPEED) - - Reward thoroughness by crediting key findings (WEIGHT_FINDINGS) - """ - speed_score = max( - 0.0, - 100.0 - min(execution_time * SPEED_PENALTY_PER_SECOND, 100.0), + wq, ws, wf = scoring.normalized_weights + speed_penalty = min( + execution_time * scoring.speed_penalty_per_second, 100.0 ) - findings_score = len(eda_report.key_findings) * FINDINGS_SCORE_PER_ITEM + speed_score = max(0.0, 100.0 - speed_penalty) + credited_findings = min(len(eda_report.key_findings), scoring.findings_cap) + findings_score = credited_findings * scoring.findings_score_per_item return ( - eda_report.data_quality_score * WEIGHT_QUALITY - + speed_score * WEIGHT_SPEED - + findings_score * WEIGHT_FINDINGS + eda_report.data_quality_score * wq + + speed_score * ws + + findings_score * wf ) @step def compare_prompts_and_tag_best( dataset_df: pd.DataFrame, - prompt_variants: list[str], - agent_config: AgentConfig = None, -) -> Annotated[str, ArtifactConfig(name="best_prompt")]: - """Compare prompt variants and tag the best one with exclusive 'optimized' tag. - - This step demonstrates ZenML's artifact management by: - 1. Testing multiple prompt variants - 2. Finding the best performer - 3. Returning it as a tagged artifact that other pipelines can find - - The 'optimized' tag is exclusive, so only one prompt can be 'optimized' at a time. + prompt_variants: List[str], + agent_config: AgentConfig | None = None, + scoring_config: Optional[ScoringConfig] = None, +) -> Tuple[ + Annotated[str, ArtifactConfig(name="best_prompt")], + Annotated[List[VariantScore], ArtifactConfig(name="prompt_scoreboard")], +]: + """Compare prompt variants, compute scores, and emit best prompt + scoreboard. + + Behavior: + - Provider/model inference owned by AgentConfig.model_id() + - Uniform time/tool-call budget enforcement via tool wrappers + - Score each variant using ScoringConfig; record success/failure with timing + - Tag best prompt with exclusive 'optimized' tag only if at least one success Args: dataset_df: Dataset to test prompts against prompt_variants: List of system prompts to compare - agent_config: Configuration for AI agents + agent_config: Configuration for AI agents (defaults applied if None) + scoring_config: Scoring configuration (defaults applied if None) Returns: - The best performing prompt string with exclusive 'optimized' tag + Tuple of: + - Best performing prompt string (artifact name 'best_prompt') + - Scoreboard entries for all variants (artifact name 'prompt_scoreboard') """ if agent_config is None: agent_config = AgentConfig() + if scoring_config is None: + scoring_config = ScoringConfig() logger.info(f"๐Ÿงช Testing {len(prompt_variants)} prompt variants") - results = [] + # Compute provider:model id once for consistent use across variants + provider_model = agent_config.model_id() + + # Prepare a shared wrapper that enforces budgets for all tools + wrapper = budget_wrapper(getattr(agent_config, "max_tool_calls", None)) + + # Collect VariantScore entries for all variants (success and failure) + scoreboard: List[VariantScore] = [] for i, system_prompt in enumerate(prompt_variants): prompt_id = f"variant_{i + 1}" @@ -90,20 +89,28 @@ def compare_prompts_and_tag_best( start_time = time.time() try: - # Create agent with this prompt - deps = AnalystAgentDeps() + # Create agent with this prompt and per-variant deps with time budget + deps = AnalystAgentDeps( + time_budget_s=float(agent_config.timeout_seconds) + if agent_config + else None + ) + deps.tool_calls = ( + 0 # Ensure tool-call counter is reset per variant + ) main_ref = deps.store(dataset_df) agent = Agent( - f"openai:{agent_config.model_name}", + provider_model, deps_type=AnalystAgentDeps, output_type=EDAReport, system_prompt=system_prompt, model_settings=ModelSettings(parallel_tool_calls=False), ) + # Register tools with shared budget enforcement for tool in AGENT_TOOLS: - agent.tool(tool) + agent.tool(wrapper(tool)) # Run analysis user_prompt = build_user_prompt(main_ref, dataset_df) @@ -112,19 +119,22 @@ def compare_prompts_and_tag_best( execution_time = time.time() - start_time - # Score this variant - score = compute_prompt_score(eda_report, execution_time) - - results.append( - { - "prompt_id": prompt_id, - "prompt": system_prompt, - "score": score, - "quality_score": eda_report.data_quality_score, - "execution_time": execution_time, - "findings_count": len(eda_report.key_findings), - "success": True, - } + # Score this variant using provided scoring configuration + score = compute_prompt_score( + eda_report, execution_time, scoring_config + ) + + scoreboard.append( + VariantScore( + prompt_id=prompt_id, + prompt=system_prompt, + score=score, + quality_score=eda_report.data_quality_score, + execution_time=execution_time, + findings_count=len(eda_report.key_findings), + success=True, + error=None, + ) ) logger.info( @@ -132,49 +142,59 @@ def compare_prompts_and_tag_best( ) except Exception as e: + execution_time = time.time() - start_time logger.warning(f"โŒ {prompt_id} failed: {e}") - results.append( - { - "prompt_id": prompt_id, - "prompt": system_prompt, - "score": 0, - "success": False, - "error": str(e), - } + scoreboard.append( + VariantScore( + prompt_id=prompt_id, + prompt=system_prompt, + score=0.0, + quality_score=0.0, + execution_time=execution_time, + findings_count=0, + success=False, + error=str(e), + ) ) - # Find best performer - successful_results = [r for r in results if r["success"]] + # Determine best performer among successful variants + successful_results = [entry for entry in scoreboard if entry.success] if not successful_results: logger.warning( "All prompts failed, falling back to DEFAULT_SYSTEM_PROMPT" ) best_prompt = DEFAULT_SYSTEM_PROMPT + logger.info("โญ๏ธ Skipping best prompt tagging since all variants failed") else: - best_result = max(successful_results, key=lambda x: x["score"]) - best_prompt = best_result["prompt"] + best_result = max(successful_results, key=lambda x: x.score) + best_prompt = best_result.prompt logger.info( - f"๐Ÿ† Best prompt: {best_result['prompt_id']} (score: {best_result['score']:.1f})" + f"๐Ÿ† Best prompt: {best_result.prompt_id} (score: {best_result.score:.1f})" ) - logger.info("๐Ÿ’พ Best prompt will be stored with exclusive 'optimized' tag") - - # Add exclusive tag to this step's output artifact - add_tags(tags=[Tag(name="optimized", exclusive=True)], infer_artifact=True) + logger.info( + "๐Ÿ’พ Best prompt will be stored with exclusive 'optimized' tag" + ) + # Explicitly target the named output artifact to avoid multi-output ambiguity + add_tags( + tags=[Tag(name="optimized", exclusive=True)], + output_name="best_prompt", + ) - return best_prompt + return best_prompt, scoreboard -def get_optimized_prompt() -> str: - """Retrieve the optimized prompt using ZenML's tag-based filtering. +def get_optimized_prompt() -> Tuple[str, bool]: + """Retrieve the optimized prompt from a tagged artifact, with safe fallback. This demonstrates ZenML's tag filtering by finding artifacts tagged with 'optimized'. Since 'optimized' is an exclusive tag, there will be at most one such artifact. Returns: - The optimized prompt from the latest optimization run as a plain string, - or DEFAULT_SYSTEM_PROMPT if none found or retrieval fails. + Tuple[str, bool]: (prompt, from_artifact) where: + - prompt: The optimized prompt text if found; DEFAULT_SYSTEM_PROMPT otherwise. + - from_artifact: True if the prompt was retrieved from an artifact; False on fallback. """ try: client = Client() @@ -189,19 +209,20 @@ def get_optimized_prompt() -> str: f"๐ŸŽฏ Retrieved optimized prompt from artifact: {optimized_artifact.id}" ) logger.info(f" Artifact created: {optimized_artifact.created}") - return prompt_value + return prompt_value, True else: logger.info( - "๐Ÿ” No optimized prompt found (no artifacts with 'optimized' tag). Using DEFAULT_SYSTEM_PROMPT." + "๐Ÿ” No optimized prompt found (no artifacts with 'optimized' tag). " + "Falling back to default (from_artifact=False)." ) except Exception as e: logger.warning( - f"Failed to retrieve optimized prompt: {e}. Falling back to DEFAULT_SYSTEM_PROMPT." + f"Failed to retrieve optimized prompt: {e}. Falling back to DEFAULT_SYSTEM_PROMPT (from_artifact=False)." ) - # Fallback to default system prompt if no optimization artifacts found or retrieval failed + # Fallback to default system prompt if lookup fails logger.info( - "๐Ÿ“ Using default system prompt (run optimization pipeline first)" + "๐Ÿ“ Using default system prompt (run optimization pipeline first). from_artifact=False" ) - return DEFAULT_SYSTEM_PROMPT + return DEFAULT_SYSTEM_PROMPT, False