Skip to content

swamy18/prediction-guard--Lightweight-ML-inference-drift-failure-middleware

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

18 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

Prediction Guard

A lightweight middleware for ML model failure detection and rollback.

This is a decision system, not a monitoring dashboard.

Python 3.8+ License: MIT Tests


🎯 What is Prediction Guard?

Prediction Guard is a thin middleware layer that sits in front of your ML inference endpoint and:

  1. Logs statistically useful prediction telemetry
  2. Analyzes logs for drift and failure signals
  3. Decides on model health with explicit reasoning
  4. Acts on decisions (rollback) with safeguards

Key Insight: Monitoring tells you something is wrong. Prediction Guard tells you what to do about it.


🧠 Core Philosophy

Principle What it means
Decision-first Every analysis leads to an explicit decision with reasons
Multi-signal required Drift alone is NOT enough to trigger rollback
Privacy-safe Never log raw user dataβ€”only hashes and summaries
Human-in-the-loop Auto-rollback is off by default; thresholds are manually tunable
Explainable Every decision includes reasons a non-ML engineer can understand
Minimal Only 2 dependencies: numpy and scipy

πŸ“¦ Installation

# Clone the repository
git clone https://github.com/swamy18/prediction-guard.git
cd prediction-guard

# Install in development mode
pip install -e .

# Or install dependencies only
pip install -r requirements.txt

Dependencies

numpy>=1.21.0
scipy>=1.7.0

That's it. No Kafka. No Redis. No heavy infrastructure.


πŸš€ Quick Start

1. Initialize Configuration

prediction-guard init

This creates prediction_guard_config.json with sensible defaults.

2. Integrate with Your Inference Endpoint

from prediction_guard.middleware import PredictionInterceptor
from prediction_guard.types import GuardConfig

# Configure
config = GuardConfig(
    current_model_version="v2.0",
    fallback_model_version="v1.9",
    log_directory="./logs",
)

# Create interceptor
interceptor = PredictionInterceptor(config)

# In your prediction endpoint
def predict(input_data):
    with interceptor.intercept(input_data, {"region": "us-east"}) as ctx:
        result = your_model.predict(input_data)
        ctx.set_result(
            prediction=result.prediction,
            confidence=result.confidence,
            probabilities=result.probabilities,
            embedding=result.embedding,
        )
    return result

3. Create a Baseline (from historical data)

prediction-guard baseline create --model v2.0 --days 7

4. Run Analysis and Get Decision

prediction-guard decide --model v2.0

Output:

=== Model Health Decision ===
Model Version: v2.0
State: HEALTHY
Confidence: 95%
Recommended Action: none

Reasons:
  - No issues detected

πŸ—οΈ Architecture

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                           YOUR APPLICATION                               β”‚
β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”         β”‚
β”‚  β”‚   Request   β”‚ ──▢│  Interceptor     β”‚ ──▢│   ML Model      β”‚         β”‚
β”‚  β”‚             β”‚    β”‚  (logs telemetry)β”‚    β”‚   Prediction    β”‚         β”‚
β”‚  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜    β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜         β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
                               β”‚
                               β”‚ Append-only writes
                               β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                        JSONL LOG FILES                                   β”‚
β”‚                                                                          β”‚
β”‚  predictions_2024-01-15.jsonl                                           β”‚
β”‚  predictions_2024-01-16.jsonl                                           β”‚
β”‚  ...                                                                    β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
                               β”‚
                               β”‚ Scheduled / Manual trigger
                               β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                      OFFLINE ANALYZER                                    β”‚
β”‚                                                                          β”‚
β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”  β”‚
β”‚  β”‚  Drift Detectors                                                   β”‚  β”‚
β”‚  β”‚  β€’ Feature Drift (Kolmogorov-Smirnov test)                        β”‚  β”‚
β”‚  β”‚  β€’ Embedding Drift (Cosine distance from baseline centroid)       β”‚  β”‚
β”‚  β”‚  β€’ Prediction Drift (Population Stability Index)                  β”‚  β”‚
β”‚  β”‚  β€’ Confidence Entropy (Shannon entropy change)                    β”‚  β”‚
β”‚  β”‚  β€’ Latency Drift (P50/P99 percentile changes)                    β”‚  β”‚
β”‚  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜  β”‚
β”‚                                                                          β”‚
β”‚  Compares current window against stored baseline                        β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
                               β”‚
                               β”‚ AnalysisResult
                               β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                      DECISION ENGINE                                     β”‚
β”‚                                                                          β”‚
β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”  β”‚
β”‚  β”‚  Multi-Signal Logic                                                β”‚  β”‚
β”‚  β”‚                                                                    β”‚  β”‚
β”‚  β”‚  if drift_signals >= 3:                                           β”‚  β”‚
β”‚  β”‚      state = UNSTABLE, action = ROLLBACK                          β”‚  β”‚
β”‚  β”‚                                                                    β”‚  β”‚
β”‚  β”‚  if drift_signals == 2 AND (embedding + confidence):              β”‚  β”‚
β”‚  β”‚      state = UNSTABLE, action = ROLLBACK                          β”‚  β”‚
β”‚  β”‚                                                                    β”‚  β”‚
β”‚  β”‚  if drift_signals == 1:                                           β”‚  β”‚
β”‚  β”‚      state = SUSPICIOUS, action = ALERT                           β”‚  β”‚
β”‚  β”‚                                                                    β”‚  β”‚
β”‚  β”‚  if business_proxy_healthy:                                       β”‚  β”‚
β”‚  β”‚      OVERRIDE drift signals β†’ HEALTHY                              β”‚  β”‚
β”‚  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜  β”‚
β”‚                                                                          β”‚
β”‚  Output: HealthDecision with state, reasons, recommended_action         β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
                               β”‚
                               β”‚ If action = ROLLBACK
                               β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                      ACTION EXECUTOR                                     β”‚
β”‚                                                                          β”‚
β”‚  Rollback Mechanisms:                                                   β”‚
β”‚  β€’ Config file update (prediction_guard_config.json)                    β”‚
β”‚  β€’ Environment variable (MODEL_VERSION)                                 β”‚
β”‚  β€’ Model alias file (model_alias.json)                                  β”‚
β”‚  β€’ Feature flag file (feature_flags.json)                               β”‚
β”‚                                                                          β”‚
β”‚  Safeguards:                                                            β”‚
β”‚  βœ“ Auto-rollback OFF by default                                        β”‚
β”‚  βœ“ Cooldown period (30 min default)                                    β”‚
β”‚  βœ“ All actions logged for audit                                        β”‚
β”‚  βœ“ Revert capability                                                   β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

πŸ“Š Telemetry Events

Every prediction logs one structured event:

{
  "timestamp": "2024-01-15T10:30:00.123456",
  "model_version": "v2.0",
  "request_id": "550e8400-e29b-41d4-a716-446655440000",
  "input_hash": "a3f2b8c9d4e5f6...",
  "embedding_summary": [0.12, 0.34, 0.56, ...],
  "prediction": "positive",
  "confidence_score": 0.92,
  "prediction_entropy": 0.28,
  "latency_ms": 45.2,
  "request_context": {
    "region": "us-east-1",
    "user_type": "premium"
  }
}

Privacy Guarantees

Field Privacy Treatment
input_hash SHA256 hash of inputβ€”raw data NEVER stored
embedding_summary Mean/centroid onlyβ€”no individual embeddings
request_context Optional metadataβ€”you control what's included

πŸ“ˆ Drift Detection Methods

1. Feature Drift (Kolmogorov-Smirnov Test)

Compares the distribution of a feature (e.g., confidence scores) between current window and baseline.

from prediction_guard.analysis import DriftDetector

# Returns (ks_statistic, p_value)
stat, pvalue = DriftDetector.ks_test(current_values, baseline_values)

# Interpretation:
# stat > 0.15 AND pvalue < 0.05 β†’ Significant drift

When it fires: Input data distribution has shifted (e.g., new user demographics)

2. Embedding Drift (Cosine Distance)

Measures how far the current embedding centroid has moved from baseline.

distance = DriftDetector.cosine_distance(current_centroid, baseline_centroid)

# Interpretation:
# 0.0 = identical direction
# 1.0 = orthogonal
# 2.0 = opposite direction

When it fires: The semantic content of inputs has changed (e.g., new topics)

3. Prediction Drift (Population Stability Index)

Measures shift in prediction class distribution.

psi = DriftDetector.population_stability_index(current_dist, baseline_dist)

# Interpretation:
# PSI < 0.1  β†’ No significant change
# 0.1-0.25  β†’ Moderate change, investigate
# PSI > 0.25 β†’ Significant change, action needed

When it fires: Model is producing different class ratios than expected

4. Confidence Entropy

Measures change in prediction uncertainty.

change = DriftDetector.entropy_change(current_entropies, baseline_mean)

# Interpretation:
# Positive = more uncertainty (model less confident)
# Negative = less uncertainty (could be overconfident)

When it fires: Model is becoming more/less certain about predictions

5. Latency Drift

Detects performance regression.

p50_change, p99_change = DriftDetector.latency_drift(
    current_latencies, baseline_p50, baseline_p99
)

# Interpretation:
# Positive = slower (regression)
# Negative = faster (unlikely to be bad)

When it fires: Infrastructure or model performance has degraded


🎯 Decision Logic

Health States

State Meaning Typical Action
HEALTHY Model performing as expected None
SUSPICIOUS Some drift detected, not conclusive Alert, investigate
UNSTABLE Clear degradation, action needed Rollback

Decision Rules

The decision engine uses multi-signal logic. This is critical: drift alone is NOT enough.

# Pseudo-code for decision logic

if business_proxy_score >= 0.9:
    # Business is fine, ignore drift signals
    return HEALTHY

if business_proxy_score < 0.1:
    # Business is suffering, even without drift
    return UNSTABLE + ROLLBACK

drift_count = count_breached_thresholds()

if drift_count >= 3:
    # Strong evidence: multiple independent signals
    return UNSTABLE + ROLLBACK

if drift_count == 2:
    if has_embedding_drift AND has_confidence_drift:
        # Particularly concerning combination
        return UNSTABLE + ROLLBACK
    else:
        # Investigate but don't act yet
        return SUSPICIOUS + ALERT

if drift_count == 1:
    # Could be noise or early warning
    return SUSPICIOUS + ALERT

# No signals
return HEALTHY

Why Multi-Signal?

Scenario Single-Signal Response Multi-Signal Response
Random noise in one metric ❌ False alarm rollback βœ… Ignore (HEALTHY)
Seasonal traffic change ❌ Unnecessary rollback βœ… Alert only (SUSPICIOUS)
Actual model degradation βœ… Correct rollback βœ… Correct rollback

βš™οΈ Configuration Reference

from prediction_guard.types import GuardConfig, RollbackMechanism

config = GuardConfig(
    # === Drift Thresholds ===
    # Tune these based on your model's sensitivity
    feature_drift_threshold=0.15,       # KS statistic threshold
    embedding_drift_threshold=0.20,     # Cosine distance threshold
    prediction_drift_threshold=0.10,    # PSI threshold
    confidence_entropy_threshold=0.25,  # Relative entropy change
    latency_p99_threshold_ms=100.0,     # Absolute P99 threshold
    
    # === Analysis Windows ===
    analysis_window_minutes=60,         # How much recent data to analyze
    baseline_window_days=7,             # How much data for baseline
    min_samples_for_analysis=100,       # Minimum events for valid analysis
    
    # === Rollback Settings ===
    auto_rollback_enabled=False,        # CRITICAL: Off by default
    rollback_cooldown_minutes=30,       # Minimum time between rollbacks
    rollback_mechanism=RollbackMechanism.CONFIG_FILE,
    
    # === Paths ===
    log_directory="./logs",
    baseline_directory="./baselines",
    incident_directory="./incidents",
    
    # === Model Versions ===
    current_model_version="v2.0",
    fallback_model_version="v1.9",
    
    # === Business Proxy (Optional) ===
    business_proxy_enabled=False,
    business_proxy_threshold=0.10,
    business_proxy_overrides_drift=True,  # Business trumps drift
)

Configuration File (JSON)

{
  "feature_drift_threshold": 0.15,
  "embedding_drift_threshold": 0.20,
  "prediction_drift_threshold": 0.10,
  "confidence_entropy_threshold": 0.25,
  "latency_p99_threshold_ms": 100.0,
  "analysis_window_minutes": 60,
  "baseline_window_days": 7,
  "min_samples_for_analysis": 100,
  "auto_rollback_enabled": false,
  "rollback_cooldown_minutes": 30,
  "rollback_mechanism": "config_file",
  "log_directory": "./logs",
  "baseline_directory": "./baselines",
  "incident_directory": "./incidents",
  "current_model_version": "v2.0",
  "fallback_model_version": "v1.9"
}

Environment Variable

export PREDICTION_GUARD_CONFIG=/path/to/config.json

πŸ–₯️ CLI Reference

Initialize

# Create default configuration file
prediction-guard init

Analyze

# Run drift analysis
prediction-guard analyze --model v2.0 --window 60

# Output as JSON
prediction-guard analyze --model v2.0 --json

Decide

# Run analysis and make decision
prediction-guard decide --model v2.0

# With business proxy score
prediction-guard decide --model v2.0 --business-score 0.95

# JSON output
prediction-guard decide --json

Run Full Pipeline

# Analyze, decide, (optionally) act
prediction-guard run --model v2.0

# Actually execute rollback if recommended
prediction-guard run --model v2.0 --execute

Baseline Management

# Create baseline from last 7 days of data
prediction-guard baseline create --model v2.0 --days 7

# List available baselines
prediction-guard baseline list

# Show baseline details
prediction-guard baseline show --model v2.0

# Delete baseline
prediction-guard baseline delete --model v2.0

Status

# Show system status
prediction-guard status

Output:

{
  "current_model_version": "v2.0",
  "fallback_model_version": "v1.9",
  "auto_rollback_enabled": false,
  "has_baseline": true,
  "available_baselines": ["v1.9", "v2.0"],
  "recent_incidents": [],
  "cooldown_active": false,
  "cooldown_remaining_seconds": 0.0
}

Incidents

# List recent incidents
prediction-guard incidents --limit 10

# Filter by model
prediction-guard incidents --model v2.0

🐍 Python API Reference

PredictionGuard (Main Orchestrator)

from prediction_guard import PredictionGuard

guard = PredictionGuard()  # Loads config from file

# Run analysis
analysis = guard.analyze(model_version="v2.0")

# Make decision
decision = guard.decide(analysis, business_proxy_score=0.95)

# Or do both at once
decision = guard.analyze_and_decide(model_version="v2.0")

# Execute action
if decision.recommended_action == ActionType.ROLLBACK:
    action = guard.execute_action(decision, force=False)

# Full pipeline
result = guard.run_pipeline(
    model_version="v2.0",
    auto_execute=False,  # Don't auto-execute
)

# Get system status
status = guard.get_status()

PredictionInterceptor (Middleware)

from prediction_guard.middleware import PredictionInterceptor

interceptor = PredictionInterceptor(config)

# Context manager style (recommended)
with interceptor.intercept(input_data, {"region": "us-east"}) as ctx:
    result = model.predict(input_data)
    ctx.set_result(
        prediction=result.prediction,
        confidence=result.confidence,
        probabilities=result.probabilities,
        embedding=result.embedding,
    )

# Direct logging style
request_id = interceptor.log_prediction(
    input_data=input_data,
    prediction="positive",
    confidence=0.92,
    probabilities=[0.92, 0.08],
    embedding=[0.1, 0.2, 0.3],
    request_context={"region": "us-east"},
    latency_ms=45.2,
)

# Always close when done
interceptor.close()

HealthDecision (Output)

decision = guard.analyze_and_decide()

print(decision.model_version)     # "v2.0"
print(decision.state)             # ModelHealthState.UNSTABLE
print(decision.reasons)           # ["embedding_drift_high", "confidence_entropy_spike"]
print(decision.recommended_action) # ActionType.ROLLBACK
print(decision.confidence)        # 0.85
print(decision.analysis_summary)  # {"feature_drift_score": 0.12, ...}

# Serialize
data = decision.to_dict()

πŸ”„ Rollback Mechanisms

1. Config File (Default)

Updates prediction_guard_config.json:

{
  "current_model_version": "v1.9",
  "_rollback_at": "2024-01-15T10:30:00",
  "_rollback_from": "v2.0"
}

2. Environment Variable

Sets:

MODEL_VERSION=v1.9
MODEL_ROLLBACK_AT=2024-01-15T10:30:00

3. Model Alias File

Creates/updates model_alias.json:

{
  "current_alias": "v1.9",
  "previous_alias": "v2.0",
  "switched_at": "2024-01-15T10:30:00"
}

4. Feature Flag File

Creates/updates feature_flags.json:

{
  "active_model_version": "v1.9",
  "model_rollback_active": true,
  "rollback_at": "2024-01-15T10:30:00"
}

Custom Rollback Handler

For custom integrations (e.g., Kubernetes, service mesh):

# Extend ActionExecutor with custom handler
from prediction_guard.action import ActionExecutor

class CustomExecutor(ActionExecutor):
    def _rollback_custom(self, action):
        # Your custom rollback logic
        # e.g., update Kubernetes ConfigMap
        # e.g., call service mesh API
        pass

πŸ“š Learning Loop

After each incident, Prediction Guard saves a snapshot for post-mortem analysis:

from prediction_guard.incident import IncidentManager

manager = IncidentManager(config)

# Record an incident (automatic when state != HEALTHY)
incident_id = manager.record_incident(decision, analysis, action)

# Add human notes after investigation
manager.add_resolution_notes(
    incident_id=incident_id,
    notes="False positive. Traffic spike from marketing campaign.",
    threshold_adjustments={
        "feature_drift_threshold": 0.20,  # Should be higher
    }
)

# Get aggregated threshold recommendations
recommendations = manager.get_threshold_recommendations(model_version="v2.0")
# {"feature_drift_threshold": 0.18, ...}

Incident Snapshot Structure

{
  "incident_id": "550e8400-e29b-41d4-a716-446655440000",
  "model_version": "v2.0",
  "detected_at": "2024-01-15T10:30:00",
  "decision": {
    "state": "unstable",
    "reasons": ["embedding_drift_high", "confidence_entropy_spike"],
    "recommended_action": "rollback"
  },
  "analysis": {
    "sample_count": 1523,
    "feature_drift_score": 0.12,
    "embedding_drift_score": 0.45,
    "...": "..."
  },
  "action_taken": {
    "success": true,
    "mechanism": "config_file",
    "from_version": "v2.0",
    "to_version": "v1.9"
  },
  "resolution_notes": "Investigating root cause...",
  "threshold_adjustments": {}
}

Important: No auto-learning in v1. Human-in-the-loop tuning only.


πŸ§ͺ Testing

# Run all tests
pytest tests/ -v

# Run specific test file
pytest tests/test_decision_engine.py -v

# With coverage
pytest tests/ --cov=prediction_guard --cov-report=html

Test Structure

tests/
β”œβ”€β”€ test_drift_detector.py    # Statistical test verification
β”œβ”€β”€ test_decision_engine.py   # Decision logic validation
└── test_helpers.py           # Utility function tests

πŸ“ File Structure

prediction_guard/
β”œβ”€β”€ __init__.py                 # Package init, version info
β”œβ”€β”€ types.py                    # All types: enums, dataclasses
β”‚   β”œβ”€β”€ ModelHealthState        # HEALTHY, SUSPICIOUS, UNSTABLE
β”‚   β”œβ”€β”€ DriftType               # FEATURE, EMBEDDING, PREDICTION, etc.
β”‚   β”œβ”€β”€ ActionType              # NONE, ALERT, ROLLBACK
β”‚   β”œβ”€β”€ PredictionEvent         # Single prediction telemetry
β”‚   β”œβ”€β”€ DriftMetric             # Single drift measurement
β”‚   β”œβ”€β”€ AnalysisResult          # Complete analysis output
β”‚   β”œβ”€β”€ HealthDecision          # Decision with reasons
β”‚   β”œβ”€β”€ RollbackAction          # Executed rollback record
β”‚   └── GuardConfig             # All configuration options
β”œβ”€β”€ config.py                   # Load/save configuration
β”œβ”€β”€ guard.py                    # Main PredictionGuard orchestrator
β”œβ”€β”€ incident.py                 # Incident snapshots for learning loop
β”œβ”€β”€ cli.py                      # Command-line interface
β”‚
β”œβ”€β”€ logging/
β”‚   β”œβ”€β”€ __init__.py
β”‚   β”œβ”€β”€ telemetry_logger.py     # Append-only JSONL logging
β”‚   β”‚   └── TelemetryLogger     # Thread-safe, buffered writes
β”‚   └── log_reader.py           # Time-windowed log reading
β”‚       └── LogReader           # Memory-efficient streaming
β”‚
β”œβ”€β”€ analysis/
β”‚   β”œβ”€β”€ __init__.py
β”‚   β”œβ”€β”€ drift_detector.py       # Statistical tests
β”‚   β”‚   └── DriftDetector       # KS, PSI, cosine, entropy
β”‚   β”œβ”€β”€ baseline_manager.py     # Baseline storage
β”‚   β”‚   └── BaselineManager     # Save/load/compute baselines
β”‚   └── analyzer.py             # Orchestrates analysis
β”‚       └── OfflineAnalyzer     # Reads logs, computes all metrics
β”‚
β”œβ”€β”€ decision/
β”‚   β”œβ”€β”€ __init__.py
β”‚   └── engine.py               # Decision logic
β”‚       └── DecisionEngine      # Multi-signal evaluation
β”‚
β”œβ”€β”€ action/
β”‚   β”œβ”€β”€ __init__.py
β”‚   β”œβ”€β”€ executor.py             # Rollback execution
β”‚   β”‚   └── ActionExecutor      # Multiple mechanisms, logging
β”‚   └── cooldown.py             # Cooldown management
β”‚       └── CooldownManager     # Prevent rollback storms
β”‚
└── middleware/
    β”œβ”€β”€ __init__.py
    β”œβ”€β”€ interceptor.py          # FastAPI-compatible middleware
    β”‚   └── PredictionInterceptor
    └── helpers.py              # Utilities
        β”œβ”€β”€ compute_input_hash()
        β”œβ”€β”€ compute_embedding_summary()
        └── compute_entropy()

🚫 Non-Goals (Explicitly NOT Built)

Not Building Why
Dashboards Use Grafana/Datadog for visualization
Real-time streaming Adds complexity without proportional value
Auto-threshold tuning Requires more data and can be dangerous
Perfect thresholds No such thingβ€”tune based on your domain
Deep learning models Overkill for drift detection
Replace observability Complement, don't replace

Prediction Guard does ONE thing: Detect model failure and decide when to roll back.


πŸ”§ Production Deployment

Recommended Architecture

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”     β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚   Inference API     β”‚     β”‚   Cron Job          β”‚
β”‚   (FastAPI/Flask)   β”‚     β”‚   (every 15 min)    β”‚
β”‚                     β”‚     β”‚                     β”‚
β”‚   + Interceptor     β”‚     β”‚   prediction-guard  β”‚
β”‚     (logs events)   β”‚     β”‚   run --model v2.0  β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜     β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
           β”‚                           β”‚
           β”‚ writes                    β”‚ reads
           β–Ό                           β–Ό
    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
    β”‚          Shared Filesystem               β”‚
    β”‚          (or S3/GCS bucket)              β”‚
    β”‚                                          β”‚
    β”‚   logs/predictions_2024-01-15.jsonl      β”‚
    β”‚   baselines/baseline_v2.0.json           β”‚
    β”‚   incidents/incident_*.json              β”‚
    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Kubernetes CronJob Example

apiVersion: batch/v1
kind: CronJob
metadata:
  name: prediction-guard-analysis
spec:
  schedule: "*/15 * * * *"  # Every 15 minutes
  jobTemplate:
    spec:
      template:
        spec:
          containers:
          - name: guard
            image: your-registry/prediction-guard:latest
            command:
            - prediction-guard
            - run
            - --model
            - v2.0
            - --execute  # Only if auto_rollback_enabled
            volumeMounts:
            - name: logs
              mountPath: /app/logs
          volumes:
          - name: logs
            persistentVolumeClaim:
              claimName: prediction-logs
          restartPolicy: OnFailure

Alerting Integration

# After running pipeline
result = guard.run_pipeline()

if result["decision"]["state"] in ["suspicious", "unstable"]:
    # Send to your alerting system
    send_to_pagerduty(
        severity="critical" if result["decision"]["state"] == "unstable" else "warning",
        summary=f"Model {result['model_version']} is {result['decision']['state']}",
        details=result["decision"],
    )

🀝 Contributing

  1. Fork the repository
  2. Create a feature branch: git checkout -b feature/my-feature
  3. Make changes and add tests
  4. Run tests: pytest tests/ -v
  5. Submit a pull request

πŸ“„ License

MIT License. See LICENSE for details.


πŸ™ Acknowledgments

Built with the philosophy that MLOps should be about decisions, not dashboards.

Inspired by real-world ML incidents where monitoring showed the problem but didn't tell anyone what to do about it.


Prediction Guard β€” The smallest system that actually decides and acts.

About

Lightweight middleware to detect silent ML model failures, prediction drift, and unstable inference in production

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages