A lightweight middleware for ML model failure detection and rollback.
This is a decision system, not a monitoring dashboard.
Prediction Guard is a thin middleware layer that sits in front of your ML inference endpoint and:
- Logs statistically useful prediction telemetry
- Analyzes logs for drift and failure signals
- Decides on model health with explicit reasoning
- Acts on decisions (rollback) with safeguards
Key Insight: Monitoring tells you something is wrong. Prediction Guard tells you what to do about it.
| Principle | What it means |
|---|---|
| Decision-first | Every analysis leads to an explicit decision with reasons |
| Multi-signal required | Drift alone is NOT enough to trigger rollback |
| Privacy-safe | Never log raw user dataβonly hashes and summaries |
| Human-in-the-loop | Auto-rollback is off by default; thresholds are manually tunable |
| Explainable | Every decision includes reasons a non-ML engineer can understand |
| Minimal | Only 2 dependencies: numpy and scipy |
# Clone the repository
git clone https://github.com/swamy18/prediction-guard.git
cd prediction-guard
# Install in development mode
pip install -e .
# Or install dependencies only
pip install -r requirements.txtnumpy>=1.21.0
scipy>=1.7.0
That's it. No Kafka. No Redis. No heavy infrastructure.
prediction-guard initThis creates prediction_guard_config.json with sensible defaults.
from prediction_guard.middleware import PredictionInterceptor
from prediction_guard.types import GuardConfig
# Configure
config = GuardConfig(
current_model_version="v2.0",
fallback_model_version="v1.9",
log_directory="./logs",
)
# Create interceptor
interceptor = PredictionInterceptor(config)
# In your prediction endpoint
def predict(input_data):
with interceptor.intercept(input_data, {"region": "us-east"}) as ctx:
result = your_model.predict(input_data)
ctx.set_result(
prediction=result.prediction,
confidence=result.confidence,
probabilities=result.probabilities,
embedding=result.embedding,
)
return resultprediction-guard baseline create --model v2.0 --days 7prediction-guard decide --model v2.0Output:
=== Model Health Decision ===
Model Version: v2.0
State: HEALTHY
Confidence: 95%
Recommended Action: none
Reasons:
- No issues detected
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β YOUR APPLICATION β
β βββββββββββββββ ββββββββββββββββββββ βββββββββββββββββββ β
β β Request β βββΆβ Interceptor β βββΆβ ML Model β β
β β β β (logs telemetry)β β Prediction β β
β βββββββββββββββ ββββββββββ¬ββββββββββ βββββββββββββββββββ β
ββββββββββββββββββββββββββββββββΌβββββββββββββββββββββββββββββββββββββββββββ
β
β Append-only writes
βΌ
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β JSONL LOG FILES β
β β
β predictions_2024-01-15.jsonl β
β predictions_2024-01-16.jsonl β
β ... β
ββββββββββββββββββββββββββββββββ¬βββββββββββββββββββββββββββββββββββββββββββ
β
β Scheduled / Manual trigger
βΌ
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β OFFLINE ANALYZER β
β β
β βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β Drift Detectors β β
β β β’ Feature Drift (Kolmogorov-Smirnov test) β β
β β β’ Embedding Drift (Cosine distance from baseline centroid) β β
β β β’ Prediction Drift (Population Stability Index) β β
β β β’ Confidence Entropy (Shannon entropy change) β β
β β β’ Latency Drift (P50/P99 percentile changes) β β
β βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β
β Compares current window against stored baseline β
ββββββββββββββββββββββββββββββββ¬βββββββββββββββββββββββββββββββββββββββββββ
β
β AnalysisResult
βΌ
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β DECISION ENGINE β
β β
β βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β Multi-Signal Logic β β
β β β β
β β if drift_signals >= 3: β β
β β state = UNSTABLE, action = ROLLBACK β β
β β β β
β β if drift_signals == 2 AND (embedding + confidence): β β
β β state = UNSTABLE, action = ROLLBACK β β
β β β β
β β if drift_signals == 1: β β
β β state = SUSPICIOUS, action = ALERT β β
β β β β
β β if business_proxy_healthy: β β
β β OVERRIDE drift signals β HEALTHY β β
β βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β
β Output: HealthDecision with state, reasons, recommended_action β
ββββββββββββββββββββββββββββββββ¬βββββββββββββββββββββββββββββββββββββββββββ
β
β If action = ROLLBACK
βΌ
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β ACTION EXECUTOR β
β β
β Rollback Mechanisms: β
β β’ Config file update (prediction_guard_config.json) β
β β’ Environment variable (MODEL_VERSION) β
β β’ Model alias file (model_alias.json) β
β β’ Feature flag file (feature_flags.json) β
β β
β Safeguards: β
β β Auto-rollback OFF by default β
β β Cooldown period (30 min default) β
β β All actions logged for audit β
β β Revert capability β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Every prediction logs one structured event:
{
"timestamp": "2024-01-15T10:30:00.123456",
"model_version": "v2.0",
"request_id": "550e8400-e29b-41d4-a716-446655440000",
"input_hash": "a3f2b8c9d4e5f6...",
"embedding_summary": [0.12, 0.34, 0.56, ...],
"prediction": "positive",
"confidence_score": 0.92,
"prediction_entropy": 0.28,
"latency_ms": 45.2,
"request_context": {
"region": "us-east-1",
"user_type": "premium"
}
}| Field | Privacy Treatment |
|---|---|
input_hash |
SHA256 hash of inputβraw data NEVER stored |
embedding_summary |
Mean/centroid onlyβno individual embeddings |
request_context |
Optional metadataβyou control what's included |
Compares the distribution of a feature (e.g., confidence scores) between current window and baseline.
from prediction_guard.analysis import DriftDetector
# Returns (ks_statistic, p_value)
stat, pvalue = DriftDetector.ks_test(current_values, baseline_values)
# Interpretation:
# stat > 0.15 AND pvalue < 0.05 β Significant driftWhen it fires: Input data distribution has shifted (e.g., new user demographics)
Measures how far the current embedding centroid has moved from baseline.
distance = DriftDetector.cosine_distance(current_centroid, baseline_centroid)
# Interpretation:
# 0.0 = identical direction
# 1.0 = orthogonal
# 2.0 = opposite directionWhen it fires: The semantic content of inputs has changed (e.g., new topics)
Measures shift in prediction class distribution.
psi = DriftDetector.population_stability_index(current_dist, baseline_dist)
# Interpretation:
# PSI < 0.1 β No significant change
# 0.1-0.25 β Moderate change, investigate
# PSI > 0.25 β Significant change, action neededWhen it fires: Model is producing different class ratios than expected
Measures change in prediction uncertainty.
change = DriftDetector.entropy_change(current_entropies, baseline_mean)
# Interpretation:
# Positive = more uncertainty (model less confident)
# Negative = less uncertainty (could be overconfident)When it fires: Model is becoming more/less certain about predictions
Detects performance regression.
p50_change, p99_change = DriftDetector.latency_drift(
current_latencies, baseline_p50, baseline_p99
)
# Interpretation:
# Positive = slower (regression)
# Negative = faster (unlikely to be bad)When it fires: Infrastructure or model performance has degraded
| State | Meaning | Typical Action |
|---|---|---|
HEALTHY |
Model performing as expected | None |
SUSPICIOUS |
Some drift detected, not conclusive | Alert, investigate |
UNSTABLE |
Clear degradation, action needed | Rollback |
The decision engine uses multi-signal logic. This is critical: drift alone is NOT enough.
# Pseudo-code for decision logic
if business_proxy_score >= 0.9:
# Business is fine, ignore drift signals
return HEALTHY
if business_proxy_score < 0.1:
# Business is suffering, even without drift
return UNSTABLE + ROLLBACK
drift_count = count_breached_thresholds()
if drift_count >= 3:
# Strong evidence: multiple independent signals
return UNSTABLE + ROLLBACK
if drift_count == 2:
if has_embedding_drift AND has_confidence_drift:
# Particularly concerning combination
return UNSTABLE + ROLLBACK
else:
# Investigate but don't act yet
return SUSPICIOUS + ALERT
if drift_count == 1:
# Could be noise or early warning
return SUSPICIOUS + ALERT
# No signals
return HEALTHY| Scenario | Single-Signal Response | Multi-Signal Response |
|---|---|---|
| Random noise in one metric | β False alarm rollback | β Ignore (HEALTHY) |
| Seasonal traffic change | β Unnecessary rollback | β Alert only (SUSPICIOUS) |
| Actual model degradation | β Correct rollback | β Correct rollback |
from prediction_guard.types import GuardConfig, RollbackMechanism
config = GuardConfig(
# === Drift Thresholds ===
# Tune these based on your model's sensitivity
feature_drift_threshold=0.15, # KS statistic threshold
embedding_drift_threshold=0.20, # Cosine distance threshold
prediction_drift_threshold=0.10, # PSI threshold
confidence_entropy_threshold=0.25, # Relative entropy change
latency_p99_threshold_ms=100.0, # Absolute P99 threshold
# === Analysis Windows ===
analysis_window_minutes=60, # How much recent data to analyze
baseline_window_days=7, # How much data for baseline
min_samples_for_analysis=100, # Minimum events for valid analysis
# === Rollback Settings ===
auto_rollback_enabled=False, # CRITICAL: Off by default
rollback_cooldown_minutes=30, # Minimum time between rollbacks
rollback_mechanism=RollbackMechanism.CONFIG_FILE,
# === Paths ===
log_directory="./logs",
baseline_directory="./baselines",
incident_directory="./incidents",
# === Model Versions ===
current_model_version="v2.0",
fallback_model_version="v1.9",
# === Business Proxy (Optional) ===
business_proxy_enabled=False,
business_proxy_threshold=0.10,
business_proxy_overrides_drift=True, # Business trumps drift
){
"feature_drift_threshold": 0.15,
"embedding_drift_threshold": 0.20,
"prediction_drift_threshold": 0.10,
"confidence_entropy_threshold": 0.25,
"latency_p99_threshold_ms": 100.0,
"analysis_window_minutes": 60,
"baseline_window_days": 7,
"min_samples_for_analysis": 100,
"auto_rollback_enabled": false,
"rollback_cooldown_minutes": 30,
"rollback_mechanism": "config_file",
"log_directory": "./logs",
"baseline_directory": "./baselines",
"incident_directory": "./incidents",
"current_model_version": "v2.0",
"fallback_model_version": "v1.9"
}export PREDICTION_GUARD_CONFIG=/path/to/config.json# Create default configuration file
prediction-guard init# Run drift analysis
prediction-guard analyze --model v2.0 --window 60
# Output as JSON
prediction-guard analyze --model v2.0 --json# Run analysis and make decision
prediction-guard decide --model v2.0
# With business proxy score
prediction-guard decide --model v2.0 --business-score 0.95
# JSON output
prediction-guard decide --json# Analyze, decide, (optionally) act
prediction-guard run --model v2.0
# Actually execute rollback if recommended
prediction-guard run --model v2.0 --execute# Create baseline from last 7 days of data
prediction-guard baseline create --model v2.0 --days 7
# List available baselines
prediction-guard baseline list
# Show baseline details
prediction-guard baseline show --model v2.0
# Delete baseline
prediction-guard baseline delete --model v2.0# Show system status
prediction-guard statusOutput:
{
"current_model_version": "v2.0",
"fallback_model_version": "v1.9",
"auto_rollback_enabled": false,
"has_baseline": true,
"available_baselines": ["v1.9", "v2.0"],
"recent_incidents": [],
"cooldown_active": false,
"cooldown_remaining_seconds": 0.0
}# List recent incidents
prediction-guard incidents --limit 10
# Filter by model
prediction-guard incidents --model v2.0from prediction_guard import PredictionGuard
guard = PredictionGuard() # Loads config from file
# Run analysis
analysis = guard.analyze(model_version="v2.0")
# Make decision
decision = guard.decide(analysis, business_proxy_score=0.95)
# Or do both at once
decision = guard.analyze_and_decide(model_version="v2.0")
# Execute action
if decision.recommended_action == ActionType.ROLLBACK:
action = guard.execute_action(decision, force=False)
# Full pipeline
result = guard.run_pipeline(
model_version="v2.0",
auto_execute=False, # Don't auto-execute
)
# Get system status
status = guard.get_status()from prediction_guard.middleware import PredictionInterceptor
interceptor = PredictionInterceptor(config)
# Context manager style (recommended)
with interceptor.intercept(input_data, {"region": "us-east"}) as ctx:
result = model.predict(input_data)
ctx.set_result(
prediction=result.prediction,
confidence=result.confidence,
probabilities=result.probabilities,
embedding=result.embedding,
)
# Direct logging style
request_id = interceptor.log_prediction(
input_data=input_data,
prediction="positive",
confidence=0.92,
probabilities=[0.92, 0.08],
embedding=[0.1, 0.2, 0.3],
request_context={"region": "us-east"},
latency_ms=45.2,
)
# Always close when done
interceptor.close()decision = guard.analyze_and_decide()
print(decision.model_version) # "v2.0"
print(decision.state) # ModelHealthState.UNSTABLE
print(decision.reasons) # ["embedding_drift_high", "confidence_entropy_spike"]
print(decision.recommended_action) # ActionType.ROLLBACK
print(decision.confidence) # 0.85
print(decision.analysis_summary) # {"feature_drift_score": 0.12, ...}
# Serialize
data = decision.to_dict()Updates prediction_guard_config.json:
{
"current_model_version": "v1.9",
"_rollback_at": "2024-01-15T10:30:00",
"_rollback_from": "v2.0"
}Sets:
MODEL_VERSION=v1.9
MODEL_ROLLBACK_AT=2024-01-15T10:30:00Creates/updates model_alias.json:
{
"current_alias": "v1.9",
"previous_alias": "v2.0",
"switched_at": "2024-01-15T10:30:00"
}Creates/updates feature_flags.json:
{
"active_model_version": "v1.9",
"model_rollback_active": true,
"rollback_at": "2024-01-15T10:30:00"
}For custom integrations (e.g., Kubernetes, service mesh):
# Extend ActionExecutor with custom handler
from prediction_guard.action import ActionExecutor
class CustomExecutor(ActionExecutor):
def _rollback_custom(self, action):
# Your custom rollback logic
# e.g., update Kubernetes ConfigMap
# e.g., call service mesh API
passAfter each incident, Prediction Guard saves a snapshot for post-mortem analysis:
from prediction_guard.incident import IncidentManager
manager = IncidentManager(config)
# Record an incident (automatic when state != HEALTHY)
incident_id = manager.record_incident(decision, analysis, action)
# Add human notes after investigation
manager.add_resolution_notes(
incident_id=incident_id,
notes="False positive. Traffic spike from marketing campaign.",
threshold_adjustments={
"feature_drift_threshold": 0.20, # Should be higher
}
)
# Get aggregated threshold recommendations
recommendations = manager.get_threshold_recommendations(model_version="v2.0")
# {"feature_drift_threshold": 0.18, ...}{
"incident_id": "550e8400-e29b-41d4-a716-446655440000",
"model_version": "v2.0",
"detected_at": "2024-01-15T10:30:00",
"decision": {
"state": "unstable",
"reasons": ["embedding_drift_high", "confidence_entropy_spike"],
"recommended_action": "rollback"
},
"analysis": {
"sample_count": 1523,
"feature_drift_score": 0.12,
"embedding_drift_score": 0.45,
"...": "..."
},
"action_taken": {
"success": true,
"mechanism": "config_file",
"from_version": "v2.0",
"to_version": "v1.9"
},
"resolution_notes": "Investigating root cause...",
"threshold_adjustments": {}
}Important: No auto-learning in v1. Human-in-the-loop tuning only.
# Run all tests
pytest tests/ -v
# Run specific test file
pytest tests/test_decision_engine.py -v
# With coverage
pytest tests/ --cov=prediction_guard --cov-report=htmltests/
βββ test_drift_detector.py # Statistical test verification
βββ test_decision_engine.py # Decision logic validation
βββ test_helpers.py # Utility function tests
prediction_guard/
βββ __init__.py # Package init, version info
βββ types.py # All types: enums, dataclasses
β βββ ModelHealthState # HEALTHY, SUSPICIOUS, UNSTABLE
β βββ DriftType # FEATURE, EMBEDDING, PREDICTION, etc.
β βββ ActionType # NONE, ALERT, ROLLBACK
β βββ PredictionEvent # Single prediction telemetry
β βββ DriftMetric # Single drift measurement
β βββ AnalysisResult # Complete analysis output
β βββ HealthDecision # Decision with reasons
β βββ RollbackAction # Executed rollback record
β βββ GuardConfig # All configuration options
βββ config.py # Load/save configuration
βββ guard.py # Main PredictionGuard orchestrator
βββ incident.py # Incident snapshots for learning loop
βββ cli.py # Command-line interface
β
βββ logging/
β βββ __init__.py
β βββ telemetry_logger.py # Append-only JSONL logging
β β βββ TelemetryLogger # Thread-safe, buffered writes
β βββ log_reader.py # Time-windowed log reading
β βββ LogReader # Memory-efficient streaming
β
βββ analysis/
β βββ __init__.py
β βββ drift_detector.py # Statistical tests
β β βββ DriftDetector # KS, PSI, cosine, entropy
β βββ baseline_manager.py # Baseline storage
β β βββ BaselineManager # Save/load/compute baselines
β βββ analyzer.py # Orchestrates analysis
β βββ OfflineAnalyzer # Reads logs, computes all metrics
β
βββ decision/
β βββ __init__.py
β βββ engine.py # Decision logic
β βββ DecisionEngine # Multi-signal evaluation
β
βββ action/
β βββ __init__.py
β βββ executor.py # Rollback execution
β β βββ ActionExecutor # Multiple mechanisms, logging
β βββ cooldown.py # Cooldown management
β βββ CooldownManager # Prevent rollback storms
β
βββ middleware/
βββ __init__.py
βββ interceptor.py # FastAPI-compatible middleware
β βββ PredictionInterceptor
βββ helpers.py # Utilities
βββ compute_input_hash()
βββ compute_embedding_summary()
βββ compute_entropy()
| Not Building | Why |
|---|---|
| Dashboards | Use Grafana/Datadog for visualization |
| Real-time streaming | Adds complexity without proportional value |
| Auto-threshold tuning | Requires more data and can be dangerous |
| Perfect thresholds | No such thingβtune based on your domain |
| Deep learning models | Overkill for drift detection |
| Replace observability | Complement, don't replace |
Prediction Guard does ONE thing: Detect model failure and decide when to roll back.
βββββββββββββββββββββββ βββββββββββββββββββββββ
β Inference API β β Cron Job β
β (FastAPI/Flask) β β (every 15 min) β
β β β β
β + Interceptor β β prediction-guard β
β (logs events) β β run --model v2.0 β
ββββββββββββ¬βββββββββββ ββββββββββββ¬βββββββββββ
β β
β writes β reads
βΌ βΌ
ββββββββββββββββββββββββββββββββββββββββββββ
β Shared Filesystem β
β (or S3/GCS bucket) β
β β
β logs/predictions_2024-01-15.jsonl β
β baselines/baseline_v2.0.json β
β incidents/incident_*.json β
ββββββββββββββββββββββββββββββββββββββββββββ
apiVersion: batch/v1
kind: CronJob
metadata:
name: prediction-guard-analysis
spec:
schedule: "*/15 * * * *" # Every 15 minutes
jobTemplate:
spec:
template:
spec:
containers:
- name: guard
image: your-registry/prediction-guard:latest
command:
- prediction-guard
- run
- --model
- v2.0
- --execute # Only if auto_rollback_enabled
volumeMounts:
- name: logs
mountPath: /app/logs
volumes:
- name: logs
persistentVolumeClaim:
claimName: prediction-logs
restartPolicy: OnFailure# After running pipeline
result = guard.run_pipeline()
if result["decision"]["state"] in ["suspicious", "unstable"]:
# Send to your alerting system
send_to_pagerduty(
severity="critical" if result["decision"]["state"] == "unstable" else "warning",
summary=f"Model {result['model_version']} is {result['decision']['state']}",
details=result["decision"],
)- Fork the repository
- Create a feature branch:
git checkout -b feature/my-feature - Make changes and add tests
- Run tests:
pytest tests/ -v - Submit a pull request
MIT License. See LICENSE for details.
Built with the philosophy that MLOps should be about decisions, not dashboards.
Inspired by real-world ML incidents where monitoring showed the problem but didn't tell anyone what to do about it.
Prediction Guard β The smallest system that actually decides and acts.