Skip to content

Commit 1908ab2

Browse files
author
dev
committed
feat: OLS forecast + rolling z-score anomaly detection
- transform/forecast.py: dependency-free per-country OLS projection with widening 95% prediction band, plus rolling-baseline z-score anomalies (shift-1 lookback so spikes do not contaminate their own score) - orchestration: new build_forecast_and_anomalies task wired into the flow after dbt build, writing mart.mart_emissions_forecast and mart.mart_emissions_anomalies - migration 004: forecast + anomalies tables with PK + indexes - dashboard page 05_forecast.py: actuals vs forecast with shaded band and anomaly markers, plus side-by-side data tables - dbt exposures: declare streamlit_dashboard and emissions_forecast_model as downstream consumers so dbt docs lineage is complete - tests: 5 new unit tests covering linear extrapolation, short history guard, multi-country fan-out, spike detection, and steady-series quiet
1 parent 3afe260 commit 1908ab2

7 files changed

Lines changed: 392 additions & 0 deletions

File tree

dashboard/pages/05_forecast.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
"""Forecast + anomaly page: actuals, OLS projection, and flagged outliers."""
2+
3+
from __future__ import annotations
4+
5+
import plotly.graph_objects as go
6+
import streamlit as st
7+
8+
from dashboard.components.filters import sidebar_filters
9+
from dashboard.utils.db import query
10+
11+
st.title("Forecast & Anomalies")
12+
sidebar_filters()
13+
14+
countries = query("select distinct country_code from mart.mart_country_emissions order by 1")
15+
if countries.empty:
16+
st.warning("No data — run `make seed` first.")
17+
st.stop()
18+
19+
country = st.selectbox("Country", countries["country_code"].tolist())
20+
21+
actuals = query(
22+
"select year, total_emissions_tonnes from mart.mart_country_emissions "
23+
"where country_code = %s order by year",
24+
(country,),
25+
)
26+
forecast = query(
27+
"select year, forecast_tonnes, lower_band, upper_band from mart.mart_emissions_forecast "
28+
"where country_code = %s order by year",
29+
(country,),
30+
)
31+
anomalies = query(
32+
"select year, total_emissions_tonnes, z_score, severity "
33+
"from mart.mart_emissions_anomalies where country_code = %s order by year",
34+
(country,),
35+
)
36+
37+
fig = go.Figure()
38+
fig.add_trace(
39+
go.Scatter(
40+
x=actuals["year"],
41+
y=actuals["total_emissions_tonnes"],
42+
mode="lines+markers",
43+
name="actual",
44+
)
45+
)
46+
if not forecast.empty:
47+
fig.add_trace(
48+
go.Scatter(
49+
x=forecast["year"],
50+
y=forecast["forecast_tonnes"],
51+
mode="lines+markers",
52+
name="forecast",
53+
line={"dash": "dash"},
54+
)
55+
)
56+
fig.add_trace(
57+
go.Scatter(
58+
x=list(forecast["year"]) + list(forecast["year"][::-1]),
59+
y=list(forecast["upper_band"]) + list(forecast["lower_band"][::-1]),
60+
fill="toself",
61+
fillcolor="rgba(255,140,0,0.15)",
62+
line={"width": 0},
63+
name="95% band",
64+
showlegend=True,
65+
)
66+
)
67+
if not anomalies.empty:
68+
fig.add_trace(
69+
go.Scatter(
70+
x=anomalies["year"],
71+
y=anomalies["total_emissions_tonnes"],
72+
mode="markers",
73+
marker={"size": 14, "color": "red", "symbol": "x"},
74+
name="anomaly",
75+
)
76+
)
77+
fig.update_layout(
78+
title=f"{country}: actuals, forecast, and anomalies",
79+
xaxis_title="year",
80+
yaxis_title="tonnes CO₂e",
81+
height=550,
82+
)
83+
st.plotly_chart(fig, use_container_width=True)
84+
85+
c1, c2 = st.columns(2)
86+
with c1:
87+
st.subheader("Forecast")
88+
st.dataframe(forecast, use_container_width=True)
89+
with c2:
90+
st.subheader("Anomalies")
91+
st.dataframe(anomalies, use_container_width=True)
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
version: 2
2+
3+
exposures:
4+
- name: streamlit_dashboard
5+
type: dashboard
6+
maturity: high
7+
url: http://localhost:8501
8+
description: "EU ETS multi-page Streamlit dashboard."
9+
depends_on:
10+
- ref('mart_country_emissions')
11+
- ref('mart_sector_trends')
12+
- ref('mart_top_emitters')
13+
- ref('mart_compliance_gap')
14+
owner:
15+
name: data-platform
16+
email: data@example.com
17+
18+
- name: emissions_forecast_model
19+
type: ml
20+
maturity: medium
21+
description: "OLS-based per-country forecast + rolling z-score anomaly detection. Built by orchestration.tasks.build_forecast_and_anomalies and stored in mart.mart_emissions_forecast / mart.mart_emissions_anomalies."
22+
depends_on:
23+
- ref('mart_country_emissions')
24+
owner:
25+
name: data-platform
26+
email: data@example.com

orchestration/flows.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from orchestration.alerts import send_alert
1313
from orchestration.tasks import (
14+
build_forecast_and_anomalies,
1415
clean_table,
1516
download_source,
1617
load_country_codes,
@@ -42,6 +43,10 @@ def energy_pipeline() -> dict[str, int]:
4243

4344
run_dbt("deps")
4445
run_dbt("build --target dev")
46+
47+
forecast_rows, anomaly_rows = build_forecast_and_anomalies()
48+
loaded["forecast"] = forecast_rows
49+
loaded["anomalies"] = anomaly_rows
4550
return loaded
4651
except Exception as exc:
4752
send_alert(

orchestration/tasks.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from ingest import loaders, sources
1313
from ingest.schemas import AllowanceSchema, EmissionSchema, InstallationSchema
1414
from transform import clean
15+
from transform.forecast import detect_anomalies, forecast_country_emissions
1516

1617
log = structlog.get_logger(__name__)
1718

@@ -113,3 +114,57 @@ def run_dbt(command: str) -> None:
113114
if result.returncode != 0:
114115
log.error("dbt_failed", stderr=result.stderr[-2000:])
115116
raise RuntimeError(f"dbt {command} failed: {result.stderr[-500:]}")
117+
118+
119+
@task(retries=2, retry_delay_seconds=30, tags=["transform", "ml"])
120+
def build_forecast_and_anomalies() -> tuple[int, int]:
121+
"""Read mart_country_emissions, fit per-country forecasts, flag anomalies.
122+
123+
Writes results to ``mart.mart_emissions_forecast`` and
124+
``mart.mart_emissions_anomalies``. Returns ``(forecast_rows, anomaly_rows)``.
125+
"""
126+
with loaders.get_conn() as conn, conn.cursor() as cur:
127+
cur.execute(
128+
"select country_code, year, total_emissions_tonnes from mart.mart_country_emissions"
129+
)
130+
rows = cur.fetchall()
131+
history = pd.DataFrame(rows, columns=["country_code", "year", "total_emissions_tonnes"])
132+
log.info("forecast_input", rows=len(history))
133+
134+
forecast = forecast_country_emissions(history)
135+
anomalies = detect_anomalies(history)
136+
137+
with loaders.get_conn() as conn:
138+
loaders.truncate(conn, "mart.mart_emissions_forecast")
139+
loaders.truncate(conn, "mart.mart_emissions_anomalies")
140+
if not forecast.empty:
141+
loaders.copy_dataframe(
142+
conn,
143+
forecast,
144+
"mart.mart_emissions_forecast",
145+
["country_code", "year", "forecast_tonnes", "lower_band", "upper_band", "model"],
146+
source_file="forecast_task",
147+
)
148+
if not anomalies.empty:
149+
with conn.cursor() as cur:
150+
cur.executemany(
151+
"INSERT INTO mart.mart_emissions_anomalies "
152+
"(country_code, year, total_emissions_tonnes, yoy_pct, z_score, severity) "
153+
"VALUES (%s, %s, %s, %s, %s, %s)",
154+
list(
155+
anomalies[
156+
[
157+
"country_code",
158+
"year",
159+
"total_emissions_tonnes",
160+
"yoy_pct",
161+
"z_score",
162+
"severity",
163+
]
164+
].itertuples(index=False, name=None)
165+
),
166+
)
167+
conn.commit()
168+
169+
log.info("forecast_done", forecast_rows=len(forecast), anomaly_rows=len(anomalies))
170+
return len(forecast), len(anomalies)

tests/test_forecast.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
"""Tests for the forecasting and anomaly detection module."""
2+
3+
from __future__ import annotations
4+
5+
import pandas as pd
6+
7+
from transform.forecast import ForecastConfig, detect_anomalies, forecast_country_emissions
8+
9+
10+
def _history(country: str, values: list[float], start: int = 2015) -> pd.DataFrame:
11+
return pd.DataFrame(
12+
{
13+
"country_code": [country] * len(values),
14+
"year": list(range(start, start + len(values))),
15+
"total_emissions_tonnes": values,
16+
}
17+
)
18+
19+
20+
def test_forecast_extrapolates_linear_trend() -> None:
21+
df = _history("DE", [100.0, 110.0, 120.0, 130.0, 140.0])
22+
out = forecast_country_emissions(df, ForecastConfig(horizon_years=3))
23+
assert len(out) == 3
24+
# OLS slope is 10/year — first forecast year should be ~150.
25+
assert abs(out.iloc[0]["forecast_tonnes"] - 150.0) < 1e-6
26+
assert (out["upper_band"] >= out["forecast_tonnes"]).all()
27+
assert (out["lower_band"] <= out["forecast_tonnes"]).all()
28+
29+
30+
def test_forecast_skips_short_history() -> None:
31+
df = _history("FR", [100.0, 110.0])
32+
out = forecast_country_emissions(df, ForecastConfig(min_history_years=4))
33+
assert out.empty
34+
35+
36+
def test_forecast_handles_multiple_countries() -> None:
37+
df = pd.concat(
38+
[
39+
_history("DE", [100.0, 110.0, 120.0, 130.0, 140.0]),
40+
_history("FR", [200.0, 195.0, 190.0, 185.0, 180.0]),
41+
]
42+
)
43+
out = forecast_country_emissions(df, ForecastConfig(horizon_years=2))
44+
assert set(out["country_code"]) == {"DE", "FR"}
45+
assert len(out) == 4
46+
47+
48+
def test_anomaly_detects_large_spike() -> None:
49+
df = _history("DE", [100.0, 102.0, 101.0, 103.0, 102.0, 500.0, 105.0])
50+
out = detect_anomalies(df, ForecastConfig(anomaly_z_threshold=2.0))
51+
assert not out.empty
52+
assert 2020 in out["year"].tolist()
53+
assert out["severity"].iloc[0] in {"warning", "critical"}
54+
55+
56+
def test_anomaly_returns_empty_for_steady_series() -> None:
57+
df = _history("DE", [100.0, 101.0, 102.0, 103.0, 104.0, 105.0])
58+
out = detect_anomalies(df)
59+
assert out.empty

transform/forecast.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
"""Forecasting and anomaly detection on country-level emissions.
2+
3+
Simple, dependency-free approach designed to run inside the Prefect flow
4+
*after* dbt has built ``mart.mart_country_emissions``:
5+
6+
* **Forecast:** ordinary least squares on (year, total_emissions) per country,
7+
projected ``horizon`` years forward, with a +/- 1.96 * residual-stderr band.
8+
* **Anomalies:** rolling z-score (window=3) on year-over-year change; rows with
9+
``|z| > 2.5`` are flagged.
10+
11+
Both outputs are written to dedicated mart tables so the dashboard can read
12+
them with the same cached ``query()`` helper as everything else.
13+
"""
14+
15+
from __future__ import annotations
16+
17+
from dataclasses import dataclass
18+
19+
import numpy as np
20+
import pandas as pd
21+
import structlog
22+
23+
log = structlog.get_logger(__name__)
24+
25+
26+
@dataclass(frozen=True)
27+
class ForecastConfig:
28+
"""Configuration for the forecast task."""
29+
30+
horizon_years: int = 5
31+
min_history_years: int = 4
32+
anomaly_z_threshold: float = 2.5
33+
anomaly_window: int = 3
34+
35+
36+
def _fit_linear(years: np.ndarray, values: np.ndarray) -> tuple[float, float, float]:
37+
"""Return (slope, intercept, residual_stderr) from a 1-D OLS fit."""
38+
n = len(years)
39+
if n < 2:
40+
return 0.0, float(values.mean() if n else 0.0), 0.0
41+
x_mean = years.mean()
42+
y_mean = values.mean()
43+
denom = ((years - x_mean) ** 2).sum()
44+
if denom == 0:
45+
return 0.0, float(y_mean), 0.0
46+
slope = float(((years - x_mean) * (values - y_mean)).sum() / denom)
47+
intercept = float(y_mean - slope * x_mean)
48+
fitted = slope * years + intercept
49+
residuals = values - fitted
50+
dof = max(n - 2, 1)
51+
stderr = float(np.sqrt((residuals**2).sum() / dof))
52+
return slope, intercept, stderr
53+
54+
55+
def forecast_country_emissions(
56+
history: pd.DataFrame, config: ForecastConfig | None = None
57+
) -> pd.DataFrame:
58+
"""Project emissions ``horizon_years`` forward for every country.
59+
60+
Args:
61+
history: Output of ``mart.mart_country_emissions`` with columns
62+
``country_code``, ``year``, ``total_emissions_tonnes``.
63+
config: Forecast configuration; defaults to ``ForecastConfig()``.
64+
65+
Returns:
66+
DataFrame with columns ``country_code``, ``year``, ``forecast_tonnes``,
67+
``lower_band``, ``upper_band``, ``model``.
68+
"""
69+
cfg = config or ForecastConfig()
70+
out_rows: list[dict[str, object]] = []
71+
72+
for country, group in history.groupby("country_code", sort=True):
73+
g = group.sort_values("year")
74+
if len(g) < cfg.min_history_years:
75+
log.warning("forecast_skipped_insufficient_history", country=country, years=len(g))
76+
continue
77+
years = g["year"].to_numpy(dtype=float)
78+
values = g["total_emissions_tonnes"].to_numpy(dtype=float)
79+
slope, intercept, stderr = _fit_linear(years, values)
80+
81+
last_year = int(years.max())
82+
for h in range(1, cfg.horizon_years + 1):
83+
yr = last_year + h
84+
point = slope * yr + intercept
85+
band = 1.96 * stderr * np.sqrt(h) # widening band
86+
out_rows.append(
87+
{
88+
"country_code": country,
89+
"year": yr,
90+
"forecast_tonnes": max(0.0, point),
91+
"lower_band": max(0.0, point - band),
92+
"upper_band": max(0.0, point + band),
93+
"model": "ols_linear",
94+
}
95+
)
96+
97+
return pd.DataFrame(out_rows)
98+
99+
100+
def detect_anomalies(history: pd.DataFrame, config: ForecastConfig | None = None) -> pd.DataFrame:
101+
"""Flag country-years whose YoY change is a rolling z-score outlier.
102+
103+
Args:
104+
history: ``mart.mart_country_emissions`` rows.
105+
config: Anomaly window + threshold.
106+
107+
Returns:
108+
DataFrame containing only the anomalous rows, with the computed
109+
``yoy_pct``, ``z_score``, and a string ``severity`` label.
110+
"""
111+
cfg = config or ForecastConfig()
112+
df = history.sort_values(["country_code", "year"]).copy()
113+
df["yoy_pct"] = df.groupby("country_code")["total_emissions_tonnes"].pct_change() * 100.0
114+
115+
# Compute the rolling baseline from prior rows only (shift by 1) so that
116+
# an anomalous point does not contaminate its own z-score.
117+
grouped = df.groupby("country_code")["yoy_pct"]
118+
prior = grouped.shift(1)
119+
rolling_mean = prior.groupby(df["country_code"]).transform(
120+
lambda s: s.rolling(cfg.anomaly_window, min_periods=2).mean()
121+
)
122+
rolling_std = prior.groupby(df["country_code"]).transform(
123+
lambda s: s.rolling(cfg.anomaly_window, min_periods=2).std()
124+
)
125+
df["z_score"] = (df["yoy_pct"] - rolling_mean) / rolling_std.replace(0, np.nan)
126+
127+
anomalies = df[df["z_score"].abs() > cfg.anomaly_z_threshold].copy()
128+
anomalies["severity"] = np.where(anomalies["z_score"].abs() > 4.0, "critical", "warning")
129+
return anomalies[
130+
["country_code", "year", "total_emissions_tonnes", "yoy_pct", "z_score", "severity"]
131+
].reset_index(drop=True)

0 commit comments

Comments
 (0)