Skip to content

Commit 768bda2

Browse files
committed
PR review
1 parent 9b2060b commit 768bda2

File tree

15 files changed

+216
-218
lines changed

15 files changed

+216
-218
lines changed

floracast/README.md

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ python run.py --config configs/inference.yaml --pipeline inference
9797
```
9898

9999
3. **View results**:
100-
- Check `outputs/forecast_inference.csv` for predictions
100+
- Check the predictions artifact for predictions
101101
- Use ZenML dashboard to view artifacts and metrics
102102

103103
## ⚙️ Configuration Files
@@ -118,35 +118,33 @@ Edit the appropriate config file to customize:
118118
- **Evaluation**: Forecasting horizon, metrics
119119
- **Output**: File paths and formats
120120

121-
122-
123121
```
124122
floracast/
125123
├── README.md
126124
├── requirements.txt
127125
├── .env.example
128126
├── configs/
129-
│ ├── training.yaml # Training pipeline config
130-
── inference.yaml # Inference pipeline config
127+
│ ├── training.yaml # Training pipeline config
128+
── inference.yaml # Inference pipeline config
131129
├── data/
132-
│ └── ecommerce_daily.csv # Generated sample data
130+
│ └── ecommerce_daily.csv # Example input data
133131
├── pipelines/
134-
│ ├── train_forecast_pipeline.py
135-
│ └── batch_inference_pipeline.py
132+
│ ├── train_forecast_pipeline.py # Training pipeline definition
133+
│ └── batch_inference_pipeline.py # Batch inference pipeline definition
136134
├── steps/
137-
│ ├── ingest.py # Data loading
138-
│ ├── preprocess.py # Time series preprocessing
139-
│ ├── train.py # Model training
140-
│ ├── evaluate.py # Model evaluation
141-
│ ├── promote.py # Model registration
142-
│ ├── batch_infer.py # Batch predictions
143-
│ └── load_model.py # Model loading utilities
135+
│ ├── ingest.py # Data ingestion step
136+
│ ├── preprocess.py # Preprocessing step (train/val split, scaling)
137+
│ ├── train.py # Model training step
138+
│ ├── evaluate.py # Model evaluation step
139+
│ ├── promote.py # Model registration/promotion step
140+
│ ├── batch_infer.py # Batch inference step
141+
│ └── load_model.py # Model loading utilities
144142
├── materializers/
145-
── tft_materializer.py # Custom TFTModel materializer
146-
| └── timeseries_materializer.py # Custom TimeSeries materializer
143+
── tft_materializer.py # Custom TFTModel materializer
144+
└── timeseries_materializer.py # Custom TimeSeries materializer
147145
├── utils/
148-
│ └── metrics.py # Forecasting metrics
149-
└── run.py # Main entry point
146+
│ └── metrics.py # Forecasting metrics (e.g., SMAPE)
147+
└── run.py # CLI entry point for running pipelines
150148
```
151149

152150
### Key Components
@@ -199,7 +197,7 @@ Read more:
199197

200198
- **Set up an MLOps stack on Azure**: [ZenML Azure guide](https://docs.zenml.io/stacks/popular-stacks/azure-guide)
201199
- **Kubernetes Orchestrator (AKS)**: [Docs](https://docs.zenml.io/stacks/stack-components/orchestrators/kubernetes)
202-
- **Azure Blob Artifact Store**: [Docs](https://docs.zenml.io/stacks/stack-components/artifact-stores/azuree)
200+
- **Azure Blob Artifact Store**: [Docs](https://docs.zenml.io/stacks/stack-components/artifact-stores/azure)
203201
- **Azure Container Registry**: [Docs](https://docs.zenml.io/stacks/stack-components/container-registries/azure)
204202
- **AzureML Step Operator**: [Docs](https://docs.zenml.io/stacks/stack-components/step-operators/azureml)
205203
- **Terraform stack recipe for Azure**: [Hashicorp Registry](https://registry.terraform.io/modules/zenml-io/zenml-stack/azure/latest)

floracast/configs/inference.yaml

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,9 @@ steps:
1818
datetime_col: "ds"
1919
target_col: "y"
2020

21-
preprocess_for_inference:
21+
batch_inference_predict:
2222
parameters:
2323
datetime_col: "ds"
2424
target_col: "y"
2525
freq: "D"
26-
27-
batch_inference_predict:
28-
parameters:
2926
horizon: 14

floracast/configs/training.yaml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ steps:
3232
datetime_col: "ds"
3333
target_col: "y"
3434

35-
preprocess_for_training:
35+
preprocess_data:
3636
parameters:
3737
datetime_col: "ds"
3838
target_col: "y"
@@ -41,7 +41,6 @@ steps:
4141

4242
train_model:
4343
parameters:
44-
model_name: "TFTModel"
4544
input_chunk_length: 90
4645
output_chunk_length: 14
4746
hidden_size: 256
@@ -57,5 +56,5 @@ steps:
5756

5857
evaluate:
5958
parameters:
60-
horizon: 7
59+
horizon: 14
6160
metric: "smape"

floracast/materializers/tft_materializer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,12 +168,12 @@ def _load_with_pytorch_state(self) -> Any:
168168
)
169169

170170
dates = pd.date_range("2020-01-01", periods=dummy_length, freq="D")
171-
values = np.random.randn(dummy_length)
171+
values = np.random.randn(dummy_length).astype(np.float32)
172172
dummy_series = TimeSeries.from_dataframe(
173173
pd.DataFrame({"ds": dates, "y": values}),
174174
time_col="ds",
175175
value_cols="y",
176-
)
176+
).astype(np.float32)
177177

178178
# Partially fit to create the internal model structure
179179
temp_model.fit(dummy_series, epochs=1, verbose=False)

floracast/pipelines/batch_inference_pipeline.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from zenml.logger import get_logger
77

88
from steps.ingest import ingest_data
9-
from steps.preprocess import preprocess_for_inference
109
from steps.batch_infer import batch_inference_predict
1110

1211
logger = get_logger(__name__)
@@ -22,10 +21,7 @@ def batch_inference_pipeline() -> None:
2221
# Step 1: Ingest data (simulate real-time data sources)
2322
raw_data = ingest_data(infer=True)
2423

25-
# Step 2: Preprocess data (use full series for inference context)
26-
inference_series = preprocess_for_inference(df=raw_data)
27-
28-
# Step 3: Generate predictions using model from MCP
29-
batch_inference_predict(series=inference_series)
24+
# Step 2: Generate predictions using model from MCP (with scaling handled internally)
25+
batch_inference_predict(df=raw_data)
3026

3127
logger.info("Batch inference completed. Returning predictions DataFrame.")

floracast/pipelines/train_forecast_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from zenml.logger import get_logger
77

88
from steps.ingest import ingest_data
9-
from steps.preprocess import preprocess_for_training
9+
from steps.preprocess import preprocess_data
1010
from steps.train import train_model
1111
from steps.evaluate import evaluate
1212
from steps.promote import promote_model
@@ -24,7 +24,7 @@ def train_forecast_pipeline() -> None:
2424
raw_data = ingest_data()
2525

2626
# Step 2: Preprocess data into Darts TimeSeries with train/val split
27-
train_series, val_series = preprocess_for_training(df=raw_data)
27+
train_series, val_series, _ = preprocess_data(df=raw_data)
2828

2929
# Step 3: Train the forecasting model
3030
trained_model = train_model(train_series=train_series)

floracast/requirements.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,4 @@ azure-identity>=1.16.0,<2.0.0
1919
azure-storage-blob>=12.20.0,<13.0.0
2020

2121
# Utilities
22-
matplotlib>=3.7.0,<4.0.0
23-
seaborn>=0.12.0,<0.14.0
22+
matplotlib>=3.7.0,<4.0.0

floracast/run.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
"""
44

55
import click
6-
from datetime import datetime
76
from pathlib import Path
87
from pipelines import batch_inference_pipeline, train_forecast_pipeline
98
from zenml.logger import get_logger

floracast/steps/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
"""ZenML pipeline steps for FloraCast."""
22

33
from .ingest import ingest_data
4-
from .preprocess import preprocess_for_training, preprocess_for_inference
4+
from .preprocess import preprocess_data
55
from .train import train_model
66
from .evaluate import evaluate
77
from .promote import promote_model
88
from .batch_infer import batch_inference_predict
99

1010
__all__ = [
1111
"ingest_data",
12-
"preprocess_for_training",
13-
"preprocess_for_inference",
12+
"preprocess_data",
1413
"train_model",
1514
"evaluate",
1615
"promote_model",

floracast/steps/batch_infer.py

Lines changed: 63 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,32 @@
44

55
from typing import Annotated
66
import pandas as pd
7+
import numpy as np
78
from darts import TimeSeries
89
from zenml import step, get_step_context, log_metadata
910
from zenml.logger import get_logger
1011
from zenml.client import Client
12+
from utils.prediction import iterative_predict
1113

1214
logger = get_logger(__name__)
1315

1416

1517
@step
1618
def batch_inference_predict(
17-
series: TimeSeries,
19+
df: pd.DataFrame,
20+
datetime_col: str = "ds",
21+
target_col: str = "y",
22+
freq: str = "D",
1823
horizon: int = 14,
1924
) -> Annotated[pd.DataFrame, "predictions"]:
2025
"""
2126
Perform batch inference using the trained model from Model Control Plane.
2227
2328
Args:
24-
series: Time series data for forecasting
29+
df: Raw DataFrame with datetime and target columns
30+
datetime_col: Name of datetime column
31+
target_col: Name of target column
32+
freq: Frequency string for time series
2533
horizon: Number of time steps to forecast
2634
2735
Returns:
@@ -30,6 +38,21 @@ def batch_inference_predict(
3038
logger.info(f"Performing batch inference with horizon: {horizon}")
3139

3240
try:
41+
# Convert DataFrame to TimeSeries
42+
logger.info("Converting DataFrame to TimeSeries")
43+
series = TimeSeries.from_dataframe(
44+
df, time_col=datetime_col, value_cols=target_col, freq=freq
45+
)
46+
47+
# Cast to float32 for consistency with training data
48+
logger.info("Converting TimeSeries to float32 for consistency")
49+
series = series.astype(np.float32)
50+
51+
logger.info(f"Created TimeSeries with {len(series)} points")
52+
logger.info(
53+
f"Series range: {series.start_time()} to {series.end_time()}"
54+
)
55+
3356
# Get the model from Model Control Plane
3457
context = get_step_context()
3558
if not context.model:
@@ -72,41 +95,48 @@ def batch_inference_predict(
7295
f"Loaded model from Model Control Plane: {type(trained_model).__name__}"
7396
)
7497

75-
# Generate predictions using improved multi-step approach (same as evaluation)
76-
logger.info(
77-
f"Using iterative multi-step prediction for horizon={horizon}"
78-
)
79-
80-
# Use multiple prediction steps for better long-term accuracy
81-
predictions_list = []
82-
context_series = series
83-
84-
# Predict in chunks of output_chunk_length (14 days)
85-
remaining_steps = horizon
86-
while remaining_steps > 0:
87-
chunk_size = min(
88-
14, remaining_steps
89-
) # Model's output_chunk_length
90-
chunk_pred = trained_model.predict(
91-
n=chunk_size, series=context_series
98+
# Load the fitted scaler artifact
99+
fitted_scaler = None
100+
try:
101+
scaler_artifact = context.model.get_artifact("fitted_scaler")
102+
if scaler_artifact is None:
103+
raise ValueError(
104+
"fitted_scaler artifact not found in model version"
105+
)
106+
fitted_scaler = scaler_artifact.load()
107+
logger.info("Loaded fitted scaler artifact from model version")
108+
109+
# Apply scaling to the input series
110+
logger.info("Applying scaling to input series for inference")
111+
series = fitted_scaler.transform(series)
112+
logger.info("Scaling applied successfully")
113+
except Exception as scaler_error:
114+
logger.error(f"Failed to load or apply scaler: {scaler_error}")
115+
logger.warning(
116+
"Proceeding without scaling - predictions may be incorrect!"
92117
)
93-
predictions_list.append(chunk_pred)
118+
# Continue without scaling for backward compatibility
94119

95-
# Extend context with the prediction for next iteration
96-
context_series = context_series.concatenate(chunk_pred)
97-
remaining_steps -= chunk_size
120+
# Generate predictions using improved multi-step approach
121+
predictions = iterative_predict(trained_model, series, horizon)
98122

99-
# Combine all predictions
100-
if len(predictions_list) == 1:
101-
predictions = predictions_list[0]
123+
# Inverse transform predictions back to original scale
124+
if fitted_scaler is not None:
125+
try:
126+
logger.info(
127+
"Inverse transforming predictions back to original scale"
128+
)
129+
predictions = fitted_scaler.inverse_transform(predictions)
130+
logger.info("Inverse transformation applied successfully")
131+
except Exception as inverse_error:
132+
logger.error(
133+
f"Failed to inverse transform predictions: {inverse_error}"
134+
)
135+
logger.warning("Predictions remain in scaled format!")
102136
else:
103-
predictions = predictions_list[0]
104-
for pred_chunk in predictions_list[1:]:
105-
predictions = predictions.concatenate(pred_chunk)
106-
107-
logger.info(
108-
f"Generated {len(predictions)} predictions using multi-step approach"
109-
)
137+
logger.warning(
138+
"No scaler available - predictions remain in original format"
139+
)
110140

111141
# Convert to DataFrame
112142
pred_df = predictions.pd_dataframe().reset_index()

0 commit comments

Comments
 (0)