diff --git a/AGENTS.md b/AGENTS.md index ea4df46250..225008d9be 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -107,6 +107,9 @@ Use filesystem navigation tools to explore the codebase structure as needed. - Use imperative mood: "Add feature" not "Added feature" - Reference issue numbers when applicable: "Fix user auth bug (#1234)" - For multi-line messages, add a blank line after the summary + +Codex-style agents must review the diff (when not already tracked), craft a concise summary line, and include a detailed body covering the key changes. The body should describe the main code adjustments so reviewers can understand the scope from the commit message alone. + - Example: ``` Add retry logic to artifact upload diff --git a/examples/prompt_optimization/README.md b/examples/prompt_optimization/README.md new file mode 100644 index 0000000000..d8ae92ab6b --- /dev/null +++ b/examples/prompt_optimization/README.md @@ -0,0 +1,165 @@ +# Prompt Optimization with ZenML + +This example demonstrates **ZenML's artifact management** through a two-stage AI prompt optimization workflow using **Pydantic AI** for exploratory data analysis. + +## What This Example Shows + +**Stage 1: Prompt Optimization** +- Tests multiple prompt variants against sample data +- Compares performance using data quality scores and execution time +- Emits a scoreboard artifact that summarizes quality, speed, findings, and success per prompt +- Tags the best-performing prompt with an **exclusive ZenML tag** +- Stores the optimized prompt in ZenML's artifact registry + +**Stage 2: Production Analysis** +- Automatically attempts to retrieve the latest optimized prompt from the registry +- Falls back to the default system prompt if no optimized prompt is available or retrieval fails +- Runs production EDA analysis using the selected prompt +- Returns a `used_optimized_prompt` boolean to indicate whether the optimized prompt was actually used +- Demonstrates real artifact sharing between pipeline runs + +This showcases how ZenML enables **reproducible ML workflows** where optimization results automatically flow into production systems, with safe fallbacks. + +## Quick Start + +### Prerequisites +```bash +# Install ZenML and initialize +pip install "zenml[server]" +zenml init + +# Install example dependencies +pip install -r requirements.txt + +# Set your API key (OpenAI or Anthropic) +export OPENAI_API_KEY="your-openai-key" +# OR +export ANTHROPIC_API_KEY="your-anthropic-key" +``` + +### Run the Example +```bash +# Complete two-stage workflow (default behavior) +python run.py + +# Run individual stages +python run.py --optimization-pipeline # Stage 1: Find best prompt +python run.py --production-pipeline # Stage 2: Use optimized prompt + +# Force a specific provider/model (override auto-detection) +python run.py --provider openai --model-name gpt-4o-mini +python run.py --provider anthropic --model-name claude-3-haiku-20240307 +``` + +## Data Sources + +The example supports multiple data sources: + +```bash +# HuggingFace datasets (default) +python run.py --data-source "hf:scikit-learn/iris" +python run.py --data-source "hf:scikit-learn/wine" + +# Local CSV files +python run.py --data-source "local:./my_data.csv" + +# Specify target column +python run.py --data-source "local:sales.csv" --target-column "revenue" + +# Sample a subset of rows for faster iterations +python run.py --data-source "hf:scikit-learn/iris" --sample-size 500 +``` + +## Key ZenML Concepts Demonstrated + +### Artifact Management +- **Exclusive Tagging**: Only one prompt can have the "optimized" tag at a time +- **Artifact Registry**: Centralized storage for ML artifacts with versioning +- **Cross-Pipeline Sharing**: Production pipeline automatically finds optimization results + +### Pipeline Orchestration +- **Multi-Stage Workflows**: Optimization โ†’ Production with artifact passing +- **Conditional Execution**: Production pipeline adapts based on available artifacts +- **Lineage Tracking**: Full traceability from prompt testing to production use + +### AI Integration +- **Model Flexibility**: Works with OpenAI GPT or Anthropic Claude models +- **Performance Testing**: Systematic comparison of prompt variants +- **Production Deployment**: Seamless transition from experimentation to production + +## Configuration Options + +### Provider and Model Selection +```bash +# Auto (default): infer provider from model name or environment keys +python run.py + +# Force provider explicitly +python run.py --provider openai --model-name gpt-4o-mini +python run.py --provider anthropic --model-name claude-3-haiku-20240307 + +# Fully-qualified model names are also supported +python run.py --model-name "openai:gpt-4o-mini" +python run.py --model-name "anthropic:claude-3-haiku-20240307" +``` + +### Scoring Configuration +Configure how prompts are ranked during optimization: +```bash +# Weights for the aggregate score +python run.py --weight-quality 0.7 --weight-speed 0.2 --weight-findings 0.1 + +# Speed penalty (points per second) applied to compute a speed score +python run.py --speed-penalty-per-second 2.0 + +# Findings scoring (base points per finding) and cap +python run.py --findings-score-per-item 0.5 --findings-cap 20 +``` + +*Why a cap?* It prevents a variant that simply emits an excessively long list of "findings" from dominating the overall score. Capping keeps the aggregate score bounded and comparable across variants while still rewarding useful coverage. + +### Custom Prompt Files +Provide your own prompt variants via a file (UTF-8, one prompt per line; blank lines ignored): +```bash +python run.py --prompts-file ./my_prompts.txt +``` + +### Sampling +Downsample large datasets to speed up experiments: +```bash +python run.py --sample-size 500 +``` + +### Budgets and Timeouts +Budgets are enforced at the tool boundary for deterministic behavior and cost control: +```bash +# Tool-call budget and overall time budget for the agent +python run.py --max-tool-calls 8 --timeout-seconds 120 +``` +- The agent tools check and enforce these budgets during execution. + +### Caching +```bash +# Disable caching for fresh runs +python run.py --no-cache +``` + +## Expected Output + +When you run the complete workflow, you'll see: + +1. **Optimization Stage**: Testing of multiple prompt variants with performance metrics +2. **Scoreboard**: CLI prints a compact top-3 summary (score, time, findings, success) and a short preview of the best prompt; a full scoreboard artifact is saved +3. **Tagging**: Best prompt automatically tagged in ZenML registry (exclusive "optimized" tag) +4. **Production Stage**: Retrieval and use of optimized prompt if available; otherwise falls back to the default +5. **Results**: EDA analysis with data quality scores and recommendations, plus a `used_optimized_prompt` flag indicating whether an optimized prompt was actually used + +The ZenML dashboard will show the complete lineage from optimization to production use, including the prompt scoreboard and tagged best prompt. + +## Next Steps + +- **View Results**: Check the ZenML dashboard for pipeline runs and artifacts +- **Customize Prompts**: Provide your own variants via `--prompts-file` +- **Tune Scoring**: Adjust weights and penalties to match your evaluation criteria +- **Scale Up**: Deploy with remote orchestrators for production workloads +- **Integrate**: Connect to your existing data sources and ML pipelines \ No newline at end of file diff --git a/examples/prompt_optimization/__init__.py b/examples/prompt_optimization/__init__.py new file mode 100644 index 0000000000..f25ceb38fa --- /dev/null +++ b/examples/prompt_optimization/__init__.py @@ -0,0 +1,5 @@ +"""Pydantic AI EDA pipeline example for ZenML. + +This example demonstrates how to build an AI-powered Exploratory Data Analysis +pipeline using ZenML and Pydantic AI for automated data analysis and quality assessment. +""" \ No newline at end of file diff --git a/examples/prompt_optimization/models.py b/examples/prompt_optimization/models.py new file mode 100644 index 0000000000..4b60283a52 --- /dev/null +++ b/examples/prompt_optimization/models.py @@ -0,0 +1,216 @@ +"""Simple data models for prompt optimization example.""" + +import os +from typing import List, Literal, Optional, Tuple + +from pydantic import BaseModel, Field, field_validator, model_validator + + +class DataSourceConfig(BaseModel): + """Data source configuration.""" + + source_type: str = Field(description="Data source type: hf or local") + source_path: str = Field( + description="Path or identifier for the data source" + ) + target_column: Optional[str] = Field( + None, description="Optional target column name" + ) + sample_size: Optional[int] = Field( + None, description="Number of rows to sample" + ) + + +def normalize_model_name(name: str) -> str: + """Normalize a model name by stripping any leading provider prefix. + + This prevents accidental double-prefixing when constructing a provider:model id. + + Examples: + - 'openai:gpt-4o-mini' -> 'gpt-4o-mini' + - 'anthropic:claude-3-haiku-20240307' -> 'claude-3-haiku-20240307' + - 'gpt-4o-mini' -> 'gpt-4o-mini' (unchanged) + """ + if not isinstance(name, str): + return name + value = name.strip() + if ":" in value: + maybe_provider, remainder = value.split(":", 1) + if maybe_provider.lower() in ("openai", "anthropic"): + return remainder.strip() + return value + + +def infer_provider( + model_name: Optional[str] = None, +) -> Literal["openai", "anthropic"]: + """Infer the provider from model-name hints first, then environment variables. + + Rules: + - If `model_name` clearly references Claude โ†’ Anthropic. + - If `model_name` clearly references GPT / -o models โ†’ OpenAI. + - Otherwise use environment keys as tie-breakers. + - If both or neither keys exist and there is no hint, default to OpenAI. + """ + if isinstance(model_name, str): + mn = model_name.lower().strip() + # Handle common hints robustly + if mn.startswith("claude") or "claude" in mn: + return "anthropic" + if mn.startswith("gpt") or "gpt-" in mn or "-o" in mn: + # e.g., gpt-4o, gpt-4o-mini, gpt-4.1 + return "openai" + + has_openai = bool(os.getenv("OPENAI_API_KEY")) + has_anthropic = bool(os.getenv("ANTHROPIC_API_KEY")) + + if has_openai and not has_anthropic: + return "openai" + if has_anthropic and not has_openai: + return "anthropic" + + # Both present or neither present โ†’ default to OpenAI for GPT-style defaults + return "openai" + + +class AgentConfig(BaseModel): + """AI agent configuration.""" + + provider: Optional[Literal["openai", "anthropic"]] = Field( + default=None, + description="Optional model provider. If omitted, inferred from the model name or environment.", + ) + model_name: str = Field("gpt-4o-mini", description="Language model to use") + max_tool_calls: int = Field(6, description="Maximum tool calls allowed") + timeout_seconds: int = Field( + 60, description="Max execution time in seconds" + ) + + @field_validator("model_name", mode="before") + @classmethod + def _normalize_model_name(cls, v: object) -> object: + """Normalize the model name early to remove accidental provider prefixes.""" + if isinstance(v, str): + return normalize_model_name(v) + return v + + @model_validator(mode="before") + @classmethod + def _set_provider_from_prefix(cls, data: object) -> object: + """If a provider prefix is present in model_name, capture it into `provider`. + + This preserves the user's explicit intent (e.g., 'openai:gpt-4o-mini') + even though `model_name` is normalized by the field validator. + """ + if not isinstance(data, dict): + return data + provider = data.get("provider") + raw_name = data.get("model_name") + if (provider is None or provider == "") and isinstance(raw_name, str): + value = raw_name.strip() + if ":" in value: + maybe_provider = value.split(":", 1)[0].strip().lower() + if maybe_provider in ("openai", "anthropic"): + data["provider"] = maybe_provider + return data + + def model_id(self) -> str: + """Return a 'provider:model' string, inferring the provider when necessary. + + Precedence: + 1) Explicit provider captured from model_name prefix or set via `provider` + 2) Heuristic inference using model_name hints or environment variables + """ + # Normalize again defensively to avoid double-prefixing if callers mutate model_name later + name = normalize_model_name(self.model_name) + provider: Literal["openai", "anthropic"] = ( + self.provider + if self.provider is not None + else infer_provider(name) + ) + return f"{provider}:{name}" + + +class EDAReport(BaseModel): + """Simple EDA analysis report from AI agent.""" + + headline: str = Field(description="Executive summary of key findings") + key_findings: List[str] = Field( + default_factory=list, description="Important discoveries" + ) + data_quality_score: float = Field( + description="Overall quality score (0-100)" + ) + markdown: str = Field(description="Full markdown report") + + +class ScoringConfig(BaseModel): + """Scoring config with weights, penalties, and caps.""" + + weight_quality: float = Field( + 0.7, + ge=0.0, + description="Weight for the quality component.", + ) + weight_speed: float = Field( + 0.2, + ge=0.0, + description="Weight for the speed/latency component.", + ) + weight_findings: float = Field( + 0.1, + ge=0.0, + description="Weight for the findings coverage component.", + ) + speed_penalty_per_second: float = Field( + 2.0, + ge=0.0, + description="Linear penalty (points per second) applied when computing a speed score.", + ) + findings_score_per_item: float = Field( + 0.5, + ge=0.0, + description="Base points credited per key finding before applying weights.", + ) + findings_cap: int = Field( + 20, + ge=0, + description="Maximum number of findings to credit when scoring.", + ) + + @property + def normalized_weights(self) -> Tuple[float, float, float]: + """Return (quality, speed, findings) weights normalized to sum to 1.""" + total = self.weight_quality + self.weight_speed + self.weight_findings + if total <= 0: + return (1.0, 0.0, 0.0) + return ( + self.weight_quality / total, + self.weight_speed / total, + self.weight_findings / total, + ) + + +class VariantScore(BaseModel): + """Score summary for a single prompt variant.""" + + prompt_id: str = Field( + description="Stable identifier for the variant (e.g., 'variant_1')." + ) + prompt: str = Field(description="The system prompt text evaluated.") + score: float = Field(description="Aggregate score used for ranking.") + quality_score: float = Field( + description="Quality score produced by the agent/report (0-100)." + ) + execution_time: float = Field( + ge=0.0, description="Execution time in seconds for the evaluation." + ) + findings_count: int = Field( + ge=0, description="Number of key findings counted for this run." + ) + success: bool = Field( + description="Whether the evaluation completed successfully." + ) + error: Optional[str] = Field( + None, description="Optional error message if success is False." + ) diff --git a/examples/prompt_optimization/pipelines/__init__.py b/examples/prompt_optimization/pipelines/__init__.py new file mode 100644 index 0000000000..e278ea485d --- /dev/null +++ b/examples/prompt_optimization/pipelines/__init__.py @@ -0,0 +1,15 @@ +"""ZenML pipelines for prompt optimization example. + +This module contains pipeline definitions: + +- prompt_optimization_pipeline.py: Test prompts and tag the best one +- production_eda_pipeline.py: Use optimized prompt for EDA analysis +""" + +from .production_eda_pipeline import production_eda_pipeline +from .prompt_optimization_pipeline import prompt_optimization_pipeline + +__all__ = [ + "prompt_optimization_pipeline", + "production_eda_pipeline", +] \ No newline at end of file diff --git a/examples/prompt_optimization/pipelines/production_eda_pipeline.py b/examples/prompt_optimization/pipelines/production_eda_pipeline.py new file mode 100644 index 0000000000..c60cd63345 --- /dev/null +++ b/examples/prompt_optimization/pipelines/production_eda_pipeline.py @@ -0,0 +1,63 @@ +"""Simple production EDA pipeline using optimized prompts.""" + +from typing import Any, Dict, Optional + +from models import AgentConfig, DataSourceConfig +from steps import get_optimized_prompt, ingest_data, run_eda_agent + +from zenml import pipeline +from zenml.logger import get_logger + +logger = get_logger(__name__) + + +@pipeline +def production_eda_pipeline( + source_config: DataSourceConfig, + agent_config: Optional[AgentConfig] = None, +) -> Dict[str, Any]: + """Production EDA pipeline using an optimized prompt when available. + + The pipeline automatically attempts to retrieve the latest optimized prompt + from the artifact registry (tagged 'optimized') and falls back to the + default system prompt if none is available or retrieval fails. + + Args: + source_config: Data source configuration + agent_config: AI agent configuration + + Returns: + EDA results and metadata, including whether an optimized prompt was used. + """ + logger.info("๐Ÿญ Starting production EDA pipeline") + + # Step 1: Get optimized prompt or fall back to default + optimized_prompt, used_optimized = get_optimized_prompt() + if used_optimized: + logger.info("๐ŸŽฏ Using optimized prompt retrieved from artifact") + else: + logger.warning( + "๐Ÿ” Optimized prompt not found; falling back to default system prompt" + ) + + # Step 2: Load data + dataset_df, metadata = ingest_data(source_config=source_config) + + # Step 3: Run EDA analysis with the selected prompt + report_markdown, report_json, sql_log, analysis_tables = run_eda_agent( + dataset_df=dataset_df, + dataset_metadata=metadata, + agent_config=agent_config, + custom_system_prompt=optimized_prompt, + ) + + logger.info("โœ… Production EDA pipeline completed") + + return { + "report_markdown": report_markdown, + "report_json": report_json, + "sql_log": sql_log, + "analysis_tables": analysis_tables, + "used_optimized_prompt": used_optimized, + "metadata": metadata, + } diff --git a/examples/prompt_optimization/pipelines/prompt_optimization_pipeline.py b/examples/prompt_optimization/pipelines/prompt_optimization_pipeline.py new file mode 100644 index 0000000000..18ecc4b48b --- /dev/null +++ b/examples/prompt_optimization/pipelines/prompt_optimization_pipeline.py @@ -0,0 +1,64 @@ +"""Simple prompt optimization pipeline for ZenML artifact management demo.""" + +from typing import Any, Dict, List, Optional + +from models import AgentConfig, DataSourceConfig, ScoringConfig +from steps import compare_prompts_and_tag_best, ingest_data + +from zenml import pipeline +from zenml.logger import get_logger + +logger = get_logger(__name__) + + +@pipeline +def prompt_optimization_pipeline( + source_config: DataSourceConfig, + prompt_variants: List[str], + agent_config: Optional[AgentConfig] = None, + scoring_config: Optional[ScoringConfig] = None, +) -> Dict[str, Any]: + """Optimize prompts and tag the best one for production use. + + This pipeline demonstrates ZenML's artifact management by testing + multiple prompt variants and tagging the best performer. + + Args: + source_config: Data source configuration + prompt_variants: List of prompt strings to test + agent_config: AI agent configuration + scoring_config: Optional scoring configuration controlling ranking + + Returns: + Pipeline results with best prompt, scoreboard, effective scoring config, and metadata. + """ + logger.info("๐Ÿงช Starting prompt optimization pipeline") + + # Step 1: Load data + dataset_df, metadata = ingest_data(source_config=source_config) + + # Use a single ScoringConfig instance for both the step and return payload + effective_scoring = scoring_config or ScoringConfig() + + # Step 2: Test prompts and tag the best one, collecting a scoreboard + best_prompt, scoreboard = compare_prompts_and_tag_best( + dataset_df=dataset_df, + prompt_variants=prompt_variants, + agent_config=agent_config, + scoring_config=effective_scoring, + ) + + logger.info( + "โœ… Prompt optimization completed - best prompt tagged and scoreboard generated" + ) + + return { + "best_prompt": best_prompt, + "scoreboard": [entry.model_dump() for entry in scoreboard], + "scoring": effective_scoring.model_dump(), + "metadata": metadata, + "config": { + "source": f"{source_config.source_type}:{source_config.source_path}", + "variants_tested": len(prompt_variants), + }, + } diff --git a/examples/prompt_optimization/requirements.txt b/examples/prompt_optimization/requirements.txt new file mode 100644 index 0000000000..3215a20c16 --- /dev/null +++ b/examples/prompt_optimization/requirements.txt @@ -0,0 +1,20 @@ +# ZenML +zenml[server] + +# Core dependencies with version constraints for compatibility +pandas>=2.0.0,<3.0.0 +numpy>=1.24.0,<2.0.0 +duckdb>=1.0.0,<2.0.0 + +# Pydantic compatibility +pydantic>=2.6,<3 + +# AI/ML frameworks +pydantic-ai[logfire]>=0.4.0 +openai>=1.0.0,<2.0.0 +anthropic>=0.30.0 # Alternative LLM provider + +# Data source connectors +datasets>=2.14.0,<3.0.0 # HuggingFace datasets (optional for local files) + +# Utility libraries diff --git a/examples/prompt_optimization/run.py b/examples/prompt_optimization/run.py new file mode 100644 index 0000000000..f80206dc58 --- /dev/null +++ b/examples/prompt_optimization/run.py @@ -0,0 +1,334 @@ +#!/usr/bin/env python3 +"""Two-Stage Prompt Optimization Example with ZenML.""" + +import argparse +import os +import sys +from typing import List, Optional + +from models import AgentConfig, DataSourceConfig, ScoringConfig +from pipelines.production_eda_pipeline import production_eda_pipeline +from pipelines.prompt_optimization_pipeline import prompt_optimization_pipeline + + +def build_parser() -> argparse.ArgumentParser: + """Build the CLI parser. + + Rationale: + - Uses only the Python standard library to avoid extra dependencies. + - Mirrors the original Click-based flag and option semantics to keep the + README usage and user experience unchanged. + """ + parser = argparse.ArgumentParser( + description="Run prompt optimization and/or production EDA pipelines (both by default)" + ) + + # Instantiate defaults from ScoringConfig so CLI flags reflect canonical values + scoring_defaults = ScoringConfig() + + parser.add_argument( + "--optimization-pipeline", + action="store_true", + help="Run prompt optimization", + ) + parser.add_argument( + "--production-pipeline", + action="store_true", + help="Run production EDA", + ) + parser.add_argument( + "--data-source", + default="hf:scikit-learn/iris", + help="Data source in 'type:path' format, e.g., 'hf:scikit-learn/iris' or 'local:./data.csv'", + ) + parser.add_argument( + "--target-column", + default="target", + help="Target column name", + ) + parser.add_argument( + "--model-name", + help="Model name (auto-detected if not specified). Can be fully-qualified like 'openai:gpt-4o-mini' or 'anthropic:claude-3-haiku-20240307'.", + ) + parser.add_argument( + "--provider", + choices=["auto", "openai", "anthropic"], + default="auto", + help="Model provider to use. 'auto' infers from model/name/environment.", + ) + parser.add_argument( + "--sample-size", + type=int, + help="Optional row sample size for the dataset.", + ) + parser.add_argument( + "--prompts-file", + help="Path to a newline-delimited prompts file (UTF-8). One prompt per line; blank lines are ignored.", + ) + parser.add_argument( + "--max-tool-calls", + default=6, + type=int, + help="Max tool calls", + ) + parser.add_argument( + "--timeout-seconds", + default=60, + type=int, + help="Timeout seconds", + ) + parser.add_argument( + "--no-cache", + action="store_true", + help="Disable caching", + ) + + # Scoring knobs (seeded from ScoringConfig defaults) + parser.add_argument( + "--weight-quality", + type=float, + default=scoring_defaults.weight_quality, + help="Weight for the quality component.", + ) + parser.add_argument( + "--weight-speed", + type=float, + default=scoring_defaults.weight_speed, + help="Weight for the speed/latency component.", + ) + parser.add_argument( + "--weight-findings", + type=float, + default=scoring_defaults.weight_findings, + help="Weight for the findings coverage component.", + ) + parser.add_argument( + "--speed-penalty-per-second", + type=float, + default=scoring_defaults.speed_penalty_per_second, + help="Penalty points per second applied when computing a speed score.", + ) + parser.add_argument( + "--findings-score-per-item", + type=float, + default=scoring_defaults.findings_score_per_item, + help="Base points credited per key finding before applying weights.", + ) + parser.add_argument( + "--findings-cap", + type=int, + default=scoring_defaults.findings_cap, + help="Maximum number of findings to credit when scoring.", + ) + return parser + + +def main() -> int: + """Run prompt optimization and/or production EDA pipelines.""" + parser = build_parser() + args = parser.parse_args() + + # Default: run both pipelines if no flags specified + optimization_pipeline = args.optimization_pipeline + production_pipeline = args.production_pipeline + if not optimization_pipeline and not production_pipeline: + optimization_pipeline = True + production_pipeline = True + + # Check API keys. We support either OpenAI or Anthropic; at least one is required. + has_openai = bool(os.getenv("OPENAI_API_KEY")) + has_anthropic = bool(os.getenv("ANTHROPIC_API_KEY")) + + if not (has_openai or has_anthropic): + print("โŒ Set OPENAI_API_KEY or ANTHROPIC_API_KEY", file=sys.stderr) + return 1 + + # Derive raw model name from CLI or sensible defaults considering provider/env + raw_model_name: Optional[str] = args.model_name + if raw_model_name is None: + if args.provider == "openai": + raw_model_name = "gpt-4o-mini" + elif args.provider == "anthropic": + raw_model_name = "claude-3-haiku-20240307" + else: + raw_model_name = ( + "gpt-4o-mini" if has_openai else "claude-3-haiku-20240307" + ) + + # Detect and strip provider prefixes (e.g., 'openai:gpt-4o-mini'), preserving explicit provider + explicit_provider: Optional[str] = None + normalized_model_name = raw_model_name.strip() + if ":" in normalized_model_name: + maybe_provider, remainder = normalized_model_name.split(":", 1) + maybe_provider = maybe_provider.strip().lower() + if maybe_provider in ("openai", "anthropic"): + explicit_provider = maybe_provider + normalized_model_name = remainder.strip() + + # Combine with --provider to determine provider choice + if args.provider == "auto": + provider_choice: Optional[str] = explicit_provider or None + else: + provider_choice = args.provider + + # Validate that the chosen provider has its corresponding API key configured + if provider_choice == "openai" and not has_openai: + print( + "โŒ You selected --provider openai (or model prefix 'openai:') but OPENAI_API_KEY is not set.", + file=sys.stderr, + ) + return 1 + if provider_choice == "anthropic" and not has_anthropic: + print( + "โŒ You selected --provider anthropic (or model prefix 'anthropic:') but ANTHROPIC_API_KEY is not set.", + file=sys.stderr, + ) + return 1 + + # Parse data source "type:path" into its components + try: + source_type, source_path = args.data_source.split(":", 1) + except ValueError: + print( + "โŒ Invalid --data-source. Expected 'type:path' (e.g., 'hf:scikit-learn/iris' or 'local:./data.csv').", + file=sys.stderr, + ) + print( + f"Provided value: {args.data_source}. Update the flag or use one of the examples above.", + file=sys.stderr, + ) + return 2 + + # Prepare prompt variants: file-based if provided; otherwise use baked-in defaults + prompt_variants: List[str] + if args.prompts_file: + try: + with open(args.prompts_file, "r", encoding="utf-8") as f: + prompt_variants = [ + line.strip() + for line in f.read().splitlines() + if line.strip() + ] + if not prompt_variants: + print( + f"โŒ Prompts file is empty after removing blank lines: {args.prompts_file}", + file=sys.stderr, + ) + print( + "Add one prompt per line, or omit --prompts-file to use the built-in defaults.", + file=sys.stderr, + ) + return 4 + except OSError: + print( + f"โŒ Unable to read prompts file: {args.prompts_file}", + file=sys.stderr, + ) + print( + "Provide a valid path with --prompts-file or omit the flag to use the built-in prompts.", + file=sys.stderr, + ) + return 3 + else: + prompt_variants = [ + "You are a data analyst. Quickly assess data quality (0-100 score) and identify key patterns. Be concise.", + "You are a data quality specialist. Calculate quality score (0-100), find missing data, duplicates, and correlations. Provide specific recommendations.", + "You are a business analyst. Assess ML readiness with quality score. Focus on business impact and actionable insights.", + ] + + # Create configs passed to the pipelines + source_config = DataSourceConfig( + source_type=source_type, + source_path=source_path, + target_column=args.target_column, + sample_size=args.sample_size, + ) + + agent_config = AgentConfig( + model_name=normalized_model_name, + provider=provider_choice, # None => auto inference in AgentConfig + max_tool_calls=args.max_tool_calls, + timeout_seconds=args.timeout_seconds, + ) + + scoring_config = ScoringConfig( + weight_quality=args.weight_quality, + weight_speed=args.weight_speed, + weight_findings=args.weight_findings, + speed_penalty_per_second=args.speed_penalty_per_second, + findings_score_per_item=args.findings_score_per_item, + findings_cap=args.findings_cap, + ) + + # ZenML run options: keep parity with the original example + pipeline_options = {"enable_cache": not args.no_cache} + + # Concise transparency logs + sample_info = args.sample_size if args.sample_size is not None else "all" + print( + f"โ„น๏ธ Provider: {provider_choice or 'auto'} | Model: {normalized_model_name} | Data: {source_type}:{source_path} (target={args.target_column}, sample={sample_info})" + ) + print( + f"โ„น๏ธ Prompt variants: {len(prompt_variants)} | Tool calls: {args.max_tool_calls} | Timeout: {args.timeout_seconds}s" + ) + + # Stage 1: Prompt optimization + if optimization_pipeline: + print("๐Ÿงช Running prompt optimization...") + try: + optimization_result = prompt_optimization_pipeline.with_options( + **pipeline_options + )( + source_config=source_config, + prompt_variants=prompt_variants, + agent_config=agent_config, + scoring_config=scoring_config, + ) + print("โœ… Optimization completed") + # Pretty-print a compact scoreboard summary if available + if isinstance(optimization_result, dict): + scoreboard = optimization_result.get("scoreboard") or [] + best_prompt = optimization_result.get("best_prompt") + if isinstance(scoreboard, list) and scoreboard: + top = sorted( + scoreboard, + key=lambda x: x.get("score", 0.0), + reverse=True, + )[:3] + print("๐Ÿ“Š Scoreboard (top 3):") + for entry in top: + pid = entry.get("prompt_id", "?") + sc = entry.get("score", 0.0) + t = entry.get("execution_time", 0.0) + f = entry.get("findings_count", 0) + ok = entry.get("success", False) + mark = "โœ…" if ok else "โŒ" + print( + f"- {pid} | score: {sc:.1f} | time: {t:.1f}s | findings: {f} | {mark}" + ) + if isinstance(best_prompt, str) and best_prompt: + preview = best_prompt.replace("\n", " ")[:80] + print( + f"๐Ÿ“ Best prompt preview: {preview}{'...' if len(best_prompt) > 80 else ''}" + ) + except Exception as e: + # Best-effort behavior: log the error and continue to the next stage + print(f"โŒ Optimization failed: {e}", file=sys.stderr) + + # Stage 2: Production analysis + if production_pipeline: + print("๐Ÿญ Running production analysis...") + try: + production_eda_pipeline.with_options(**pipeline_options)( + source_config=source_config, + agent_config=agent_config, + ) + print("โœ… Production analysis completed") + except Exception as e: + print(f"โŒ Production analysis failed: {e}", file=sys.stderr) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/examples/prompt_optimization/steps/__init__.py b/examples/prompt_optimization/steps/__init__.py new file mode 100644 index 0000000000..3f28fca444 --- /dev/null +++ b/examples/prompt_optimization/steps/__init__.py @@ -0,0 +1,18 @@ +"""ZenML steps for prompt optimization example. + +Simple, focused steps that demonstrate ZenML's core capabilities: +- Data ingestion +- AI-powered analysis +- Prompt optimization and tagging +""" + +from .eda_agent import run_eda_agent +from .ingest import ingest_data +from .prompt_optimization import compare_prompts_and_tag_best, get_optimized_prompt + +__all__ = [ + "ingest_data", + "run_eda_agent", + "compare_prompts_and_tag_best", + "get_optimized_prompt", +] \ No newline at end of file diff --git a/examples/prompt_optimization/steps/agent_tools.py b/examples/prompt_optimization/steps/agent_tools.py new file mode 100644 index 0000000000..cfaa394f3e --- /dev/null +++ b/examples/prompt_optimization/steps/agent_tools.py @@ -0,0 +1,207 @@ +"""Simple Pydantic AI agent tools for EDA analysis.""" + +import time +from dataclasses import dataclass, field +from functools import wraps +from threading import Lock +from typing import Any, Callable, Dict, List, Optional + +import duckdb +import pandas as pd +from pydantic_ai import ModelRetry, RunContext + + +@dataclass +class AnalystAgentDeps: + """Simple storage for analysis results with Out[n] references.""" + + output: Dict[str, pd.DataFrame] = field(default_factory=dict) + query_history: List[Dict[str, Any]] = field(default_factory=list) + tool_calls: int = 0 + started_at: float = field(default_factory=time.monotonic) + time_budget_s: Optional[float] = None + lock: Lock = field(default_factory=Lock, repr=False, compare=False) + + def store(self, value: pd.DataFrame) -> str: + """Store the output and return reference like Out[1] for the LLM.""" + with self.lock: + ref = f"Out[{len(self.output) + 1}]" + self.output[ref] = value + return ref + + def get(self, ref: str) -> pd.DataFrame: + if ref not in self.output: + raise ModelRetry( + f"Error: {ref} is not a valid variable reference. Check the previous messages and try again." + ) + return self.output[ref] + + +def run_sql(ctx: RunContext[AnalystAgentDeps], dataset: str, sql: str) -> str: + """Run SQL query on a DataFrame using DuckDB. + + Note: Use 'dataset' as the table name in your SQL queries. + + Args: + ctx: Pydantic AI agent RunContext + dataset: reference to the DataFrame (e.g., 'Out[1]') + sql: SQL query to execute + """ + try: + data = ctx.deps.get(dataset) + result = duckdb.query_df( + df=data, virtual_table_name="dataset", sql_query=sql + ) + df = result.df() + rows = len(df) + ref = ctx.deps.store(df) + + # Log the query for tracking + ctx.deps.query_history.append( + {"sql": sql, "result_ref": ref, "rows_returned": rows} + ) + + return f"Query executed successfully. Result stored as `{ref}` ({rows} rows)." + except Exception as e: + raise ModelRetry(f"SQL query failed: {str(e)}") + + +def display( + ctx: RunContext[AnalystAgentDeps], dataset: str, rows: int = 5 +) -> str: + """Display the first few rows of a dataset. + + Args: + ctx: Pydantic AI agent RunContext + dataset: reference to the DataFrame + rows: number of rows to display (default: 5) + """ + try: + data = ctx.deps.get(dataset) + return f"Dataset {dataset} preview:\n{data.head(rows).to_string()}" + except Exception as e: + return f"Display error: {str(e)}" + + +def describe(ctx: RunContext[AnalystAgentDeps], dataset: str) -> str: + """Get statistical summary of a dataset. + + Args: + ctx: Pydantic AI agent RunContext + dataset: reference to the DataFrame + """ + try: + data = ctx.deps.get(dataset) + + # Basic info + info = [ + f"Dataset {dataset} comprehensive summary:", + f"Shape: {data.shape[0]:,} rows ร— {data.shape[1]} columns", + f"Memory usage: {data.memory_usage(deep=True).sum() / 1024**2:.2f} MB", + "", + ] + + # Data types and missing info + info.append("Column Information:") + for col in data.columns: + dtype = str(data[col].dtype) + null_count = data[col].isnull().sum() + null_pct = (null_count / len(data)) * 100 + unique_count = data[col].nunique() + + info.append( + f" {col}: {dtype} | {null_count} nulls ({null_pct:.1f}%) | {unique_count} unique" + ) + + info.append("\nStatistical Summary:") + info.append(data.describe(include="all").to_string()) + + return "\n".join(info) + except Exception as e: + return f"Describe error: {str(e)}" + + +def analyze_correlations( + ctx: RunContext[AnalystAgentDeps], dataset: str +) -> str: + """Analyze correlations between numeric variables. + + Args: + ctx: Pydantic AI agent RunContext + dataset: reference to the DataFrame + """ + try: + data = ctx.deps.get(dataset) + numeric_data = data.select_dtypes(include=["number"]) + + if len(numeric_data.columns) < 2: + return "Need at least 2 numeric columns for correlation analysis." + + corr_matrix = numeric_data.corr() + + # Store correlation matrix + corr_ref = ctx.deps.store(corr_matrix) + + # Find strong correlations + strong_corrs = [] + for i, col1 in enumerate(numeric_data.columns): + for j, col2 in enumerate(numeric_data.columns[i + 1 :], i + 1): + corr_val = corr_matrix.iloc[i, j] + if abs(corr_val) > 0.7: + strong_corrs.append(f"{col1} โ†” {col2}: {corr_val:.3f}") + + result = [ + f"Correlation analysis for {len(numeric_data.columns)} numeric columns:", + f"Correlation matrix stored as {corr_ref}", + "", + ] + + if strong_corrs: + result.append("Strong correlations (|r| > 0.7):") + result.extend(f" {corr}" for corr in strong_corrs) + else: + result.append("No strong correlations (|r| > 0.7) found.") + + return "\n".join(result) + except Exception as e: + return f"Correlation analysis error: {str(e)}" + + +def budget_wrapper(max_tool_calls: Optional[int]): + """Return a decorator that enforces time/tool-call budgets for tools. + + It reads `time_budget_s`, `started_at`, and `tool_calls` from ctx.deps and + raises ModelRetry when limits are exceeded. + """ + + def with_budget(tool: Callable) -> Callable: + @wraps(tool) + def _wrapped(ctx: RunContext[AnalystAgentDeps], *args, **kwargs): + # Enforce budgets atomically to be safe under parallel tool execution + with ctx.deps.lock: + tb = getattr(ctx.deps, "time_budget_s", None) + if ( + tb is not None + and (time.monotonic() - ctx.deps.started_at) > tb + ): + raise ModelRetry("Time budget exceeded.") + + if ( + max_tool_calls is not None + and ctx.deps.tool_calls >= max_tool_calls + ): + raise ModelRetry("Tool-call budget exceeded.") + + # Increment tool call count after passing checks + ctx.deps.tool_calls += 1 + + # Execute the actual tool logic outside the lock + return tool(ctx, *args, **kwargs) + + return _wrapped + + return with_budget + + +# Enhanced tool registry +AGENT_TOOLS = [run_sql, display, describe, analyze_correlations] diff --git a/examples/prompt_optimization/steps/eda_agent.py b/examples/prompt_optimization/steps/eda_agent.py new file mode 100644 index 0000000000..f25b7ceea7 --- /dev/null +++ b/examples/prompt_optimization/steps/eda_agent.py @@ -0,0 +1,142 @@ +"""Simple EDA agent step using Pydantic AI.""" + +from typing import Annotated, Any, Dict, List, Tuple + +import pandas as pd +from pydantic_ai import Agent +from pydantic_ai.settings import ModelSettings + +from zenml import step +from zenml.logger import get_logger +from zenml.types import MarkdownString + +logger = get_logger(__name__) + + +# Logfire for observability +try: + import logfire + + LOGFIRE_AVAILABLE = True +except ImportError: + LOGFIRE_AVAILABLE = False + +from models import AgentConfig, EDAReport +from steps.agent_tools import AGENT_TOOLS, AnalystAgentDeps, budget_wrapper +from steps.prompt_text import DEFAULT_SYSTEM_PROMPT, build_user_prompt + + +@step +def run_eda_agent( + dataset_df: pd.DataFrame, + dataset_metadata: Dict[str, Any], + agent_config: AgentConfig = None, + custom_system_prompt: str = None, +) -> Tuple[ + Annotated[MarkdownString, "eda_report_markdown"], + Annotated[Dict[str, Any], "eda_report_json"], + Annotated[List[Dict[str, Any]], "sql_execution_log"], + Annotated[Dict[str, pd.DataFrame], "analysis_tables"], +]: + """Run Pydantic AI agent for EDA analysis with optional custom prompt. + + Args: + dataset_df: Dataset to analyze + dataset_metadata: Metadata about the dataset + agent_config: Configuration for the AI agent + custom_system_prompt: Optional custom system prompt (overrides default) + + Returns: + Tuple of EDA outputs: markdown report, JSON report, SQL log, analysis tables + """ + if agent_config is None: + agent_config = AgentConfig() + + # Configure Logfire for observability + if LOGFIRE_AVAILABLE: + try: + logfire.configure() + logfire.instrument_pydantic_ai() + logfire.info("EDA agent starting", dataset_shape=dataset_df.shape) + except Exception as e: + print(f"Warning: Failed to configure Logfire: {e}") + + # Initialize agent dependencies with time budget and store the dataset + deps = AnalystAgentDeps( + time_budget_s=float(agent_config.timeout_seconds) + if agent_config + else None + ) + main_ref = deps.store(dataset_df) + + # Create the EDA analyst agent with system prompt (custom or default) + if custom_system_prompt: + system_prompt = custom_system_prompt + logger.info("๐ŸŽฏ Using custom optimized system prompt for analysis") + else: + system_prompt = DEFAULT_SYSTEM_PROMPT + logger.info("๐Ÿ“ Using default system prompt for analysis") + + # Provider:model id computed once for both agent creation and metadata + provider_model = agent_config.model_id() + + analyst_agent = Agent( + provider_model, + deps_type=AnalystAgentDeps, + output_type=EDAReport, + output_retries=3, # Allow more retries for result validation + system_prompt=system_prompt, + model_settings=ModelSettings( + parallel_tool_calls=True, + ), + ) + + # Use shared budget wrapper to enforce time/tool-call limits for all tools + wrapper = budget_wrapper(getattr(agent_config, "max_tool_calls", None)) + + # Register tools with budget enforcement + for tool in AGENT_TOOLS: + analyst_agent.tool(wrapper(tool)) + + # Run focused analysis using shared user prompt builder + user_prompt = build_user_prompt(main_ref, dataset_df) + + try: + result = analyst_agent.run_sync(user_prompt, deps=deps) + eda_report = result.output + except Exception as e: + eda_report = EDAReport( + headline=f"Analysis failed for dataset with {dataset_df.shape[0]} rows", + key_findings=[ + f"Dataset contains {len(dataset_df)} rows and {len(dataset_df.columns)} columns.", + "The AI agent failed to generate a report.", + ], + data_quality_score=0.0, + markdown=( + f"# EDA Report Failed\n\n" + f"Analysis failed with error: {str(e)}\n\n" + f"Dataset shape: {dataset_df.shape}" + ), + ) + + # Return results + return ( + MarkdownString(eda_report.markdown), + { + "headline": eda_report.headline, + "key_findings": eda_report.key_findings, + "data_quality_score": eda_report.data_quality_score, + "agent_metadata": { + "model": agent_config.model_name, + "provider_model": provider_model, + "tool_calls": deps.tool_calls, + "sql_queries": len(deps.query_history), + }, + }, + deps.query_history, + { + ref: df + for ref, df in deps.output.items() + if ref != main_ref and len(df) <= 1000 + }, + ) diff --git a/examples/prompt_optimization/steps/ingest.py b/examples/prompt_optimization/steps/ingest.py new file mode 100644 index 0000000000..867f73f894 --- /dev/null +++ b/examples/prompt_optimization/steps/ingest.py @@ -0,0 +1,102 @@ +"""Simple data ingestion step for prompt optimization example.""" + +from typing import Annotated, Any, Dict, Tuple + +import pandas as pd +from models import DataSourceConfig + +from zenml import step +from zenml.logger import get_logger + +logger = get_logger(__name__) + + +@step +def ingest_data( + source_config: DataSourceConfig, +) -> Tuple[ + Annotated[pd.DataFrame, "dataset"], + Annotated[Dict[str, Any], "ingestion_metadata"], +]: + """Simple data ingestion from HuggingFace or local files. + + Args: + source_config: Configuration specifying data source + + Returns: + Tuple of (dataframe, metadata) + """ + logger.info( + f"Loading data from {source_config.source_type}:{source_config.source_path}" + ) + + # Load data based on source type + if source_config.source_type == "hf": + df = _load_from_huggingface(source_config) + elif source_config.source_type == "local": + df = _load_from_local(source_config) + else: + raise ValueError( + f"Unsupported source type: {source_config.source_type}" + ) + + # Apply sampling if configured + total = len(df) + if source_config.sample_size and total > source_config.sample_size: + df = df.sample(n=source_config.sample_size, random_state=42) + logger.info( + f"Sampled {source_config.sample_size} rows from {total} total" + ) + + # Generate simple metadata + metadata = { + "source_type": source_config.source_type, + "source_path": source_config.source_path, + "rows": len(df), + "columns": len(df.columns), + "column_names": df.columns.tolist(), + "target_column": source_config.target_column, + } + + logger.info(f"Loaded dataset: {len(df)} rows ร— {len(df.columns)} columns") + return df, metadata + + +def _load_from_huggingface(config: DataSourceConfig) -> pd.DataFrame: + """Load dataset from HuggingFace Hub.""" + try: + from datasets import load_dataset + + # Simple dataset loading + dataset = load_dataset(config.source_path, split="train") + df = dataset.to_pandas() + + logger.info(f"Loaded HuggingFace dataset: {config.source_path}") + return df + + except ImportError: + raise ImportError( + "datasets library required. Install with: pip install datasets" + ) + except Exception as e: + raise RuntimeError(f"Failed to load dataset {config.source_path}: {e}") + + +def _load_from_local(config: DataSourceConfig) -> pd.DataFrame: + """Load dataset from local file.""" + try: + file_path = config.source_path + + if file_path.endswith(".csv"): + df = pd.read_csv(file_path) + elif file_path.endswith(".json"): + df = pd.read_json(file_path) + else: + # Try CSV as fallback + df = pd.read_csv(file_path) + + logger.info(f"Loaded local file: {file_path}") + return df + + except Exception as e: + raise RuntimeError(f"Failed to load file {config.source_path}: {e}") diff --git a/examples/prompt_optimization/steps/prompt_optimization.py b/examples/prompt_optimization/steps/prompt_optimization.py new file mode 100644 index 0000000000..ef71469abc --- /dev/null +++ b/examples/prompt_optimization/steps/prompt_optimization.py @@ -0,0 +1,228 @@ +"""Simple prompt optimization step for demonstrating ZenML artifact management.""" + +import time +from typing import Annotated, List, Optional, Tuple + +import pandas as pd +from models import AgentConfig, EDAReport, ScoringConfig, VariantScore +from pydantic_ai import Agent +from pydantic_ai.settings import ModelSettings +from steps.agent_tools import AGENT_TOOLS, AnalystAgentDeps, budget_wrapper +from steps.prompt_text import DEFAULT_SYSTEM_PROMPT, build_user_prompt + +from zenml import ArtifactConfig, Tag, add_tags, step +from zenml.client import Client +from zenml.logger import get_logger + +logger = get_logger(__name__) + + +# NOTE: Quality and speed are both on a 0โ€“100 scale. Findings accrue raw points +# (capped via `findings_cap`) before weights are applied. After weight normalization, +# the aggregate score remains roughly bounded within 0โ€“100. +def compute_prompt_score( + eda_report: EDAReport, execution_time: float, scoring: ScoringConfig +) -> float: + wq, ws, wf = scoring.normalized_weights + speed_penalty = min( + execution_time * scoring.speed_penalty_per_second, 100.0 + ) + speed_score = max(0.0, 100.0 - speed_penalty) + credited_findings = min(len(eda_report.key_findings), scoring.findings_cap) + findings_score = credited_findings * scoring.findings_score_per_item + return ( + eda_report.data_quality_score * wq + + speed_score * ws + + findings_score * wf + ) + + +@step +def compare_prompts_and_tag_best( + dataset_df: pd.DataFrame, + prompt_variants: List[str], + agent_config: AgentConfig | None = None, + scoring_config: Optional[ScoringConfig] = None, +) -> Tuple[ + Annotated[str, ArtifactConfig(name="best_prompt")], + Annotated[List[VariantScore], ArtifactConfig(name="prompt_scoreboard")], +]: + """Compare prompt variants, compute scores, and emit best prompt + scoreboard. + + Behavior: + - Provider/model inference owned by AgentConfig.model_id() + - Uniform time/tool-call budget enforcement via tool wrappers + - Score each variant using ScoringConfig; record success/failure with timing + - Tag best prompt with exclusive 'optimized' tag only if at least one success + + Args: + dataset_df: Dataset to test prompts against + prompt_variants: List of system prompts to compare + agent_config: Configuration for AI agents (defaults applied if None) + scoring_config: Scoring configuration (defaults applied if None) + + Returns: + Tuple of: + - Best performing prompt string (artifact name 'best_prompt') + - Scoreboard entries for all variants (artifact name 'prompt_scoreboard') + """ + if agent_config is None: + agent_config = AgentConfig() + if scoring_config is None: + scoring_config = ScoringConfig() + + logger.info(f"๐Ÿงช Testing {len(prompt_variants)} prompt variants") + + # Compute provider:model id once for consistent use across variants + provider_model = agent_config.model_id() + + # Prepare a shared wrapper that enforces budgets for all tools + wrapper = budget_wrapper(getattr(agent_config, "max_tool_calls", None)) + + # Collect VariantScore entries for all variants (success and failure) + scoreboard: List[VariantScore] = [] + + for i, system_prompt in enumerate(prompt_variants): + prompt_id = f"variant_{i + 1}" + logger.info(f"Testing {prompt_id}...") + + start_time = time.time() + + try: + # Create agent with this prompt and per-variant deps with time budget + deps = AnalystAgentDeps( + time_budget_s=float(agent_config.timeout_seconds) + if agent_config + else None + ) + deps.tool_calls = ( + 0 # Ensure tool-call counter is reset per variant + ) + main_ref = deps.store(dataset_df) + + agent = Agent( + provider_model, + deps_type=AnalystAgentDeps, + output_type=EDAReport, + system_prompt=system_prompt, + model_settings=ModelSettings(parallel_tool_calls=False), + ) + + # Register tools with shared budget enforcement + for tool in AGENT_TOOLS: + agent.tool(wrapper(tool)) + + # Run analysis + user_prompt = build_user_prompt(main_ref, dataset_df) + result = agent.run_sync(user_prompt, deps=deps) + eda_report = result.output + + execution_time = time.time() - start_time + + # Score this variant using provided scoring configuration + score = compute_prompt_score( + eda_report, execution_time, scoring_config + ) + + scoreboard.append( + VariantScore( + prompt_id=prompt_id, + prompt=system_prompt, + score=score, + quality_score=eda_report.data_quality_score, + execution_time=execution_time, + findings_count=len(eda_report.key_findings), + success=True, + error=None, + ) + ) + + logger.info( + f"โœ… {prompt_id}: score={score:.1f}, time={execution_time:.1f}s" + ) + + except Exception as e: + execution_time = time.time() - start_time + logger.warning(f"โŒ {prompt_id} failed: {e}") + scoreboard.append( + VariantScore( + prompt_id=prompt_id, + prompt=system_prompt, + score=0.0, + quality_score=0.0, + execution_time=execution_time, + findings_count=0, + success=False, + error=str(e), + ) + ) + + # Determine best performer among successful variants + successful_results = [entry for entry in scoreboard if entry.success] + + if not successful_results: + logger.warning( + "All prompts failed, falling back to DEFAULT_SYSTEM_PROMPT" + ) + best_prompt = DEFAULT_SYSTEM_PROMPT + logger.info("โญ๏ธ Skipping best prompt tagging since all variants failed") + else: + best_result = max(successful_results, key=lambda x: x.score) + best_prompt = best_result.prompt + logger.info( + f"๐Ÿ† Best prompt: {best_result.prompt_id} (score: {best_result.score:.1f})" + ) + + logger.info( + "๐Ÿ’พ Best prompt will be stored with exclusive 'optimized' tag" + ) + # Explicitly target the named output artifact to avoid multi-output ambiguity + add_tags( + tags=[Tag(name="optimized", exclusive=True)], + output_name="best_prompt", + ) + + return best_prompt, scoreboard + + +def get_optimized_prompt() -> Tuple[str, bool]: + """Retrieve the optimized prompt from a tagged artifact, with safe fallback. + + This demonstrates ZenML's tag filtering by finding artifacts tagged with 'optimized'. + Since 'optimized' is an exclusive tag, there will be at most one such artifact. + + Returns: + Tuple[str, bool]: (prompt, from_artifact) where: + - prompt: The optimized prompt text if found; DEFAULT_SYSTEM_PROMPT otherwise. + - from_artifact: True if the prompt was retrieved from an artifact; False on fallback. + """ + try: + client = Client() + + # Find artifacts tagged with 'optimized' (our exclusive tag) + artifacts = client.list_artifact_versions(tags=["optimized"], size=1) + + if artifacts.items: + optimized_artifact = artifacts.items[0] + prompt_value = optimized_artifact.load() + logger.info( + f"๐ŸŽฏ Retrieved optimized prompt from artifact: {optimized_artifact.id}" + ) + logger.info(f" Artifact created: {optimized_artifact.created}") + return prompt_value, True + else: + logger.info( + "๐Ÿ” No optimized prompt found (no artifacts with 'optimized' tag). " + "Falling back to default (from_artifact=False)." + ) + + except Exception as e: + logger.warning( + f"Failed to retrieve optimized prompt: {e}. Falling back to DEFAULT_SYSTEM_PROMPT (from_artifact=False)." + ) + + # Fallback to default system prompt if lookup fails + logger.info( + "๐Ÿ“ Using default system prompt (run optimization pipeline first). from_artifact=False" + ) + return DEFAULT_SYSTEM_PROMPT, False diff --git a/examples/prompt_optimization/steps/prompt_text.py b/examples/prompt_optimization/steps/prompt_text.py new file mode 100644 index 0000000000..6e80eb2054 --- /dev/null +++ b/examples/prompt_optimization/steps/prompt_text.py @@ -0,0 +1,46 @@ +"""Shared prompt text for the prompt optimization example. + +This module centralizes the default system prompt and the user prompt +builder so both steps can reuse a single source of truth. +""" + +from __future__ import annotations + +import pandas as pd + +# Canonical default system prompt used by the EDA agent when no custom prompt +# is provided. Kept identical to the original in run_eda_agent to preserve behavior. +DEFAULT_SYSTEM_PROMPT: str = """You are a data analyst. Perform quick but insightful EDA. + +FOCUS ON: +- Data quality score (0-100) based on missing data and duplicates +- Key patterns and distributions +- Notable correlations or anomalies +- 2-3 actionable recommendations + +Be concise but specific with numbers. Aim for quality insights, not exhaustive analysis.""" + + +def build_user_prompt(main_ref: str, df: pd.DataFrame) -> str: + """Build the user prompt for the EDA agent. + + Produces the same structured instructions used previously, including dataset + shape metadata and the numbered action steps that guide the agent to use tools. + + Args: + main_ref: Reference key for the stored dataset (e.g., 'Out[1]'). + df: The pandas DataFrame being analyzed, used to derive rows and columns. + + Returns: + A formatted user prompt string guiding the agent through quick EDA steps. + """ + rows, cols = df.shape + return ( + f"Quick EDA analysis for dataset '{main_ref}' ({rows} rows, {cols} cols).\n\n" + "STEPS (keep it fast):\n" + f"1. display('{main_ref}') - check data structure \n" + f"2. describe('{main_ref}') - get key stats\n" + f"3. run_sql('{main_ref}', 'SELECT COUNT(*) as total FROM dataset') - check row count\n" + f"4. If multiple numeric columns: analyze_correlations('{main_ref}')\n\n" + "Generate EDAReport with data quality score and 2-3 key insights." + )